1use std::{
2 future::Future,
3 io::{Error as IoError, ErrorKind as IoErrorKind},
4 pin::Pin,
5 sync::Arc,
6 task::Poll,
7};
8
9use anyhow::{Context, Result as AnyhowResult, ensure};
10use bimap::BiMap;
11use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, future::poll_fn};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14 Transport,
15 core::{
16 StreamMuxer,
17 muxing::StreamMuxerExt,
18 transport::{DialOpts, TransportEvent},
19 },
20 identity::PeerId,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::warn;
27
28const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
31
32const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
36
37#[pin_project]
40pub struct ConsensusKeyAuthentication<
41 T: Transport,
42 S: SignatureKey + 'static,
43 C: StreamMuxer + Unpin,
44> {
45 #[pin]
46 pub inner: T,
48
49 pub auth_message: Arc<Option<Vec<u8>>>,
51
52 pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
54
55 pd: std::marker::PhantomData<(C, S)>,
57}
58
59type UpgradeFuture<T> =
61 Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
62
63impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
64 ConsensusKeyAuthentication<T, S, C>
65{
66 pub fn new(
70 inner: T,
71 auth_message: Option<Vec<u8>>,
72 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
73 ) -> Self {
74 Self {
75 inner,
76 auth_message: Arc::from(auth_message),
77 consensus_key_to_pid_map,
78 pd: std::marker::PhantomData,
79 }
80 }
81
82 pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
88 stream: &mut W,
89 auth_message: &[u8],
90 ) -> AnyhowResult<()> {
91 write_length_delimited(stream, auth_message).await?;
93
94 Ok(())
95 }
96
97 pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
111 stream: &mut R,
112 required_peer_id: &PeerId,
113 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
114 ) -> AnyhowResult<()> {
115 let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
117
118 let auth_message: AuthMessage<S> =
120 bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
121
122 let public_key = auth_message
124 .validate()
125 .with_context(|| "Failed to verify authentication message")?;
126
127 let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
129 .with_context(|| "Failed to deserialize peer ID")?;
130
131 if peer_id != *required_peer_id {
133 return Err(anyhow::anyhow!("Peer ID mismatch"));
134 }
135
136 consensus_key_to_pid_map.lock().insert(public_key, peer_id);
138
139 Ok(())
140 }
141
142 fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
147 original_future: F,
148 outgoing: bool,
149 auth_message: Arc<Option<Vec<u8>>>,
150 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
151 ) -> UpgradeFuture<T>
152 where
153 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
154 T::Output: AsOutput<C> + Send,
155 C::Substream: Unpin + Send,
156 {
157 Box::pin(async move {
159 let mut stream = original_future.await?;
161
162 timeout(AUTH_HANDSHAKE_TIMEOUT, async {
164 let mut substream = if outgoing {
167 poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
168 } else {
169 poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
170 };
171
172 if let Some(auth_message) = auth_message.as_ref() {
174 if outgoing {
175 Self::authenticate_with_remote_peer(&mut substream, auth_message)
177 .await
178 .map_err(|e| {
179 warn!("Failed to authenticate with remote peer: {e:?}");
180 IoError::other(e)
181 })?;
182
183 Self::verify_peer_authentication(
185 &mut substream,
186 stream.as_peer_id(),
187 consensus_key_to_pid_map,
188 )
189 .await
190 .map_err(|e| {
191 warn!("Failed to verify remote peer: {e:?}");
192 IoError::other(e)
193 })?;
194 } else {
195 Self::verify_peer_authentication(
197 &mut substream,
198 stream.as_peer_id(),
199 consensus_key_to_pid_map,
200 )
201 .await
202 .map_err(|e| {
203 warn!("Failed to verify remote peer: {e:?}");
204 IoError::other(e)
205 })?;
206
207 Self::authenticate_with_remote_peer(&mut substream, auth_message)
209 .await
210 .map_err(|e| {
211 warn!("Failed to authenticate with remote peer: {e:?}");
212 IoError::other(e)
213 })?;
214 }
215 }
216
217 Ok(stream)
218 })
219 .await
220 .map_err(|e| {
221 warn!("Timed out performing authentication handshake: {e:?}");
222 IoError::new(IoErrorKind::TimedOut, e)
223 })?
224 })
225 }
226}
227
228#[derive(Clone, Serialize, Deserialize)]
230struct AuthMessage<S: SignatureKey> {
231 public_key_bytes: Vec<u8>,
234
235 peer_id_bytes: Vec<u8>,
238
239 signature: S::PureAssembledSignatureType,
241}
242
243impl<S: SignatureKey> AuthMessage<S> {
244 pub fn validate(&self) -> AnyhowResult<S> {
246 let public_key = S::from_bytes(&self.public_key_bytes)
248 .with_context(|| "Failed to deserialize public key")?;
249
250 let mut signed_message = public_key.to_bytes();
252 signed_message.extend(self.peer_id_bytes.clone());
253
254 if !public_key.validate(&self.signature, &signed_message) {
256 return Err(anyhow::anyhow!("Invalid signature"));
257 }
258
259 Ok(public_key)
260 }
261}
262
263pub fn construct_auth_message<S: SignatureKey + 'static>(
269 public_key: &S,
270 peer_id: &PeerId,
271 private_key: &S::PrivateKey,
272) -> AnyhowResult<Vec<u8>> {
273 let mut public_key_bytes = public_key.to_bytes();
275
276 let peer_id_bytes = peer_id.to_bytes();
278 public_key_bytes.extend_from_slice(&peer_id_bytes);
279
280 let signature =
282 S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
283
284 let auth_message = AuthMessage::<S> {
286 public_key_bytes,
287 peer_id_bytes,
288 signature,
289 };
290
291 bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
293}
294
295impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
296 for ConsensusKeyAuthentication<T, S, C>
297where
298 T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
299 T::ListenerUpgrade: Send + 'static,
300 T::Output: AsOutput<C> + Send,
301 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
302 C::Substream: Unpin + Send,
303{
304 type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
306 type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
307
308 type Output = T::Output;
310 type Error = T::Error;
311
312 fn dial(
315 &mut self,
316 addr: libp2p::Multiaddr,
317 opts: DialOpts,
318 ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
319 let res = self.inner.dial(addr, opts);
321
322 let auth_message = Arc::clone(&self.auth_message);
324
325 match res {
327 Ok(dial) => Ok(Self::gen_handshake(
328 dial,
329 true,
330 auth_message,
331 Arc::clone(&self.consensus_key_to_pid_map),
332 )),
333 Err(err) => Err(err),
334 }
335 }
336
337 fn poll(
341 mut self: std::pin::Pin<&mut Self>,
342 cx: &mut std::task::Context<'_>,
343 ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
344 {
345 match Transport::poll(self.as_mut().project().inner, cx) {
346 Poll::Ready(event) => Poll::Ready(match event {
347 TransportEvent::Incoming {
349 listener_id,
350 upgrade,
351 local_addr,
352 send_back_addr,
353 } => {
354 let auth_message = Arc::clone(&self.auth_message);
356
357 let auth_upgrade = Self::gen_handshake(
359 upgrade,
360 false,
361 auth_message,
362 Arc::clone(&self.consensus_key_to_pid_map),
363 );
364
365 TransportEvent::Incoming {
367 listener_id,
368 upgrade: auth_upgrade,
369 local_addr,
370 send_back_addr,
371 }
372 },
373
374 TransportEvent::AddressExpired {
376 listener_id,
377 listen_addr,
378 } => TransportEvent::AddressExpired {
379 listener_id,
380 listen_addr,
381 },
382 TransportEvent::ListenerClosed {
383 listener_id,
384 reason,
385 } => TransportEvent::ListenerClosed {
386 listener_id,
387 reason,
388 },
389 TransportEvent::ListenerError { listener_id, error } => {
390 TransportEvent::ListenerError { listener_id, error }
391 },
392 TransportEvent::NewAddress {
393 listener_id,
394 listen_addr,
395 } => TransportEvent::NewAddress {
396 listener_id,
397 listen_addr,
398 },
399 }),
400
401 Poll::Pending => Poll::Pending,
402 }
403 }
404
405 fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
408 self.inner.remove_listener(id)
409 }
410 fn listen_on(
411 &mut self,
412 id: libp2p::core::transport::ListenerId,
413 addr: libp2p::Multiaddr,
414 ) -> Result<(), libp2p::TransportError<Self::Error>> {
415 self.inner.listen_on(id, addr)
416 }
417}
418
419trait AsOutput<C: StreamMuxer + Unpin> {
422 fn as_connection(&mut self) -> &mut C;
424
425 fn as_peer_id(&mut self) -> &mut PeerId;
427}
428
429impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
432 fn as_connection(&mut self) -> &mut C {
434 &mut self.1
435 }
436
437 fn as_peer_id(&mut self) -> &mut PeerId {
439 &mut self.0
440 }
441}
442
443pub async fn read_length_delimited<S: AsyncRead + Unpin>(
450 stream: &mut S,
451 max_size: usize,
452) -> AnyhowResult<Vec<u8>> {
453 let mut len_bytes = [0u8; 4];
455 stream
456 .read_exact(&mut len_bytes)
457 .await
458 .with_context(|| "Failed to read message length")?;
459
460 let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
462
463 ensure!(len <= max_size, "Message too large");
465
466 let mut message = vec![0u8; len];
468 stream
469 .read_exact(&mut message)
470 .await
471 .with_context(|| "Failed to read message")?;
472
473 Ok(message)
474}
475
476pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
481 stream: &mut S,
482 message: &[u8],
483) -> AnyhowResult<()> {
484 stream
486 .write_all(&u32::try_from(message.len())?.to_be_bytes())
487 .await
488 .with_context(|| "Failed to write message length")?;
489
490 stream
492 .write_all(message)
493 .await
494 .with_context(|| "Failed to write message")?;
495
496 Ok(())
497}
498
499#[cfg(test)]
500mod test {
501 use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
502 use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
503 use rand::Rng;
504
505 use super::*;
506
507 type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
509
510 macro_rules! new_identity {
512 () => {{
513 let seed = rand::rngs::OsRng.r#gen::<[u8; 32]>();
515
516 let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
518
519 let peer_id = libp2p::identity::Keypair::generate_ed25519()
521 .public()
522 .to_peer_id();
523
524 let auth_message =
526 super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
527
528 (keypair, peer_id, auth_message)
529 }};
530 }
531
532 macro_rules! cursor_from {
534 ($auth_message:expr) => {{
535 let mut stream = futures::io::Cursor::new(vec![]);
536 write_length_delimited(&mut stream, &$auth_message)
537 .await
538 .expect("Failed to write message");
539 stream.set_position(0);
540 stream
541 }};
542 }
543
544 #[test]
546 fn signature_verify() {
547 let (_, _, auth_message) = new_identity!();
549
550 let public_key = super::AuthMessage::<BLSPubKey>::validate(
552 &bincode::deserialize(&auth_message).unwrap(),
553 );
554 assert!(public_key.is_ok());
555 }
556
557 #[test]
560 fn signature_verify_invalid_public_key() {
561 let (_, _, auth_message) = new_identity!();
563
564 let mut auth_message: super::AuthMessage<BLSPubKey> =
566 bincode::deserialize(&auth_message).unwrap();
567
568 auth_message.public_key_bytes[0] ^= 0x01;
570
571 let auth_message = bincode::serialize(&auth_message).unwrap();
573
574 let public_key = super::AuthMessage::<BLSPubKey>::validate(
576 &bincode::deserialize(&auth_message).unwrap(),
577 );
578 assert!(public_key.is_err());
579 }
580
581 #[test]
584 fn signature_verify_invalid_peer_id() {
585 let (_, _, auth_message) = new_identity!();
587
588 let mut auth_message: super::AuthMessage<BLSPubKey> =
590 bincode::deserialize(&auth_message).unwrap();
591
592 auth_message.peer_id_bytes[0] ^= 0x01;
594
595 let auth_message = bincode::serialize(&auth_message).unwrap();
597
598 let public_key = super::AuthMessage::<BLSPubKey>::validate(
600 &bincode::deserialize(&auth_message).unwrap(),
601 );
602 assert!(public_key.is_err());
603 }
604
605 #[tokio::test(flavor = "multi_thread")]
606 async fn valid_authentication() {
607 let (keypair, peer_id, auth_message) = new_identity!();
609
610 let mut stream = cursor_from!(auth_message);
612
613 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
615
616 let result = MockStakeTableAuth::verify_peer_authentication(
618 &mut stream,
619 &peer_id,
620 Arc::clone(&consensus_key_to_pid_map),
621 )
622 .await;
623
624 assert!(
626 consensus_key_to_pid_map
627 .lock()
628 .get_by_left(&keypair.0)
629 .unwrap()
630 == &peer_id,
631 "Map does not have the correct entry"
632 );
633
634 assert!(
635 result.is_ok(),
636 "Should have passed authentication but did not"
637 );
638 }
639
640 #[tokio::test(flavor = "multi_thread")]
641 async fn peer_id_mismatch() {
642 let (_, _, auth_message) = new_identity!();
644
645 let (_, malicious_peer_id, _) = new_identity!();
647
648 let mut stream = cursor_from!(auth_message);
650
651 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
653
654 let result = MockStakeTableAuth::verify_peer_authentication(
656 &mut stream,
657 &malicious_peer_id,
658 Arc::clone(&consensus_key_to_pid_map),
659 )
660 .await;
661
662 assert!(
664 result
665 .expect_err("Should have failed authentication but did not")
666 .to_string()
667 .contains("Peer ID mismatch"),
668 "Did not fail with the correct error"
669 );
670
671 assert!(
673 consensus_key_to_pid_map.lock().is_empty(),
674 "Malicious peer ID should not be in the map"
675 );
676 }
677
678 #[tokio::test(flavor = "multi_thread")]
679 async fn read_and_write_length_delimited() {
680 let message = b"Hello, world!";
682
683 let mut buffer = Vec::new();
685 write_length_delimited(&mut buffer, message).await.unwrap();
686
687 let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
689 .await
690 .unwrap();
691
692 assert_eq!(message, read_message.as_slice());
694 }
695}