hotshot_libp2p_networking/network/
transport.rs

1use std::{
2    future::Future,
3    io::{Error as IoError, ErrorKind as IoErrorKind},
4    pin::Pin,
5    sync::Arc,
6    task::Poll,
7};
8
9use anyhow::{ensure, Context, Result as AnyhowResult};
10use bimap::BiMap;
11use futures::{future::poll_fn, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14    core::{
15        muxing::StreamMuxerExt,
16        transport::{DialOpts, TransportEvent},
17        StreamMuxer,
18    },
19    identity::PeerId,
20    Transport,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::warn;
27
28/// The maximum size of an authentication message. This is used to prevent
29/// DoS attacks by sending large messages.
30const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
31
32/// The timeout for the authentication handshake. This is used to prevent
33/// attacks that keep connections open indefinitely by half-finishing the
34/// handshake.
35const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
36
37/// A wrapper for a `Transport` that bidirectionally associates (and verifies)
38/// the corresponding consensus keys.
39#[pin_project]
40pub struct ConsensusKeyAuthentication<
41    T: Transport,
42    S: SignatureKey + 'static,
43    C: StreamMuxer + Unpin,
44> {
45    #[pin]
46    /// The underlying transport we are wrapping
47    pub inner: T,
48
49    /// A pre-signed message that we (depending on if it's specified or not) send to the remote peer for authentication
50    pub auth_message: Arc<Option<Vec<u8>>>,
51
52    /// The (verified) map of consensus keys to peer IDs
53    pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
54
55    /// Phantom data for the connection type
56    pd: std::marker::PhantomData<(C, S)>,
57}
58
59/// A type alias for the future that upgrades a connection to perform the authentication handshake
60type UpgradeFuture<T> =
61    Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
62
63impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
64    ConsensusKeyAuthentication<T, S, C>
65{
66    /// Create a new `ConsensusKeyAuthentication` transport that wraps the given transport
67    /// and authenticates connections against the stake table. If the auth message is `None`,
68    /// the authentication is disabled.
69    pub fn new(
70        inner: T,
71        auth_message: Option<Vec<u8>>,
72        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
73    ) -> Self {
74        Self {
75            inner,
76            auth_message: Arc::from(auth_message),
77            consensus_key_to_pid_map,
78            pd: std::marker::PhantomData,
79        }
80    }
81
82    /// Prove to the remote peer that we are in the stake table by sending
83    /// them our authentication message.
84    ///
85    /// # Errors
86    /// - If we fail to write the message to the stream
87    pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
88        stream: &mut W,
89        auth_message: &[u8],
90    ) -> AnyhowResult<()> {
91        // Write the length-delimited message
92        write_length_delimited(stream, auth_message).await?;
93
94        Ok(())
95    }
96
97    /// Verify that the remote peer is:
98    /// - In the stake table
99    /// - Sending us a valid authentication message
100    /// - Sending us a valid signature
101    /// - Matching the peer ID we expect
102    ///
103    /// # Errors
104    /// If the peer fails verification. This can happen if:
105    /// - We fail to read the message from the stream
106    /// - The message is too large
107    /// - The message is invalid
108    /// - The peer is not in the stake table
109    /// - The signature is invalid
110    pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
111        stream: &mut R,
112        required_peer_id: &PeerId,
113        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
114    ) -> AnyhowResult<()> {
115        // Read the length-delimited message from the remote peer
116        let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
117
118        // Deserialize the authentication message
119        let auth_message: AuthMessage<S> =
120            bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
121
122        // Verify the signature on the public keys
123        let public_key = auth_message
124            .validate()
125            .with_context(|| "Failed to verify authentication message")?;
126
127        // Deserialize the `PeerId`
128        let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
129            .with_context(|| "Failed to deserialize peer ID")?;
130
131        // Verify that the peer ID is the same as the remote peer
132        if peer_id != *required_peer_id {
133            return Err(anyhow::anyhow!("Peer ID mismatch"));
134        }
135
136        // If we got here, the peer is authenticated. Add the consensus key to the map
137        consensus_key_to_pid_map.lock().insert(public_key, peer_id);
138
139        Ok(())
140    }
141
142    /// Wrap the supplied future in an upgrade that performs the authentication handshake.
143    ///
144    /// `outgoing` is a boolean that indicates if the connection is incoming or outgoing.
145    /// This is needed because the flow of the handshake is different for each.
146    fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
147        original_future: F,
148        outgoing: bool,
149        auth_message: Arc<Option<Vec<u8>>>,
150        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
151    ) -> UpgradeFuture<T>
152    where
153        T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
154        T::Output: AsOutput<C> + Send,
155
156        C::Substream: Unpin + Send,
157    {
158        // Create a new upgrade that performs the authentication handshake on top
159        Box::pin(async move {
160            // Wait for the original future to resolve
161            let mut stream = original_future.await?;
162
163            // Time out the authentication block
164            timeout(AUTH_HANDSHAKE_TIMEOUT, async {
165                // Open a substream for the handshake.
166                // The handshake order depends on whether the connection is incoming or outgoing.
167                let mut substream = if outgoing {
168                    poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
169                } else {
170                    poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
171                };
172
173                // Conditionally authenticate depending on whether we specified an auth message
174                if let Some(auth_message) = auth_message.as_ref() {
175                    if outgoing {
176                        // If the connection is outgoing, authenticate with the remote peer first
177                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
178                            .await
179                            .map_err(|e| {
180                                warn!("Failed to authenticate with remote peer: {e:?}");
181                                IoError::other(e)
182                            })?;
183
184                        // Verify the remote peer's authentication
185                        Self::verify_peer_authentication(
186                            &mut substream,
187                            stream.as_peer_id(),
188                            consensus_key_to_pid_map,
189                        )
190                        .await
191                        .map_err(|e| {
192                            warn!("Failed to verify remote peer: {e:?}");
193                            IoError::other(e)
194                        })?;
195                    } else {
196                        // If it is incoming, verify the remote peer's authentication first
197                        Self::verify_peer_authentication(
198                            &mut substream,
199                            stream.as_peer_id(),
200                            consensus_key_to_pid_map,
201                        )
202                        .await
203                        .map_err(|e| {
204                            warn!("Failed to verify remote peer: {e:?}");
205                            IoError::other(e)
206                        })?;
207
208                        // Authenticate with the remote peer
209                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
210                            .await
211                            .map_err(|e| {
212                                warn!("Failed to authenticate with remote peer: {e:?}");
213                                IoError::other(e)
214                            })?;
215                    }
216                }
217
218                Ok(stream)
219            })
220            .await
221            .map_err(|e| {
222                warn!("Timed out performing authentication handshake: {e:?}");
223                IoError::new(IoErrorKind::TimedOut, e)
224            })?
225        })
226    }
227}
228
229/// The deserialized form of an authentication message that is sent to the remote peer
230#[derive(Clone, Serialize, Deserialize)]
231struct AuthMessage<S: SignatureKey> {
232    /// The encoded (stake table) public key of the sender. This, along with the peer ID, is
233    /// signed. It is still encoded here to enable easy verification.
234    public_key_bytes: Vec<u8>,
235
236    /// The encoded peer ID of the sender. This is appended to the public key before signing.
237    /// It is still encoded here to enable easy verification.
238    peer_id_bytes: Vec<u8>,
239
240    /// The signature on the public key
241    signature: S::PureAssembledSignatureType,
242}
243
244impl<S: SignatureKey> AuthMessage<S> {
245    /// Validate the signature on the public key and return it if valid
246    pub fn validate(&self) -> AnyhowResult<S> {
247        // Deserialize the stake table public key
248        let public_key = S::from_bytes(&self.public_key_bytes)
249            .with_context(|| "Failed to deserialize public key")?;
250
251        // Reconstruct the signed message from the public key and peer ID
252        let mut signed_message = public_key.to_bytes();
253        signed_message.extend(self.peer_id_bytes.clone());
254
255        // Check if the signature is valid across both
256        if !public_key.validate(&self.signature, &signed_message) {
257            return Err(anyhow::anyhow!("Invalid signature"));
258        }
259
260        Ok(public_key)
261    }
262}
263
264/// Create an sign an authentication message to be sent to the remote peer
265///
266/// # Errors
267/// - If we fail to sign the public key
268/// - If we fail to serialize the authentication message
269pub fn construct_auth_message<S: SignatureKey + 'static>(
270    public_key: &S,
271    peer_id: &PeerId,
272    private_key: &S::PrivateKey,
273) -> AnyhowResult<Vec<u8>> {
274    // Serialize the stake table public key
275    let mut public_key_bytes = public_key.to_bytes();
276
277    // Serialize the peer ID and append it
278    let peer_id_bytes = peer_id.to_bytes();
279    public_key_bytes.extend_from_slice(&peer_id_bytes);
280
281    // Sign our public key
282    let signature =
283        S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
284
285    // Create the auth message
286    let auth_message = AuthMessage::<S> {
287        public_key_bytes,
288        peer_id_bytes,
289        signature,
290    };
291
292    // Serialize the auth message
293    bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
294}
295
296impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
297    for ConsensusKeyAuthentication<T, S, C>
298where
299    T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
300    T::ListenerUpgrade: Send + 'static,
301    T::Output: AsOutput<C> + Send,
302    T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
303
304    C::Substream: Unpin + Send,
305{
306    // `Dial` is for connecting out, `ListenerUpgrade` is for accepting incoming connections
307    type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
308    type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
309
310    // These are just passed through
311    type Output = T::Output;
312    type Error = T::Error;
313
314    /// Dial a remote peer. This function is changed to perform an authentication handshake
315    /// on top.
316    fn dial(
317        &mut self,
318        addr: libp2p::Multiaddr,
319        opts: DialOpts,
320    ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
321        // Perform the inner dial
322        let res = self.inner.dial(addr, opts);
323
324        // Clone the necessary fields
325        let auth_message = Arc::clone(&self.auth_message);
326
327        // If the dial was successful, perform the authentication handshake on top
328        match res {
329            Ok(dial) => Ok(Self::gen_handshake(
330                dial,
331                true,
332                auth_message,
333                Arc::clone(&self.consensus_key_to_pid_map),
334            )),
335            Err(err) => Err(err),
336        }
337    }
338
339    /// This function is where we perform the authentication handshake for _incoming_ connections.
340    /// The flow in this case is the reverse of the `dial` function: we first verify the remote peer's
341    /// authentication, and then authenticate with them.
342    fn poll(
343        mut self: std::pin::Pin<&mut Self>,
344        cx: &mut std::task::Context<'_>,
345    ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
346    {
347        match self.as_mut().project().inner.poll(cx) {
348            Poll::Ready(event) => Poll::Ready(match event {
349                // If we have an incoming connection, we need to perform the authentication handshake
350                TransportEvent::Incoming {
351                    listener_id,
352                    upgrade,
353                    local_addr,
354                    send_back_addr,
355                } => {
356                    // Clone the necessary fields
357                    let auth_message = Arc::clone(&self.auth_message);
358
359                    // Generate the handshake upgrade future (inbound)
360                    let auth_upgrade = Self::gen_handshake(
361                        upgrade,
362                        false,
363                        auth_message,
364                        Arc::clone(&self.consensus_key_to_pid_map),
365                    );
366
367                    // Return the new event
368                    TransportEvent::Incoming {
369                        listener_id,
370                        upgrade: auth_upgrade,
371                        local_addr,
372                        send_back_addr,
373                    }
374                },
375
376                // We need to re-map the other events because we changed the type of the upgrade
377                TransportEvent::AddressExpired {
378                    listener_id,
379                    listen_addr,
380                } => TransportEvent::AddressExpired {
381                    listener_id,
382                    listen_addr,
383                },
384                TransportEvent::ListenerClosed {
385                    listener_id,
386                    reason,
387                } => TransportEvent::ListenerClosed {
388                    listener_id,
389                    reason,
390                },
391                TransportEvent::ListenerError { listener_id, error } => {
392                    TransportEvent::ListenerError { listener_id, error }
393                },
394                TransportEvent::NewAddress {
395                    listener_id,
396                    listen_addr,
397                } => TransportEvent::NewAddress {
398                    listener_id,
399                    listen_addr,
400                },
401            }),
402
403            Poll::Pending => Poll::Pending,
404        }
405    }
406
407    /// The below functions just pass through to the inner transport, but we had
408    /// to define them
409    fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
410        self.inner.remove_listener(id)
411    }
412    fn listen_on(
413        &mut self,
414        id: libp2p::core::transport::ListenerId,
415        addr: libp2p::Multiaddr,
416    ) -> Result<(), libp2p::TransportError<Self::Error>> {
417        self.inner.listen_on(id, addr)
418    }
419}
420
421/// A helper trait that allows us to access the underlying connection
422/// and `PeerId` from a transport output
423trait AsOutput<C: StreamMuxer + Unpin> {
424    /// Get a mutable reference to the underlying connection
425    fn as_connection(&mut self) -> &mut C;
426
427    /// Get a mutable reference to the underlying `PeerId`
428    fn as_peer_id(&mut self) -> &mut PeerId;
429}
430
431/// The implementation of the `AsConnection` trait for a tuple of a `PeerId`
432/// and a connection.
433impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
434    /// Get a mutable reference to the underlying connection
435    fn as_connection(&mut self) -> &mut C {
436        &mut self.1
437    }
438
439    /// Get a mutable reference to the underlying `PeerId`
440    fn as_peer_id(&mut self) -> &mut PeerId {
441        &mut self.0
442    }
443}
444
445/// A helper function to read a length-delimited message from a stream. Takes into
446/// account the maximum message size.
447///
448/// # Errors
449/// - If the message is too big
450/// - If we fail to read from the stream
451pub async fn read_length_delimited<S: AsyncRead + Unpin>(
452    stream: &mut S,
453    max_size: usize,
454) -> AnyhowResult<Vec<u8>> {
455    // Receive the first 8 bytes of the message, which is the length
456    let mut len_bytes = [0u8; 4];
457    stream
458        .read_exact(&mut len_bytes)
459        .await
460        .with_context(|| "Failed to read message length")?;
461
462    // Parse the length of the message as a `u32`
463    let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
464
465    // Quit if the message is too large
466    ensure!(len <= max_size, "Message too large");
467
468    // Read the actual message
469    let mut message = vec![0u8; len];
470    stream
471        .read_exact(&mut message)
472        .await
473        .with_context(|| "Failed to read message")?;
474
475    Ok(message)
476}
477
478/// A helper function to write a length-delimited message to a stream.
479///
480/// # Errors
481/// - If we fail to write to the stream
482pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
483    stream: &mut S,
484    message: &[u8],
485) -> AnyhowResult<()> {
486    // Write the length of the message
487    stream
488        .write_all(&u32::try_from(message.len())?.to_be_bytes())
489        .await
490        .with_context(|| "Failed to write message length")?;
491
492    // Write the actual message
493    stream
494        .write_all(message)
495        .await
496        .with_context(|| "Failed to write message")?;
497
498    Ok(())
499}
500
501#[cfg(test)]
502mod test {
503    use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
504    use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
505    use rand::Rng;
506
507    use super::*;
508
509    /// A mock type to help with readability
510    type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
511
512    // Helper macro for generating a new identity and authentication message
513    macro_rules! new_identity {
514        () => {{
515            // Gen a new seed
516            let seed = rand::rngs::OsRng.gen::<[u8; 32]>();
517
518            // Create a new keypair
519            let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
520
521            // Create a peer ID
522            let peer_id = libp2p::identity::Keypair::generate_ed25519()
523                .public()
524                .to_peer_id();
525
526            // Construct an authentication message
527            let auth_message =
528                super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
529
530            (keypair, peer_id, auth_message)
531        }};
532    }
533
534    // Helper macro to generator a cursor from a length-delimited message
535    macro_rules! cursor_from {
536        ($auth_message:expr) => {{
537            let mut stream = futures::io::Cursor::new(vec![]);
538            write_length_delimited(&mut stream, &$auth_message)
539                .await
540                .expect("Failed to write message");
541            stream.set_position(0);
542            stream
543        }};
544    }
545
546    /// Test valid construction and verification of an authentication message
547    #[test]
548    fn signature_verify() {
549        // Create a new identity
550        let (_, _, auth_message) = new_identity!();
551
552        // Verify the authentication message
553        let public_key = super::AuthMessage::<BLSPubKey>::validate(
554            &bincode::deserialize(&auth_message).unwrap(),
555        );
556        assert!(public_key.is_ok());
557    }
558
559    /// Test invalid construction and verification of an authentication message with
560    /// an invalid public key. This ensures we are signing over it correctly.
561    #[test]
562    fn signature_verify_invalid_public_key() {
563        // Create a new identity
564        let (_, _, auth_message) = new_identity!();
565
566        // Deserialize the authentication message
567        let mut auth_message: super::AuthMessage<BLSPubKey> =
568            bincode::deserialize(&auth_message).unwrap();
569
570        // Change the public key
571        auth_message.public_key_bytes[0] ^= 0x01;
572
573        // Serialize the message again
574        let auth_message = bincode::serialize(&auth_message).unwrap();
575
576        // Verify the authentication message
577        let public_key = super::AuthMessage::<BLSPubKey>::validate(
578            &bincode::deserialize(&auth_message).unwrap(),
579        );
580        assert!(public_key.is_err());
581    }
582
583    /// Test invalid construction and verification of an authentication message with
584    /// an invalid peer ID. This ensures we are signing over it correctly.
585    #[test]
586    fn signature_verify_invalid_peer_id() {
587        // Create a new identity
588        let (_, _, auth_message) = new_identity!();
589
590        // Deserialize the authentication message
591        let mut auth_message: super::AuthMessage<BLSPubKey> =
592            bincode::deserialize(&auth_message).unwrap();
593
594        // Change the peer ID
595        auth_message.peer_id_bytes[0] ^= 0x01;
596
597        // Serialize the message again
598        let auth_message = bincode::serialize(&auth_message).unwrap();
599
600        // Verify the authentication message
601        let public_key = super::AuthMessage::<BLSPubKey>::validate(
602            &bincode::deserialize(&auth_message).unwrap(),
603        );
604        assert!(public_key.is_err());
605    }
606
607    #[tokio::test(flavor = "multi_thread")]
608    async fn valid_authentication() {
609        // Create a new identity
610        let (keypair, peer_id, auth_message) = new_identity!();
611
612        // Create a stream and write the message to it
613        let mut stream = cursor_from!(auth_message);
614
615        // Create a map from consensus keys to peer IDs
616        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
617
618        // Verify the authentication message
619        let result = MockStakeTableAuth::verify_peer_authentication(
620            &mut stream,
621            &peer_id,
622            Arc::clone(&consensus_key_to_pid_map),
623        )
624        .await;
625
626        // Make sure the map has the correct entry
627        assert!(
628            consensus_key_to_pid_map
629                .lock()
630                .get_by_left(&keypair.0)
631                .unwrap()
632                == &peer_id,
633            "Map does not have the correct entry"
634        );
635
636        assert!(
637            result.is_ok(),
638            "Should have passed authentication but did not"
639        );
640    }
641
642    #[tokio::test(flavor = "multi_thread")]
643    async fn peer_id_mismatch() {
644        // Create a new identity and authentication message
645        let (_, _, auth_message) = new_identity!();
646
647        // Create a second (malicious) identity
648        let (_, malicious_peer_id, _) = new_identity!();
649
650        // Create a stream and write the message to it
651        let mut stream = cursor_from!(auth_message);
652
653        // Create a map from consensus keys to peer IDs
654        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
655
656        // Check against the malicious peer ID
657        let result = MockStakeTableAuth::verify_peer_authentication(
658            &mut stream,
659            &malicious_peer_id,
660            Arc::clone(&consensus_key_to_pid_map),
661        )
662        .await;
663
664        // Make sure it errored for the right reason
665        assert!(
666            result
667                .expect_err("Should have failed authentication but did not")
668                .to_string()
669                .contains("Peer ID mismatch"),
670            "Did not fail with the correct error"
671        );
672
673        // Make sure the map does not have the malicious peer ID
674        assert!(
675            consensus_key_to_pid_map.lock().is_empty(),
676            "Malicious peer ID should not be in the map"
677        );
678    }
679
680    #[tokio::test(flavor = "multi_thread")]
681    async fn read_and_write_length_delimited() {
682        // Create a message
683        let message = b"Hello, world!";
684
685        // Write the message to a buffer
686        let mut buffer = Vec::new();
687        write_length_delimited(&mut buffer, message).await.unwrap();
688
689        // Read the message from the buffer
690        let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
691            .await
692            .unwrap();
693
694        // Check if the messages are the same
695        assert_eq!(message, read_message.as_slice());
696    }
697}