1#![doc = include_str!("../README.md")]
2
3use std::{
4 collections::HashMap,
5 fmt::Display,
6 future::pending,
7 hash::Hash,
8 iter::{once, repeat},
9 result::Result as StdResult,
10 sync::Arc,
11 time::Duration,
12};
13
14use bimap::BiHashMap;
15use bytes::{Bytes, BytesMut};
16use hotshot_types::addr::NetAddr;
17use parking_lot::Mutex;
18use snow::{Builder, HandshakeState, TransportState};
19use tokio::{
20 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
21 net::{TcpListener, TcpStream},
22 spawn,
23 sync::{
24 Mutex as AsyncMutex, OwnedSemaphorePermit, Semaphore,
25 mpsc::{self, Receiver, Sender},
26 },
27 task::{self, AbortHandle, JoinHandle, JoinSet},
28 time::{Interval, MissedTickBehavior, sleep, timeout},
29};
30use tracing::{debug, error, info, trace, warn};
31
32#[cfg(feature = "metrics")]
33use crate::metrics::NetworkMetrics;
34use crate::{
35 Id, Keypair, LAST_DELAY, NUM_DELAYS, NetConf, NetworkDown, NetworkError, PublicKey, Role, chan,
36 error::Empty,
37 frame::{Header, Type},
38 time::{Countdown, Timestamp},
39};
40
41type Budget = Arc<Semaphore>;
42type Result<T> = std::result::Result<T, NetworkError>;
43
44const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
46
47const MAX_NOISE_MESSAGE_SIZE: usize = 64 * 1024;
49
50const MAX_PAYLOAD_SIZE: usize = MAX_NOISE_MESSAGE_SIZE - 32;
52
53const NOISE_PARAMS: &str = "Noise_IK_25519_AESGCM_BLAKE2s";
55
56const PING_INTERVAL: Duration = Duration::from_secs(15);
58
59const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61
62const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
64
65const REPLY_TIMEOUT: Duration = Duration::from_secs(30);
71
72#[derive(Debug)]
74pub struct Network<K> {
75 pub(crate) name: &'static str,
77
78 pub(crate) label: K,
80
81 pub(crate) max_message_size: usize,
83
84 parties: Mutex<HashMap<K, Role>>,
86
87 tx: Sender<Command<K>>,
89
90 rx: AsyncMutex<Receiver<(K, Bytes, Option<OwnedSemaphorePermit>)>>,
94
95 srv: Mutex<Option<JoinHandle<Result<Empty>>>>,
97}
98
99impl<K> Drop for Network<K> {
100 fn drop(&mut self) {
101 if let Some(srv) = self.srv.lock().take() {
102 srv.abort();
103 }
104 }
105}
106
107#[derive(Debug)]
109pub(crate) enum Command<K> {
110 Add(Role, Vec<(K, PublicKey, NetAddr)>),
112 Remove(Vec<K>),
114 Assign(Role, Vec<K>),
116 Unicast(K, Option<Id>, Bytes),
118 Multicast(Vec<K>, Option<Id>, Bytes),
120 Broadcast(Option<Id>, Bytes),
122}
123
124#[derive(Debug)]
127struct Server<K> {
128 conf: NetConf<K>,
129
130 role: Role,
132
133 ibound: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
137
138 obound: Receiver<Command<K>>,
142
143 peers: HashMap<K, Peer>,
145
146 index: BiHashMap<K, PublicKey>,
149
150 task2key: HashMap<task::Id, K>,
152
153 connecting: HashMap<K, ConnectTask>,
155
156 active: HashMap<K, IoTask>,
158
159 handshake_tasks: JoinSet<Result<(TcpStream, TransportState)>>,
161
162 connect_tasks: JoinSet<(TcpStream, TransportState)>,
164
165 io_tasks: JoinSet<Result<()>>,
167
168 ping_interval: Interval,
170
171 #[cfg(feature = "metrics")]
173 metrics: Arc<NetworkMetrics<K>>,
174}
175
176#[derive(Debug)]
177struct Peer {
178 addr: NetAddr,
179 role: Role,
180 budget: Budget,
181}
182
183#[derive(Debug)]
185struct ConnectTask {
186 h: AbortHandle,
187}
188
189impl Drop for ConnectTask {
191 fn drop(&mut self) {
192 self.h.abort();
193 }
194}
195
196#[derive(Debug)]
198struct IoTask {
199 rh: AbortHandle,
201
202 wh: AbortHandle,
204
205 tx: chan::Sender<Message>,
207}
208
209impl Drop for IoTask {
211 fn drop(&mut self) {
212 self.rh.abort();
213 self.wh.abort();
214 }
215}
216
217#[derive(Debug)]
219enum Message {
220 Data(Bytes),
221 Ping(Timestamp),
222 Pong(Timestamp),
223}
224
225impl<K> Network<K> {
226 pub fn public_key(&self) -> &K {
227 &self.label
228 }
229
230 pub fn name(&self) -> &str {
231 self.name
232 }
233
234 pub(crate) fn sender(&self) -> Sender<Command<K>> {
236 self.tx.clone()
237 }
238}
239
240impl<K: Display> Network<K> {
241 pub async fn close(&self) {
242 let srv = self.srv.lock().take();
243 if let Some(srv) = srv {
244 info!(name = %self.name, node = %self.label, "aborting server task");
245 srv.abort();
246 let _ = srv.await;
247 }
248 }
249}
250
251impl<K> Network<K>
252where
253 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
254{
255 pub async fn create(cfg: NetConf<K>) -> Result<Self> {
256 let listener = TcpListener::bind(cfg.bind.to_string())
257 .await
258 .map_err(|e| NetworkError::Bind(cfg.bind.clone(), e))?;
259
260 debug!(
261 name = %cfg.name,
262 node = %cfg.label,
263 addr = %listener.local_addr()?,
264 "listening"
265 );
266
267 let mut parties = HashMap::new();
268 let mut peers = HashMap::new();
269 let mut index = BiHashMap::new();
270
271 for (k, x, a) in cfg.parties.iter().cloned() {
272 parties.insert(k.clone(), Role::Active);
273 index.insert(k.clone(), x);
274 peers.insert(
275 k,
276 Peer {
277 addr: a,
278 role: Role::Active,
279 budget: cfg.new_budget(),
280 },
281 );
282 }
283
284 let (otx, orx) = mpsc::channel(cfg.total_capacity_egress);
286
287 let (itx, irx) = mpsc::channel(cfg.total_capacity_ingress);
289
290 let mut interval = tokio::time::interval(PING_INTERVAL);
291 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
292
293 let name = cfg.name;
294 let label = cfg.label.clone();
295 let mmsze = cfg.max_message_size;
296
297 #[cfg(feature = "metrics")]
298 let metrics = {
299 let it = parties.keys().filter(|k| **k != label).cloned();
300 NetworkMetrics::new(name, &*cfg.metrics, it)
301 };
302
303 let server = Server {
304 conf: cfg,
305 role: Role::Active,
306 ibound: itx,
307 obound: orx,
308 peers,
309 index,
310 connecting: HashMap::new(),
311 active: HashMap::new(),
312 task2key: HashMap::new(),
313 handshake_tasks: JoinSet::new(),
314 connect_tasks: JoinSet::new(),
315 io_tasks: JoinSet::new(),
316 ping_interval: interval,
317 #[cfg(feature = "metrics")]
318 metrics: Arc::new(metrics),
319 };
320
321 Ok(Self {
322 name,
323 label,
324 parties: Mutex::new(parties),
325 rx: AsyncMutex::new(irx),
326 tx: otx,
327 srv: Mutex::new(Some(spawn(server.run(listener)))),
328 max_message_size: mmsze,
329 })
330 }
331
332 pub fn parties(&self, r: Option<Role>) -> Vec<K> {
333 self.parties
334 .lock()
335 .iter()
336 .filter(|&(_, x)| r.map(|r| r == *x).unwrap_or(true))
337 .map(|(k, _)| k.clone())
338 .collect()
339 }
340
341 pub async fn unicast(&self, to: K, msg: Bytes) -> Result<()> {
343 if msg.len() > self.max_message_size {
344 warn!(
345 name = %self.name,
346 node = %self.label,
347 to = %to,
348 len = %msg.len(),
349 max = %self.max_message_size,
350 "message too large to send"
351 );
352 return Err(NetworkError::MessageTooLarge);
353 }
354 self.tx
355 .send(Command::Unicast(to, None, msg))
356 .await
357 .map_err(|_| NetworkError::ChannelClosed)
358 }
359
360 pub async fn broadcast(&self, msg: Bytes) -> Result<()> {
362 if msg.len() > self.max_message_size {
363 warn!(
364 name = %self.name,
365 node = %self.label,
366 len = %msg.len(),
367 max = %self.max_message_size,
368 "message too large to broadcast"
369 );
370 return Err(NetworkError::MessageTooLarge);
371 }
372 self.tx
373 .send(Command::Broadcast(None, msg))
374 .await
375 .map_err(|_| NetworkError::ChannelClosed)
376 }
377
378 pub async fn receive(&self) -> StdResult<(K, Bytes), NetworkDown> {
380 let mut rx = self.rx.lock().await;
381 let (k, b, _) = rx.recv().await.ok_or(NetworkDown::new())?;
382 Ok((k, b))
383 }
384
385 pub async fn add(
387 &self,
388 r: Role,
389 peers: Vec<(K, PublicKey, NetAddr)>,
390 ) -> StdResult<(), NetworkDown> {
391 self.parties
392 .lock()
393 .extend(peers.iter().map(|(p, ..)| (p.clone(), r)));
394 self.tx
395 .send(Command::Add(r, peers))
396 .await
397 .map_err(|_| NetworkDown::new())
398 }
399
400 pub async fn remove(&self, peers: Vec<K>) -> StdResult<(), NetworkDown> {
402 {
403 let mut parties = self.parties.lock();
404 for p in &peers {
405 parties.remove(p);
406 }
407 }
408 self.tx
409 .send(Command::Remove(peers))
410 .await
411 .map_err(|_| NetworkDown::new())
412 }
413
414 pub async fn assign(&self, r: Role, peers: Vec<K>) -> StdResult<(), NetworkDown> {
416 {
417 let mut parties = self.parties.lock();
418 for p in &peers {
419 if let Some(role) = parties.get_mut(p) {
420 *role = r
421 }
422 }
423 }
424 self.tx
425 .send(Command::Assign(r, peers))
426 .await
427 .map_err(|_| NetworkDown::new())
428 }
429}
430
431impl<K> Server<K>
432where
433 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
434{
435 async fn run(mut self, listener: TcpListener) -> Result<Empty> {
443 self.handshake_tasks.spawn(pending());
444 self.io_tasks.spawn(pending());
445
446 for k in self
448 .peers
449 .keys()
450 .filter(|k| **k != self.conf.label)
451 .cloned()
452 .collect::<Vec<_>>()
453 {
454 self.spawn_connect(k)
455 }
456
457 loop {
458 trace!(
459 name = %self.conf.name,
460 node = %self.conf.label,
461 active = %self.active.len(),
462 connects = %self.connect_tasks.len(),
463 handshakes = %self.handshake_tasks.len().saturating_sub(1), io_tasks = %self.io_tasks.len().saturating_sub(1), tasks_ids = %self.task2key.len(),
466 iqueue = %self.ibound.capacity(),
467 oqueue = %self.obound.capacity(),
468 );
469
470 #[cfg(feature = "metrics")]
471 {
472 self.metrics.iqueue.set(self.ibound.capacity());
473 self.metrics.oqueue.set(self.obound.capacity());
474 }
475
476 tokio::select! {
477 i = listener.accept() => match i {
479 Ok((s, a)) => {
480 debug!(
481 name = %self.conf.name,
482 node = %self.conf.label,
483 addr = %a,
484 "accepted connection"
485 );
486 self.spawn_handshake(s)
487 }
488 Err(e) => {
489 warn!(
490 name = %self.conf.name,
491 node = %self.conf.label,
492 err = %e,
493 "error accepting connection"
494 )
495 }
496 },
497 Some(h) = self.handshake_tasks.join_next() => match h {
499 Ok(Ok((s, t))) => {
500 let Some((k, peer)) = self.lookup_peer(&t) else {
501 info!(
502 name = %self.conf.name,
503 node = %self.conf.label,
504 peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
505 addr = ?s.peer_addr().ok(),
506 "unknown peer"
507 );
508 continue
509 };
510 if !self.is_valid_ip(&k, &s) {
511 warn!(
512 name = %self.conf.name,
513 node = %self.conf.label,
514 peer = %k,
515 addr = ?s.peer_addr().ok(), "invalid peer ip addr"
516 );
517 continue
518 }
519 if k > self.conf.label || !self.active.contains_key(&k) {
523 self.spawn_io(k, s, t, peer.budget.clone())
524 } else {
525 debug!(
526 name = %self.conf.name,
527 node = %self.conf.label,
528 peer = %k,
529 "dropping accepted connection"
530 );
531 }
532 }
533 Ok(Err(e)) => {
534 warn!(
535 name = %self.conf.name,
536 node = %self.conf.label,
537 err = %e,
538 "handshake failed"
539 )
540 }
541 Err(e) => {
542 if !e.is_cancelled() {
543 error!(
544 name = %self.conf.name,
545 node = %self.conf.label,
546 err = %e,
547 "handshake task panic"
548 )
549 }
550 }
551 },
552 Some(tt) = self.connect_tasks.join_next_with_id() => {
554 match tt {
555 Ok((id, (s, t))) => {
556 self.on_connect_task_end(id);
557 let Some((k, peer)) = self.lookup_peer(&t) else {
558 warn!(
559 name = %self.conf.name,
560 node = %self.conf.label,
561 peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
562 addr = ?s.peer_addr().ok(),
563 "connected to unknown peer"
564 );
565 continue
566 };
567 if k < self.conf.label || !self.active.contains_key(&k) {
570 self.spawn_io(k, s, t, peer.budget.clone())
571 } else {
572 debug!(
573 name = %self.conf.name,
574 node = %self.conf.label,
575 peer = %k,
576 "dropping new connection"
577 )
578 }
579 }
580 Err(e) => {
581 if !e.is_cancelled() {
582 error!(
583 name = %self.conf.name,
584 node = %self.conf.label,
585 err = %e,
586 "connect task panic"
587 )
588 }
589 self.on_connect_task_end(e.id());
590 }
591 }
592 },
593 Some(io) = self.io_tasks.join_next_with_id() => {
595 match io {
596 Ok((id, r)) => {
597 if let Err(e) = r {
598 warn!(
599 name = %self.conf.name,
600 node = %self.conf.label,
601 err = %e,
602 "i/o error"
603 )
604 }
605 self.on_io_task_end(id);
606 }
607 Err(e) => {
608 if e.is_cancelled() {
609 self.task2key.remove(&e.id());
614 continue
615 }
616 error!(
618 name = %self.conf.name,
619 node = %self.conf.label,
620 err = %e,
621 "i/o task panic"
622 );
623 self.on_io_task_end(e.id())
624 }
625 };
626 },
627 cmd = self.obound.recv() => match cmd {
628 Some(Command::Add(role, peers)) => {
629 #[cfg(feature = "metrics")]
630 Arc::make_mut(&mut self.metrics).add_parties(peers.iter().map(|(k, ..)| k).cloned());
631 for (k, x, a) in peers {
632 let mut existing_budget = None;
633 if let Some(peer) = self.peers.get(&k) {
634 let Some(y) = self.index.get_by_left(&k) else {
635 error!(
636 name = %self.conf.name,
637 node = %self.conf.label,
638 peer = %k,
639 addr = %a,
640 "peer is known, but x25519 key not found"
641 );
642 continue
643 };
644 if a == peer.addr && x == *y {
645 debug!(
646 name = %self.conf.name,
647 node = %self.conf.label,
648 peer = %k,
649 addr = %a,
650 "peer to add already exists"
651 );
652 continue
653 }
654 existing_budget = Some(peer.budget.clone());
655 }
656 info!(
657 name = %self.conf.name,
658 node = %self.conf.label,
659 role = %role,
660 peer = %k,
661 addr = %a,
662 replace = %existing_budget.is_some(),
663 "adding or replacing peer"
664 );
665 let p = Peer {
666 addr: a,
667 role,
668 budget: existing_budget.unwrap_or_else(|| self.conf.new_budget())
669 };
670 self.peers.insert(k.clone(), p);
671 self.index.insert(k.clone(), x);
672 self.connecting.remove(&k);
673 self.active.remove(&k);
674 self.spawn_connect(k)
675 }
676 }
677 Some(Command::Remove(peers)) => {
678 for k in &peers {
679 info!(
680 name = %self.conf.name,
681 node = %self.conf.label,
682 peer = %k,
683 "removing peer"
684 );
685 self.peers.remove(k);
686 self.index.remove_by_left(k);
687 self.connecting.remove(k);
688 self.active.remove(k);
689 }
690 #[cfg(feature = "metrics")]
691 Arc::make_mut(&mut self.metrics).remove_parties(&peers)
692 }
693 Some(Command::Assign(role, peers)) => {
694 for k in &peers {
695 if let Some(p) = self.peers.get_mut(k) {
696 p.role = role
697 } else {
698 warn!(
699 name = %self.conf.name,
700 node = %self.conf.label,
701 peer = %k,
702 role = ?role,
703 "peer to assign role to not found"
704 );
705 }
706 }
707 }
708 Some(Command::Unicast(to, id, m)) => {
709 if to == self.conf.label {
710 trace!(
711 name = %self.conf.name,
712 node = %self.conf.label,
713 to = %to,
714 len = %m.len(),
715 queue = self.ibound.capacity(),
716 "sending message"
717 );
718 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m, None)) {
719 warn!(
720 name = %self.conf.name,
721 node = %self.conf.label,
722 err = %err,
723 cap = %self.ibound.capacity(),
724 "channel full => dropping unicast message"
725 )
726 }
727 continue
728 }
729 if let Some(task) = self.active.get(&to) {
730 trace!(
731 name = %self.conf.name,
732 node = %self.conf.label,
733 to = %to,
734 len = %m.len(),
735 queue = task.tx.capacity(),
736 "sending message"
737 );
738 #[cfg(feature = "metrics")]
739 self.metrics.set_peer_oqueue_cap(&to, task.tx.capacity());
740 task.tx.send(id, Message::Data(m))
741 }
742 }
743 Some(Command::Multicast(peers, id, m)) => {
744 if peers.contains(&self.conf.label) {
745 trace!(
746 name = %self.conf.name,
747 node = %self.conf.label,
748 to = %self.conf.label,
749 len = %m.len(),
750 queue = self.ibound.capacity(),
751 "sending message"
752 );
753 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
754 warn!(
755 name = %self.conf.name,
756 node = %self.conf.label,
757 err = %err,
758 cap = %self.ibound.capacity(),
759 "channel full => dropping multicast message"
760 )
761 }
762 }
763 for (to, task) in &self.active {
764 if !peers.contains(to) {
765 continue
766 }
767 trace!(
768 name = %self.conf.name,
769 node = %self.conf.label,
770 to = %to,
771 len = %m.len(),
772 queue = task.tx.capacity(),
773 "sending message"
774 );
775 #[cfg(feature = "metrics")]
776 self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
777 task.tx.send(id, Message::Data(m.clone()))
778 }
779 }
780 Some(Command::Broadcast(id, m)) => {
781 if self.role.is_active() {
782 trace!(
783 name = %self.conf.name,
784 node = %self.conf.label,
785 to = %self.conf.label,
786 len = %m.len(),
787 queue = self.ibound.capacity(),
788 "sending message"
789 );
790 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
791 warn!(
792 name = %self.conf.name,
793 node = %self.conf.label,
794 err = %err,
795 cap = %self.ibound.capacity(),
796 "channel full => dropping broadcast message"
797 )
798 }
799 }
800 for (to, task) in &self.active {
801 if Some(Role::Active) != self.peers.get(to).map(|p| p.role) {
802 continue
803 }
804 trace!(
805 name = %self.conf.name,
806 node = %self.conf.label,
807 to = %to,
808 len = %m.len(),
809 queue = task.tx.capacity(),
810 "sending message"
811 );
812 #[cfg(feature = "metrics")]
813 self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
814 task.tx.send(id, Message::Data(m.clone()))
815 }
816 }
817 None => {
818 return Err(NetworkError::ChannelClosed)
819 }
820 },
821 _ = self.ping_interval.tick() => {
822 let now = Timestamp::now();
823 for task in self.active.values() {
824 task.tx.send(None, Message::Ping(now))
825 }
826 }
827 }
828 }
829 }
830
831 fn on_connect_task_end(&mut self, id: task::Id) {
833 let Some(k) = self.task2key.remove(&id) else {
834 error!(name = %self.conf.name, node = %self.conf.label, "no key for connect task");
835 return;
836 };
837 self.connecting.remove(&k);
838 }
839
840 fn on_io_task_end(&mut self, id: task::Id) {
846 let Some(k) = self.task2key.remove(&id) else {
847 error!(name = %self.conf.name, node = %self.conf.label, "no key for i/o task");
848 return;
849 };
850 let Some(task) = self.active.get(&k) else {
851 return;
852 };
853 if task.rh.id() == id {
854 debug!(
855 name = %self.conf.name,
856 node = %self.conf.label,
857 peer = %k,
858 "read-half closed => dropping connection"
859 );
860 self.active.remove(&k);
861 self.spawn_connect(k)
862 } else if task.wh.id() == id {
863 debug!(
864 name = %self.conf.name,
865 node = %self.conf.label,
866 peer = %k,
867 "write-half closed => dropping connection"
868 );
869 self.active.remove(&k);
870 self.spawn_connect(k)
871 } else {
872 debug!(
873 name = %self.conf.name,
874 node = %self.conf.label,
875 peer = %k,
876 "i/o task was previously replaced"
877 );
878 }
879 }
880
881 fn spawn_connect(&mut self, k: K) {
886 if self.connecting.contains_key(&k) {
887 debug!(
888 name = %self.conf.name,
889 node = %self.conf.label,
890 peer = %k,
891 "connect task already started"
892 );
893 return;
894 }
895 let x = self.index.get_by_left(&k).expect("known public key");
896 let p = self.peers.get(&k).expect("known peer");
897 let h = self.connect_tasks.spawn(connect(
898 self.conf.name,
899 (self.conf.label.clone(), self.conf.keypair.clone()),
900 (k.clone(), *x),
901 p.addr.clone(),
902 self.conf.retry_delays,
903 #[cfg(feature = "metrics")]
904 self.metrics.clone(),
905 ));
906 assert!(self.task2key.insert(h.id(), k.clone()).is_none());
907 self.connecting.insert(k, ConnectTask { h });
908 }
909
910 fn spawn_handshake(&mut self, s: TcpStream) {
916 let h = Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
917 .local_private_key(&self.conf.keypair.secret_key().as_bytes())
918 .expect("valid private key")
919 .prologue(self.conf.name.as_bytes())
920 .expect("1st time we set the prologue")
921 .build_responder()
922 .expect("valid noise params yield valid handshake state");
923 self.handshake_tasks.spawn(async move {
924 timeout(HANDSHAKE_TIMEOUT, on_handshake(h, s))
925 .await
926 .or(Err(NetworkError::Timeout))?
927 });
928 }
929
930 fn spawn_io(&mut self, k: K, s: TcpStream, t: TransportState, b: Budget) {
934 debug!(
935 name = %self.conf.name,
936 node = %self.conf.label,
937 peer = %k,
938 addr = ?s.peer_addr().ok(),
939 "starting i/o tasks"
940 );
941 let (to_remote, from_remote) = chan::channel(self.conf.peer_capacity_egress);
942 let (r, w) = s.into_split();
943 let t1 = Arc::new(Mutex::new(t));
944 let t2 = t1.clone();
945 let ibound = self.ibound.clone();
946 let to_write = to_remote.clone();
947 let countdown = Countdown::new();
948 let rh = self.io_tasks.spawn(recv_loop(
949 self.conf.name,
950 k.clone(),
951 r,
952 t1,
953 ibound,
954 to_write,
955 #[cfg(feature = "metrics")]
956 self.metrics.clone(),
957 b,
958 countdown.clone(),
959 self.conf.max_message_size,
960 ));
961 let wh = self
962 .io_tasks
963 .spawn(send_loop(w, t2, from_remote, countdown));
964 assert!(self.task2key.insert(rh.id(), k.clone()).is_none());
965 assert!(self.task2key.insert(wh.id(), k.clone()).is_none());
966 let io = IoTask {
967 rh,
968 wh,
969 tx: to_remote,
970 };
971 self.active.insert(k, io);
972 #[cfg(feature = "metrics")]
973 self.metrics.connections.set(self.active.len());
974 }
975
976 fn lookup_peer(&self, t: &TransportState) -> Option<(K, &Peer)> {
978 let x = t.get_remote_static()?;
979 let x = PublicKey::try_from(x).ok()?;
980 let k = self.index.get_by_right(&x)?;
981 self.peers.get(k).map(|p| (k.clone(), p))
982 }
983
984 fn is_valid_ip(&self, k: &K, s: &TcpStream) -> bool {
986 self.peers
987 .get(k)
988 .map(|p| {
989 let NetAddr::Inet(ip, _) = p.addr else {
990 return true;
991 };
992 Some(ip) == s.peer_addr().ok().map(|a| a.ip())
993 })
994 .unwrap_or(false)
995 }
996}
997
998async fn connect<K>(
1003 name: &'static str,
1004 this: (K, Keypair),
1005 to: (K, PublicKey),
1006 addr: NetAddr,
1007 delays: [u8; NUM_DELAYS],
1008 #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1009) -> (TcpStream, TransportState)
1010where
1011 K: Eq + Hash + Display + Clone,
1012{
1013 use rand::prelude::*;
1014
1015 let new_handshake_state = || {
1016 Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
1017 .local_private_key(this.1.secret_key().as_slice())
1018 .expect("valid private key")
1019 .remote_public_key(to.1.as_slice())
1020 .expect("valid remote pub key")
1021 .prologue(name.as_bytes())
1022 .expect("1st time we set the prologue")
1023 .build_initiator()
1024 .expect("valid noise params yield valid handshake state")
1025 };
1026
1027 let delays = once(rand::rng().random_range(0..=1000))
1028 .chain(delays.into_iter().map(|d| u64::from(d) * 1000))
1029 .chain(repeat(u64::from(delays[LAST_DELAY]) * 1000));
1030
1031 let addr = addr.to_string();
1032
1033 for d in delays {
1034 sleep(Duration::from_millis(d)).await;
1035 debug!(%name, node = %this.0, peer = %to.0, %addr, "connecting");
1036 #[cfg(feature = "metrics")]
1037 metrics.add_connect_attempt(&to.0);
1038 match timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
1039 Ok(Ok(s)) => {
1040 if let Err(err) = s.set_nodelay(true) {
1041 error!(%name, node = %this.0, %err, "failed to set NO_DELAY socket option");
1042 continue;
1043 }
1044 match timeout(HANDSHAKE_TIMEOUT, handshake(new_handshake_state(), s)).await {
1045 Ok(Ok(x)) => {
1046 debug!(%name, node = %this.0, peer = %to.0, %addr, "connection established");
1047 return x;
1048 },
1049 Ok(Err(err)) => {
1050 warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "handshake failure");
1051 },
1052 Err(_) => {
1053 warn!(%name, node = %this.0, peer = %to.0, %addr, "handshake timeout");
1054 },
1055 }
1056 },
1057 Ok(Err(err)) => {
1058 warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "failed to connect");
1059 },
1060 Err(_) => {
1061 warn!(%name, node = %this.0, peer = %to.0, %addr, "connect timeout");
1062 },
1063 }
1064 }
1065
1066 unreachable!("for loop repeats forever")
1067}
1068
1069async fn handshake(
1071 mut hs: HandshakeState,
1072 mut stream: TcpStream,
1073) -> Result<(TcpStream, TransportState)> {
1074 let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1075 let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1076 let h = Header::data(n as u16);
1077 send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1078 let (h, m) = recv_frame(&mut stream).await?;
1079 if !h.is_data() || h.is_partial() {
1080 return Err(NetworkError::InvalidHandshakeMessage);
1081 }
1082 hs.read_message(&m, &mut b)?;
1083 Ok((stream, hs.into_transport_mode()?))
1084}
1085
1086async fn on_handshake(
1088 mut hs: HandshakeState,
1089 mut stream: TcpStream,
1090) -> Result<(TcpStream, TransportState)> {
1091 stream.set_nodelay(true)?;
1092 let (h, m) = recv_frame(&mut stream).await?;
1093 if !h.is_data() || h.is_partial() {
1094 return Err(NetworkError::InvalidHandshakeMessage);
1095 }
1096 let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1097 hs.read_message(&m, &mut b)?;
1098 let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1099 let h = Header::data(n as u16);
1100 send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1101 Ok((stream, hs.into_transport_mode()?))
1102}
1103
1104#[allow(clippy::too_many_arguments)]
1108async fn recv_loop<R, K>(
1109 name: &'static str,
1110 id: K,
1111 mut reader: R,
1112 state: Arc<Mutex<TransportState>>,
1113 to_deliver: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
1114 to_writer: chan::Sender<Message>,
1115 #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1116 budget: Arc<Semaphore>,
1117 mut countdown: Countdown,
1118 max_message_size: usize,
1119) -> Result<()>
1120where
1121 R: AsyncRead + Unpin,
1122 K: Eq + Hash + Display + Clone,
1123{
1124 let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1125 loop {
1126 #[cfg(feature = "metrics")]
1127 metrics.set_peer_iqueue_cap(&id, budget.available_permits());
1128 let permit = budget
1129 .clone()
1130 .acquire_owned()
1131 .await
1132 .map_err(|_| NetworkError::BudgetClosed)?;
1133 let mut msg = BytesMut::new();
1134 loop {
1135 tokio::select! {
1136 val = recv_frame(&mut reader) => {
1137 countdown.stop();
1138 match val {
1139 Ok((h, f)) => {
1140 match h.frame_type() {
1141 Ok(Type::Ping) => {
1142 let n = state.lock().read_message(&f, &mut buf)?;
1144 if let Some(ping) = Timestamp::try_from_slice(&buf[..n]) {
1145 to_writer.send(None, Message::Pong(ping))
1146 }
1147 }
1148 Ok(Type::Pong) => {
1149 let _n = state.lock().read_message(&f, &mut buf)?;
1151 #[cfg(feature = "metrics")]
1152 if let Some(ping) = Timestamp::try_from_slice(&buf[.._n])
1153 && let Some(delay) = Timestamp::now().diff(ping)
1154 {
1155 metrics.set_latency(&id, delay)
1156 }
1157 }
1158 Ok(Type::Data) => {
1159 let n = state.lock().read_message(&f, &mut buf)?;
1160 msg.extend_from_slice(&buf[..n]);
1161 if msg.len() > max_message_size {
1162 return Err(NetworkError::MessageTooLarge);
1163 }
1164 if !h.is_partial() {
1165 break;
1166 }
1167 }
1168 Err(t) => return Err(NetworkError::UnknownFrameType(t)),
1169 }
1170 }
1171 Err(e) => return Err(e)
1172 }
1173 },
1174 () = &mut countdown => {
1175 warn!(%name, node = %id, "timeout waiting for peer");
1176 return Err(NetworkError::Timeout)
1177 }
1178 }
1179 }
1180 if to_deliver
1181 .send((id.clone(), msg.freeze(), Some(permit)))
1182 .await
1183 .is_err()
1184 {
1185 break;
1186 }
1187 }
1188 Ok(())
1189}
1190
1191async fn send_loop<W>(
1196 mut writer: W,
1197 state: Arc<Mutex<TransportState>>,
1198 rx: chan::Receiver<Message>,
1199 countdown: Countdown,
1200) -> Result<()>
1201where
1202 W: AsyncWrite + Unpin,
1203{
1204 let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1205
1206 while let Some(msg) = rx.recv().await {
1207 match msg {
1208 Message::Ping(ping) => {
1209 let n = state
1210 .lock()
1211 .write_message(&ping.to_bytes()[..], &mut buf[Header::SIZE..])?;
1212 let h = Header::ping(n as u16);
1213 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?;
1214 countdown.start(REPLY_TIMEOUT)
1215 },
1216 Message::Pong(pong) => {
1217 let n = state
1218 .lock()
1219 .write_message(&pong.to_bytes()[..], &mut buf[Header::SIZE..])?;
1220 let h = Header::pong(n as u16);
1221 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1222 },
1223 Message::Data(msg) => {
1224 let mut it = msg.chunks(MAX_PAYLOAD_SIZE).peekable();
1225 while let Some(m) = it.next() {
1226 let n = state.lock().write_message(m, &mut buf[Header::SIZE..])?;
1227 let h = if it.peek().is_some() {
1228 Header::data(n as u16).partial()
1229 } else {
1230 Header::data(n as u16)
1231 };
1232 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1233 }
1234 },
1235 }
1236 }
1237 Ok(())
1238}
1239
1240async fn recv_frame<R>(r: &mut R) -> Result<(Header, Vec<u8>)>
1242where
1243 R: AsyncRead + Unpin,
1244{
1245 let b = r.read_u32().await?;
1246 let h = Header::try_from(b.to_be_bytes())?;
1247 let mut v = vec![0; h.len().into()];
1248 r.read_exact(&mut v).await?;
1249 Ok((h, v))
1250}
1251
1252async fn send_frame<W>(w: &mut W, hdr: Header, msg: &mut [u8]) -> Result<()>
1257where
1258 W: AsyncWrite + Unpin,
1259{
1260 debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
1261 msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
1262 w.write_all(msg).await?;
1263 Ok(())
1264}