hotshot_libp2p_networking/network/
transport.rs

1use std::{
2    future::Future,
3    io::{Error as IoError, ErrorKind as IoErrorKind},
4    pin::Pin,
5    sync::Arc,
6    task::Poll,
7};
8
9use anyhow::{Context, Result as AnyhowResult, ensure};
10use bimap::BiMap;
11use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, future::poll_fn};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14    Transport,
15    core::{
16        StreamMuxer,
17        muxing::StreamMuxerExt,
18        transport::{DialOpts, TransportEvent},
19    },
20    identity::PeerId,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::warn;
27
28/// The maximum size of an authentication message. This is used to prevent
29/// DoS attacks by sending large messages.
30const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
31
32/// The timeout for the authentication handshake. This is used to prevent
33/// attacks that keep connections open indefinitely by half-finishing the
34/// handshake.
35const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
36
37/// A wrapper for a `Transport` that bidirectionally associates (and verifies)
38/// the corresponding consensus keys.
39#[pin_project]
40pub struct ConsensusKeyAuthentication<
41    T: Transport,
42    S: SignatureKey + 'static,
43    C: StreamMuxer + Unpin,
44> {
45    #[pin]
46    /// The underlying transport we are wrapping
47    pub inner: T,
48
49    /// A pre-signed message that we (depending on if it's specified or not) send to the remote peer for authentication
50    pub auth_message: Arc<Option<Vec<u8>>>,
51
52    /// The (verified) map of consensus keys to peer IDs
53    pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
54
55    /// Phantom data for the connection type
56    pd: std::marker::PhantomData<(C, S)>,
57}
58
59/// A type alias for the future that upgrades a connection to perform the authentication handshake
60type UpgradeFuture<T> =
61    Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
62
63impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
64    ConsensusKeyAuthentication<T, S, C>
65{
66    /// Create a new `ConsensusKeyAuthentication` transport that wraps the given transport
67    /// and authenticates connections against the stake table. If the auth message is `None`,
68    /// the authentication is disabled.
69    pub fn new(
70        inner: T,
71        auth_message: Option<Vec<u8>>,
72        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
73    ) -> Self {
74        Self {
75            inner,
76            auth_message: Arc::from(auth_message),
77            consensus_key_to_pid_map,
78            pd: std::marker::PhantomData,
79        }
80    }
81
82    /// Prove to the remote peer that we are in the stake table by sending
83    /// them our authentication message.
84    ///
85    /// # Errors
86    /// - If we fail to write the message to the stream
87    pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
88        stream: &mut W,
89        auth_message: &[u8],
90    ) -> AnyhowResult<()> {
91        // Write the length-delimited message
92        write_length_delimited(stream, auth_message).await?;
93
94        Ok(())
95    }
96
97    /// Verify that the remote peer is:
98    /// - In the stake table
99    /// - Sending us a valid authentication message
100    /// - Sending us a valid signature
101    /// - Matching the peer ID we expect
102    ///
103    /// # Errors
104    /// If the peer fails verification. This can happen if:
105    /// - We fail to read the message from the stream
106    /// - The message is too large
107    /// - The message is invalid
108    /// - The peer is not in the stake table
109    /// - The signature is invalid
110    pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
111        stream: &mut R,
112        required_peer_id: &PeerId,
113        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
114    ) -> AnyhowResult<()> {
115        // Read the length-delimited message from the remote peer
116        let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
117
118        // Deserialize the authentication message
119        let auth_message: AuthMessage<S> =
120            bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
121
122        // Verify the signature on the public keys
123        let public_key = auth_message
124            .validate()
125            .with_context(|| "Failed to verify authentication message")?;
126
127        // Deserialize the `PeerId`
128        let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
129            .with_context(|| "Failed to deserialize peer ID")?;
130
131        // Verify that the peer ID is the same as the remote peer
132        if peer_id != *required_peer_id {
133            return Err(anyhow::anyhow!("Peer ID mismatch"));
134        }
135
136        // If we got here, the peer is authenticated. Add the consensus key to the map
137        consensus_key_to_pid_map.lock().insert(public_key, peer_id);
138
139        Ok(())
140    }
141
142    /// Wrap the supplied future in an upgrade that performs the authentication handshake.
143    ///
144    /// `outgoing` is a boolean that indicates if the connection is incoming or outgoing.
145    /// This is needed because the flow of the handshake is different for each.
146    fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
147        original_future: F,
148        outgoing: bool,
149        auth_message: Arc<Option<Vec<u8>>>,
150        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
151    ) -> UpgradeFuture<T>
152    where
153        T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
154        T::Output: AsOutput<C> + Send,
155        C::Substream: Unpin + Send,
156    {
157        // Create a new upgrade that performs the authentication handshake on top
158        Box::pin(async move {
159            // Wait for the original future to resolve
160            let mut stream = original_future.await?;
161
162            // Time out the authentication block
163            timeout(AUTH_HANDSHAKE_TIMEOUT, async {
164                // Open a substream for the handshake.
165                // The handshake order depends on whether the connection is incoming or outgoing.
166                let mut substream = if outgoing {
167                    poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
168                } else {
169                    poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
170                };
171
172                // Conditionally authenticate depending on whether we specified an auth message
173                if let Some(auth_message) = auth_message.as_ref() {
174                    if outgoing {
175                        // If the connection is outgoing, authenticate with the remote peer first
176                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
177                            .await
178                            .map_err(|e| {
179                                warn!("Failed to authenticate with remote peer: {e:?}");
180                                IoError::other(e)
181                            })?;
182
183                        // Verify the remote peer's authentication
184                        Self::verify_peer_authentication(
185                            &mut substream,
186                            stream.as_peer_id(),
187                            consensus_key_to_pid_map,
188                        )
189                        .await
190                        .map_err(|e| {
191                            warn!("Failed to verify remote peer: {e:?}");
192                            IoError::other(e)
193                        })?;
194                    } else {
195                        // If it is incoming, verify the remote peer's authentication first
196                        Self::verify_peer_authentication(
197                            &mut substream,
198                            stream.as_peer_id(),
199                            consensus_key_to_pid_map,
200                        )
201                        .await
202                        .map_err(|e| {
203                            warn!("Failed to verify remote peer: {e:?}");
204                            IoError::other(e)
205                        })?;
206
207                        // Authenticate with the remote peer
208                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
209                            .await
210                            .map_err(|e| {
211                                warn!("Failed to authenticate with remote peer: {e:?}");
212                                IoError::other(e)
213                            })?;
214                    }
215                }
216
217                Ok(stream)
218            })
219            .await
220            .map_err(|e| {
221                warn!("Timed out performing authentication handshake: {e:?}");
222                IoError::new(IoErrorKind::TimedOut, e)
223            })?
224        })
225    }
226}
227
228/// The deserialized form of an authentication message that is sent to the remote peer
229#[derive(Clone, Serialize, Deserialize)]
230struct AuthMessage<S: SignatureKey> {
231    /// The encoded (stake table) public key of the sender. This, along with the peer ID, is
232    /// signed. It is still encoded here to enable easy verification.
233    public_key_bytes: Vec<u8>,
234
235    /// The encoded peer ID of the sender. This is appended to the public key before signing.
236    /// It is still encoded here to enable easy verification.
237    peer_id_bytes: Vec<u8>,
238
239    /// The signature on the public key
240    signature: S::PureAssembledSignatureType,
241}
242
243impl<S: SignatureKey> AuthMessage<S> {
244    /// Validate the signature on the public key and return it if valid
245    pub fn validate(&self) -> AnyhowResult<S> {
246        // Deserialize the stake table public key
247        let public_key = S::from_bytes(&self.public_key_bytes)
248            .with_context(|| "Failed to deserialize public key")?;
249
250        // Reconstruct the signed message from the public key and peer ID
251        let mut signed_message = public_key.to_bytes();
252        signed_message.extend(self.peer_id_bytes.clone());
253
254        // Check if the signature is valid across both
255        if !public_key.validate(&self.signature, &signed_message) {
256            return Err(anyhow::anyhow!("Invalid signature"));
257        }
258
259        Ok(public_key)
260    }
261}
262
263/// Create an sign an authentication message to be sent to the remote peer
264///
265/// # Errors
266/// - If we fail to sign the public key
267/// - If we fail to serialize the authentication message
268pub fn construct_auth_message<S: SignatureKey + 'static>(
269    public_key: &S,
270    peer_id: &PeerId,
271    private_key: &S::PrivateKey,
272) -> AnyhowResult<Vec<u8>> {
273    // Serialize the stake table public key
274    let mut public_key_bytes = public_key.to_bytes();
275
276    // Serialize the peer ID and append it
277    let peer_id_bytes = peer_id.to_bytes();
278    public_key_bytes.extend_from_slice(&peer_id_bytes);
279
280    // Sign our public key
281    let signature =
282        S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
283
284    // Create the auth message
285    let auth_message = AuthMessage::<S> {
286        public_key_bytes,
287        peer_id_bytes,
288        signature,
289    };
290
291    // Serialize the auth message
292    bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
293}
294
295impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
296    for ConsensusKeyAuthentication<T, S, C>
297where
298    T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
299    T::ListenerUpgrade: Send + 'static,
300    T::Output: AsOutput<C> + Send,
301    T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
302    C::Substream: Unpin + Send,
303{
304    // `Dial` is for connecting out, `ListenerUpgrade` is for accepting incoming connections
305    type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
306    type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
307
308    // These are just passed through
309    type Output = T::Output;
310    type Error = T::Error;
311
312    /// Dial a remote peer. This function is changed to perform an authentication handshake
313    /// on top.
314    fn dial(
315        &mut self,
316        addr: libp2p::Multiaddr,
317        opts: DialOpts,
318    ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
319        // Perform the inner dial
320        let res = self.inner.dial(addr, opts);
321
322        // Clone the necessary fields
323        let auth_message = Arc::clone(&self.auth_message);
324
325        // If the dial was successful, perform the authentication handshake on top
326        match res {
327            Ok(dial) => Ok(Self::gen_handshake(
328                dial,
329                true,
330                auth_message,
331                Arc::clone(&self.consensus_key_to_pid_map),
332            )),
333            Err(err) => Err(err),
334        }
335    }
336
337    /// This function is where we perform the authentication handshake for _incoming_ connections.
338    /// The flow in this case is the reverse of the `dial` function: we first verify the remote peer's
339    /// authentication, and then authenticate with them.
340    fn poll(
341        mut self: std::pin::Pin<&mut Self>,
342        cx: &mut std::task::Context<'_>,
343    ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
344    {
345        match Transport::poll(self.as_mut().project().inner, cx) {
346            Poll::Ready(event) => Poll::Ready(match event {
347                // If we have an incoming connection, we need to perform the authentication handshake
348                TransportEvent::Incoming {
349                    listener_id,
350                    upgrade,
351                    local_addr,
352                    send_back_addr,
353                } => {
354                    // Clone the necessary fields
355                    let auth_message = Arc::clone(&self.auth_message);
356
357                    // Generate the handshake upgrade future (inbound)
358                    let auth_upgrade = Self::gen_handshake(
359                        upgrade,
360                        false,
361                        auth_message,
362                        Arc::clone(&self.consensus_key_to_pid_map),
363                    );
364
365                    // Return the new event
366                    TransportEvent::Incoming {
367                        listener_id,
368                        upgrade: auth_upgrade,
369                        local_addr,
370                        send_back_addr,
371                    }
372                },
373
374                // We need to re-map the other events because we changed the type of the upgrade
375                TransportEvent::AddressExpired {
376                    listener_id,
377                    listen_addr,
378                } => TransportEvent::AddressExpired {
379                    listener_id,
380                    listen_addr,
381                },
382                TransportEvent::ListenerClosed {
383                    listener_id,
384                    reason,
385                } => TransportEvent::ListenerClosed {
386                    listener_id,
387                    reason,
388                },
389                TransportEvent::ListenerError { listener_id, error } => {
390                    TransportEvent::ListenerError { listener_id, error }
391                },
392                TransportEvent::NewAddress {
393                    listener_id,
394                    listen_addr,
395                } => TransportEvent::NewAddress {
396                    listener_id,
397                    listen_addr,
398                },
399            }),
400
401            Poll::Pending => Poll::Pending,
402        }
403    }
404
405    /// The below functions just pass through to the inner transport, but we had
406    /// to define them
407    fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
408        self.inner.remove_listener(id)
409    }
410    fn listen_on(
411        &mut self,
412        id: libp2p::core::transport::ListenerId,
413        addr: libp2p::Multiaddr,
414    ) -> Result<(), libp2p::TransportError<Self::Error>> {
415        self.inner.listen_on(id, addr)
416    }
417}
418
419/// A helper trait that allows us to access the underlying connection
420/// and `PeerId` from a transport output
421trait AsOutput<C: StreamMuxer + Unpin> {
422    /// Get a mutable reference to the underlying connection
423    fn as_connection(&mut self) -> &mut C;
424
425    /// Get a mutable reference to the underlying `PeerId`
426    fn as_peer_id(&mut self) -> &mut PeerId;
427}
428
429/// The implementation of the `AsConnection` trait for a tuple of a `PeerId`
430/// and a connection.
431impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
432    /// Get a mutable reference to the underlying connection
433    fn as_connection(&mut self) -> &mut C {
434        &mut self.1
435    }
436
437    /// Get a mutable reference to the underlying `PeerId`
438    fn as_peer_id(&mut self) -> &mut PeerId {
439        &mut self.0
440    }
441}
442
443/// A helper function to read a length-delimited message from a stream. Takes into
444/// account the maximum message size.
445///
446/// # Errors
447/// - If the message is too big
448/// - If we fail to read from the stream
449pub async fn read_length_delimited<S: AsyncRead + Unpin>(
450    stream: &mut S,
451    max_size: usize,
452) -> AnyhowResult<Vec<u8>> {
453    // Receive the first 8 bytes of the message, which is the length
454    let mut len_bytes = [0u8; 4];
455    stream
456        .read_exact(&mut len_bytes)
457        .await
458        .with_context(|| "Failed to read message length")?;
459
460    // Parse the length of the message as a `u32`
461    let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
462
463    // Quit if the message is too large
464    ensure!(len <= max_size, "Message too large");
465
466    // Read the actual message
467    let mut message = vec![0u8; len];
468    stream
469        .read_exact(&mut message)
470        .await
471        .with_context(|| "Failed to read message")?;
472
473    Ok(message)
474}
475
476/// A helper function to write a length-delimited message to a stream.
477///
478/// # Errors
479/// - If we fail to write to the stream
480pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
481    stream: &mut S,
482    message: &[u8],
483) -> AnyhowResult<()> {
484    // Write the length of the message
485    stream
486        .write_all(&u32::try_from(message.len())?.to_be_bytes())
487        .await
488        .with_context(|| "Failed to write message length")?;
489
490    // Write the actual message
491    stream
492        .write_all(message)
493        .await
494        .with_context(|| "Failed to write message")?;
495
496    Ok(())
497}
498
499#[cfg(test)]
500mod test {
501    use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
502    use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
503    use rand::Rng;
504
505    use super::*;
506
507    /// A mock type to help with readability
508    type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
509
510    // Helper macro for generating a new identity and authentication message
511    macro_rules! new_identity {
512        () => {{
513            // Gen a new seed
514            let seed = rand::rngs::OsRng.r#gen::<[u8; 32]>();
515
516            // Create a new keypair
517            let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
518
519            // Create a peer ID
520            let peer_id = libp2p::identity::Keypair::generate_ed25519()
521                .public()
522                .to_peer_id();
523
524            // Construct an authentication message
525            let auth_message =
526                super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
527
528            (keypair, peer_id, auth_message)
529        }};
530    }
531
532    // Helper macro to generator a cursor from a length-delimited message
533    macro_rules! cursor_from {
534        ($auth_message:expr) => {{
535            let mut stream = futures::io::Cursor::new(vec![]);
536            write_length_delimited(&mut stream, &$auth_message)
537                .await
538                .expect("Failed to write message");
539            stream.set_position(0);
540            stream
541        }};
542    }
543
544    /// Test valid construction and verification of an authentication message
545    #[test]
546    fn signature_verify() {
547        // Create a new identity
548        let (_, _, auth_message) = new_identity!();
549
550        // Verify the authentication message
551        let public_key = super::AuthMessage::<BLSPubKey>::validate(
552            &bincode::deserialize(&auth_message).unwrap(),
553        );
554        assert!(public_key.is_ok());
555    }
556
557    /// Test invalid construction and verification of an authentication message with
558    /// an invalid public key. This ensures we are signing over it correctly.
559    #[test]
560    fn signature_verify_invalid_public_key() {
561        // Create a new identity
562        let (_, _, auth_message) = new_identity!();
563
564        // Deserialize the authentication message
565        let mut auth_message: super::AuthMessage<BLSPubKey> =
566            bincode::deserialize(&auth_message).unwrap();
567
568        // Change the public key
569        auth_message.public_key_bytes[0] ^= 0x01;
570
571        // Serialize the message again
572        let auth_message = bincode::serialize(&auth_message).unwrap();
573
574        // Verify the authentication message
575        let public_key = super::AuthMessage::<BLSPubKey>::validate(
576            &bincode::deserialize(&auth_message).unwrap(),
577        );
578        assert!(public_key.is_err());
579    }
580
581    /// Test invalid construction and verification of an authentication message with
582    /// an invalid peer ID. This ensures we are signing over it correctly.
583    #[test]
584    fn signature_verify_invalid_peer_id() {
585        // Create a new identity
586        let (_, _, auth_message) = new_identity!();
587
588        // Deserialize the authentication message
589        let mut auth_message: super::AuthMessage<BLSPubKey> =
590            bincode::deserialize(&auth_message).unwrap();
591
592        // Change the peer ID
593        auth_message.peer_id_bytes[0] ^= 0x01;
594
595        // Serialize the message again
596        let auth_message = bincode::serialize(&auth_message).unwrap();
597
598        // Verify the authentication message
599        let public_key = super::AuthMessage::<BLSPubKey>::validate(
600            &bincode::deserialize(&auth_message).unwrap(),
601        );
602        assert!(public_key.is_err());
603    }
604
605    #[tokio::test(flavor = "multi_thread")]
606    async fn valid_authentication() {
607        // Create a new identity
608        let (keypair, peer_id, auth_message) = new_identity!();
609
610        // Create a stream and write the message to it
611        let mut stream = cursor_from!(auth_message);
612
613        // Create a map from consensus keys to peer IDs
614        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
615
616        // Verify the authentication message
617        let result = MockStakeTableAuth::verify_peer_authentication(
618            &mut stream,
619            &peer_id,
620            Arc::clone(&consensus_key_to_pid_map),
621        )
622        .await;
623
624        // Make sure the map has the correct entry
625        assert!(
626            consensus_key_to_pid_map
627                .lock()
628                .get_by_left(&keypair.0)
629                .unwrap()
630                == &peer_id,
631            "Map does not have the correct entry"
632        );
633
634        assert!(
635            result.is_ok(),
636            "Should have passed authentication but did not"
637        );
638    }
639
640    #[tokio::test(flavor = "multi_thread")]
641    async fn peer_id_mismatch() {
642        // Create a new identity and authentication message
643        let (_, _, auth_message) = new_identity!();
644
645        // Create a second (malicious) identity
646        let (_, malicious_peer_id, _) = new_identity!();
647
648        // Create a stream and write the message to it
649        let mut stream = cursor_from!(auth_message);
650
651        // Create a map from consensus keys to peer IDs
652        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
653
654        // Check against the malicious peer ID
655        let result = MockStakeTableAuth::verify_peer_authentication(
656            &mut stream,
657            &malicious_peer_id,
658            Arc::clone(&consensus_key_to_pid_map),
659        )
660        .await;
661
662        // Make sure it errored for the right reason
663        assert!(
664            result
665                .expect_err("Should have failed authentication but did not")
666                .to_string()
667                .contains("Peer ID mismatch"),
668            "Did not fail with the correct error"
669        );
670
671        // Make sure the map does not have the malicious peer ID
672        assert!(
673            consensus_key_to_pid_map.lock().is_empty(),
674            "Malicious peer ID should not be in the map"
675        );
676    }
677
678    #[tokio::test(flavor = "multi_thread")]
679    async fn read_and_write_length_delimited() {
680        // Create a message
681        let message = b"Hello, world!";
682
683        // Write the message to a buffer
684        let mut buffer = Vec::new();
685        write_length_delimited(&mut buffer, message).await.unwrap();
686
687        // Read the message from the buffer
688        let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
689            .await
690            .unwrap();
691
692        // Check if the messages are the same
693        assert_eq!(message, read_message.as_slice());
694    }
695}