1use std::{
2 future::Future,
3 io::{Error as IoError, ErrorKind as IoErrorKind},
4 pin::Pin,
5 sync::Arc,
6 task::Poll,
7};
8
9use anyhow::{ensure, Context, Result as AnyhowResult};
10use bimap::BiMap;
11use futures::{future::poll_fn, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14 core::{
15 muxing::StreamMuxerExt,
16 transport::{DialOpts, TransportEvent},
17 StreamMuxer,
18 },
19 identity::PeerId,
20 Transport,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::warn;
27
28const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
31
32const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
36
37#[pin_project]
40pub struct ConsensusKeyAuthentication<
41 T: Transport,
42 S: SignatureKey + 'static,
43 C: StreamMuxer + Unpin,
44> {
45 #[pin]
46 pub inner: T,
48
49 pub auth_message: Arc<Option<Vec<u8>>>,
51
52 pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
54
55 pd: std::marker::PhantomData<(C, S)>,
57}
58
59type UpgradeFuture<T> =
61 Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
62
63impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
64 ConsensusKeyAuthentication<T, S, C>
65{
66 pub fn new(
70 inner: T,
71 auth_message: Option<Vec<u8>>,
72 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
73 ) -> Self {
74 Self {
75 inner,
76 auth_message: Arc::from(auth_message),
77 consensus_key_to_pid_map,
78 pd: std::marker::PhantomData,
79 }
80 }
81
82 pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
88 stream: &mut W,
89 auth_message: &[u8],
90 ) -> AnyhowResult<()> {
91 write_length_delimited(stream, auth_message).await?;
93
94 Ok(())
95 }
96
97 pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
111 stream: &mut R,
112 required_peer_id: &PeerId,
113 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
114 ) -> AnyhowResult<()> {
115 let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
117
118 let auth_message: AuthMessage<S> =
120 bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
121
122 let public_key = auth_message
124 .validate()
125 .with_context(|| "Failed to verify authentication message")?;
126
127 let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
129 .with_context(|| "Failed to deserialize peer ID")?;
130
131 if peer_id != *required_peer_id {
133 return Err(anyhow::anyhow!("Peer ID mismatch"));
134 }
135
136 consensus_key_to_pid_map.lock().insert(public_key, peer_id);
138
139 Ok(())
140 }
141
142 fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
147 original_future: F,
148 outgoing: bool,
149 auth_message: Arc<Option<Vec<u8>>>,
150 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
151 ) -> UpgradeFuture<T>
152 where
153 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
154 T::Output: AsOutput<C> + Send,
155
156 C::Substream: Unpin + Send,
157 {
158 Box::pin(async move {
160 let mut stream = original_future.await?;
162
163 timeout(AUTH_HANDSHAKE_TIMEOUT, async {
165 let mut substream = if outgoing {
168 poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
169 } else {
170 poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
171 };
172
173 if let Some(auth_message) = auth_message.as_ref() {
175 if outgoing {
176 Self::authenticate_with_remote_peer(&mut substream, auth_message)
178 .await
179 .map_err(|e| {
180 warn!("Failed to authenticate with remote peer: {e:?}");
181 IoError::other(e)
182 })?;
183
184 Self::verify_peer_authentication(
186 &mut substream,
187 stream.as_peer_id(),
188 consensus_key_to_pid_map,
189 )
190 .await
191 .map_err(|e| {
192 warn!("Failed to verify remote peer: {e:?}");
193 IoError::other(e)
194 })?;
195 } else {
196 Self::verify_peer_authentication(
198 &mut substream,
199 stream.as_peer_id(),
200 consensus_key_to_pid_map,
201 )
202 .await
203 .map_err(|e| {
204 warn!("Failed to verify remote peer: {e:?}");
205 IoError::other(e)
206 })?;
207
208 Self::authenticate_with_remote_peer(&mut substream, auth_message)
210 .await
211 .map_err(|e| {
212 warn!("Failed to authenticate with remote peer: {e:?}");
213 IoError::other(e)
214 })?;
215 }
216 }
217
218 Ok(stream)
219 })
220 .await
221 .map_err(|e| {
222 warn!("Timed out performing authentication handshake: {e:?}");
223 IoError::new(IoErrorKind::TimedOut, e)
224 })?
225 })
226 }
227}
228
229#[derive(Clone, Serialize, Deserialize)]
231struct AuthMessage<S: SignatureKey> {
232 public_key_bytes: Vec<u8>,
235
236 peer_id_bytes: Vec<u8>,
239
240 signature: S::PureAssembledSignatureType,
242}
243
244impl<S: SignatureKey> AuthMessage<S> {
245 pub fn validate(&self) -> AnyhowResult<S> {
247 let public_key = S::from_bytes(&self.public_key_bytes)
249 .with_context(|| "Failed to deserialize public key")?;
250
251 let mut signed_message = public_key.to_bytes();
253 signed_message.extend(self.peer_id_bytes.clone());
254
255 if !public_key.validate(&self.signature, &signed_message) {
257 return Err(anyhow::anyhow!("Invalid signature"));
258 }
259
260 Ok(public_key)
261 }
262}
263
264pub fn construct_auth_message<S: SignatureKey + 'static>(
270 public_key: &S,
271 peer_id: &PeerId,
272 private_key: &S::PrivateKey,
273) -> AnyhowResult<Vec<u8>> {
274 let mut public_key_bytes = public_key.to_bytes();
276
277 let peer_id_bytes = peer_id.to_bytes();
279 public_key_bytes.extend_from_slice(&peer_id_bytes);
280
281 let signature =
283 S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
284
285 let auth_message = AuthMessage::<S> {
287 public_key_bytes,
288 peer_id_bytes,
289 signature,
290 };
291
292 bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
294}
295
296impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
297 for ConsensusKeyAuthentication<T, S, C>
298where
299 T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
300 T::ListenerUpgrade: Send + 'static,
301 T::Output: AsOutput<C> + Send,
302 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
303
304 C::Substream: Unpin + Send,
305{
306 type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
308 type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
309
310 type Output = T::Output;
312 type Error = T::Error;
313
314 fn dial(
317 &mut self,
318 addr: libp2p::Multiaddr,
319 opts: DialOpts,
320 ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
321 let res = self.inner.dial(addr, opts);
323
324 let auth_message = Arc::clone(&self.auth_message);
326
327 match res {
329 Ok(dial) => Ok(Self::gen_handshake(
330 dial,
331 true,
332 auth_message,
333 Arc::clone(&self.consensus_key_to_pid_map),
334 )),
335 Err(err) => Err(err),
336 }
337 }
338
339 fn poll(
343 mut self: std::pin::Pin<&mut Self>,
344 cx: &mut std::task::Context<'_>,
345 ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
346 {
347 match self.as_mut().project().inner.poll(cx) {
348 Poll::Ready(event) => Poll::Ready(match event {
349 TransportEvent::Incoming {
351 listener_id,
352 upgrade,
353 local_addr,
354 send_back_addr,
355 } => {
356 let auth_message = Arc::clone(&self.auth_message);
358
359 let auth_upgrade = Self::gen_handshake(
361 upgrade,
362 false,
363 auth_message,
364 Arc::clone(&self.consensus_key_to_pid_map),
365 );
366
367 TransportEvent::Incoming {
369 listener_id,
370 upgrade: auth_upgrade,
371 local_addr,
372 send_back_addr,
373 }
374 },
375
376 TransportEvent::AddressExpired {
378 listener_id,
379 listen_addr,
380 } => TransportEvent::AddressExpired {
381 listener_id,
382 listen_addr,
383 },
384 TransportEvent::ListenerClosed {
385 listener_id,
386 reason,
387 } => TransportEvent::ListenerClosed {
388 listener_id,
389 reason,
390 },
391 TransportEvent::ListenerError { listener_id, error } => {
392 TransportEvent::ListenerError { listener_id, error }
393 },
394 TransportEvent::NewAddress {
395 listener_id,
396 listen_addr,
397 } => TransportEvent::NewAddress {
398 listener_id,
399 listen_addr,
400 },
401 }),
402
403 Poll::Pending => Poll::Pending,
404 }
405 }
406
407 fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
410 self.inner.remove_listener(id)
411 }
412 fn listen_on(
413 &mut self,
414 id: libp2p::core::transport::ListenerId,
415 addr: libp2p::Multiaddr,
416 ) -> Result<(), libp2p::TransportError<Self::Error>> {
417 self.inner.listen_on(id, addr)
418 }
419}
420
421trait AsOutput<C: StreamMuxer + Unpin> {
424 fn as_connection(&mut self) -> &mut C;
426
427 fn as_peer_id(&mut self) -> &mut PeerId;
429}
430
431impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
434 fn as_connection(&mut self) -> &mut C {
436 &mut self.1
437 }
438
439 fn as_peer_id(&mut self) -> &mut PeerId {
441 &mut self.0
442 }
443}
444
445pub async fn read_length_delimited<S: AsyncRead + Unpin>(
452 stream: &mut S,
453 max_size: usize,
454) -> AnyhowResult<Vec<u8>> {
455 let mut len_bytes = [0u8; 4];
457 stream
458 .read_exact(&mut len_bytes)
459 .await
460 .with_context(|| "Failed to read message length")?;
461
462 let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
464
465 ensure!(len <= max_size, "Message too large");
467
468 let mut message = vec![0u8; len];
470 stream
471 .read_exact(&mut message)
472 .await
473 .with_context(|| "Failed to read message")?;
474
475 Ok(message)
476}
477
478pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
483 stream: &mut S,
484 message: &[u8],
485) -> AnyhowResult<()> {
486 stream
488 .write_all(&u32::try_from(message.len())?.to_be_bytes())
489 .await
490 .with_context(|| "Failed to write message length")?;
491
492 stream
494 .write_all(message)
495 .await
496 .with_context(|| "Failed to write message")?;
497
498 Ok(())
499}
500
501#[cfg(test)]
502mod test {
503 use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
504 use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
505 use rand::Rng;
506
507 use super::*;
508
509 type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
511
512 macro_rules! new_identity {
514 () => {{
515 let seed = rand::rngs::OsRng.gen::<[u8; 32]>();
517
518 let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
520
521 let peer_id = libp2p::identity::Keypair::generate_ed25519()
523 .public()
524 .to_peer_id();
525
526 let auth_message =
528 super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
529
530 (keypair, peer_id, auth_message)
531 }};
532 }
533
534 macro_rules! cursor_from {
536 ($auth_message:expr) => {{
537 let mut stream = futures::io::Cursor::new(vec![]);
538 write_length_delimited(&mut stream, &$auth_message)
539 .await
540 .expect("Failed to write message");
541 stream.set_position(0);
542 stream
543 }};
544 }
545
546 #[test]
548 fn signature_verify() {
549 let (_, _, auth_message) = new_identity!();
551
552 let public_key = super::AuthMessage::<BLSPubKey>::validate(
554 &bincode::deserialize(&auth_message).unwrap(),
555 );
556 assert!(public_key.is_ok());
557 }
558
559 #[test]
562 fn signature_verify_invalid_public_key() {
563 let (_, _, auth_message) = new_identity!();
565
566 let mut auth_message: super::AuthMessage<BLSPubKey> =
568 bincode::deserialize(&auth_message).unwrap();
569
570 auth_message.public_key_bytes[0] ^= 0x01;
572
573 let auth_message = bincode::serialize(&auth_message).unwrap();
575
576 let public_key = super::AuthMessage::<BLSPubKey>::validate(
578 &bincode::deserialize(&auth_message).unwrap(),
579 );
580 assert!(public_key.is_err());
581 }
582
583 #[test]
586 fn signature_verify_invalid_peer_id() {
587 let (_, _, auth_message) = new_identity!();
589
590 let mut auth_message: super::AuthMessage<BLSPubKey> =
592 bincode::deserialize(&auth_message).unwrap();
593
594 auth_message.peer_id_bytes[0] ^= 0x01;
596
597 let auth_message = bincode::serialize(&auth_message).unwrap();
599
600 let public_key = super::AuthMessage::<BLSPubKey>::validate(
602 &bincode::deserialize(&auth_message).unwrap(),
603 );
604 assert!(public_key.is_err());
605 }
606
607 #[tokio::test(flavor = "multi_thread")]
608 async fn valid_authentication() {
609 let (keypair, peer_id, auth_message) = new_identity!();
611
612 let mut stream = cursor_from!(auth_message);
614
615 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
617
618 let result = MockStakeTableAuth::verify_peer_authentication(
620 &mut stream,
621 &peer_id,
622 Arc::clone(&consensus_key_to_pid_map),
623 )
624 .await;
625
626 assert!(
628 consensus_key_to_pid_map
629 .lock()
630 .get_by_left(&keypair.0)
631 .unwrap()
632 == &peer_id,
633 "Map does not have the correct entry"
634 );
635
636 assert!(
637 result.is_ok(),
638 "Should have passed authentication but did not"
639 );
640 }
641
642 #[tokio::test(flavor = "multi_thread")]
643 async fn peer_id_mismatch() {
644 let (_, _, auth_message) = new_identity!();
646
647 let (_, malicious_peer_id, _) = new_identity!();
649
650 let mut stream = cursor_from!(auth_message);
652
653 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
655
656 let result = MockStakeTableAuth::verify_peer_authentication(
658 &mut stream,
659 &malicious_peer_id,
660 Arc::clone(&consensus_key_to_pid_map),
661 )
662 .await;
663
664 assert!(
666 result
667 .expect_err("Should have failed authentication but did not")
668 .to_string()
669 .contains("Peer ID mismatch"),
670 "Did not fail with the correct error"
671 );
672
673 assert!(
675 consensus_key_to_pid_map.lock().is_empty(),
676 "Malicious peer ID should not be in the map"
677 );
678 }
679
680 #[tokio::test(flavor = "multi_thread")]
681 async fn read_and_write_length_delimited() {
682 let message = b"Hello, world!";
684
685 let mut buffer = Vec::new();
687 write_length_delimited(&mut buffer, message).await.unwrap();
688
689 let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
691 .await
692 .unwrap();
693
694 assert_eq!(message, read_message.as_slice());
696 }
697}