1#![doc = include_str!("../README.md")]
2
3use std::{
4 collections::HashMap,
5 fmt::Display,
6 future::pending,
7 hash::Hash,
8 iter::{once, repeat},
9 sync::Arc,
10 time::Duration,
11};
12
13use bimap::BiHashMap;
14use bytes::{Bytes, BytesMut};
15use parking_lot::Mutex;
16use snow::{Builder, HandshakeState, TransportState};
17use tokio::{
18 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
19 net::{TcpListener, TcpStream},
20 spawn,
21 sync::{
22 Mutex as AsyncMutex, OwnedSemaphorePermit, Semaphore,
23 mpsc::{self, Receiver, Sender},
24 },
25 task::{self, AbortHandle, JoinHandle, JoinSet},
26 time::{Interval, MissedTickBehavior, sleep, timeout},
27};
28use tracing::{debug, error, info, trace, warn};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::NetworkMetrics;
32use crate::{
33 Address, Id, Keypair, LAST_DELAY, NUM_DELAYS, NetConf, NetworkError, PublicKey, Role, chan,
34 error::Empty,
35 frame::{Header, Type},
36 time::{Countdown, Timestamp},
37};
38
39type Budget = Arc<Semaphore>;
40type Result<T> = std::result::Result<T, NetworkError>;
41
42const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
44
45const MAX_NOISE_MESSAGE_SIZE: usize = 64 * 1024;
47
48const MAX_PAYLOAD_SIZE: usize = MAX_NOISE_MESSAGE_SIZE - 32;
50
51const NOISE_PARAMS: &str = "Noise_IK_25519_AESGCM_BLAKE2s";
53
54const PING_INTERVAL: Duration = Duration::from_secs(15);
56
57const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
59
60const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
62
63const REPLY_TIMEOUT: Duration = Duration::from_secs(30);
69
70#[derive(Debug)]
72pub struct Network<K> {
73 name: &'static str,
75
76 label: K,
78
79 parties: Mutex<HashMap<K, Role>>,
81
82 tx: Sender<Command<K>>,
84
85 rx: AsyncMutex<Receiver<(K, Bytes, Option<OwnedSemaphorePermit>)>>,
89
90 srv: JoinHandle<Result<Empty>>,
92
93 max_message_size: usize,
95}
96
97impl<K> Drop for Network<K> {
98 fn drop(&mut self) {
99 self.srv.abort()
100 }
101}
102
103#[derive(Debug)]
105pub(crate) enum Command<K> {
106 Add(Vec<(K, PublicKey, Address)>),
108 Remove(Vec<K>),
110 Assign(Role, Vec<K>),
112 Unicast(K, Option<Id>, Bytes),
114 Multicast(Vec<K>, Option<Id>, Bytes),
116 Broadcast(Option<Id>, Bytes),
118}
119
120#[derive(Debug)]
123struct Server<K> {
124 conf: NetConf<K>,
125
126 role: Role,
128
129 ibound: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
133
134 obound: Receiver<Command<K>>,
138
139 peers: HashMap<K, Peer>,
141
142 index: BiHashMap<K, PublicKey>,
145
146 task2key: HashMap<task::Id, K>,
148
149 connecting: HashMap<K, ConnectTask>,
151
152 active: HashMap<K, IoTask>,
154
155 handshake_tasks: JoinSet<Result<(TcpStream, TransportState)>>,
157
158 connect_tasks: JoinSet<(TcpStream, TransportState)>,
160
161 io_tasks: JoinSet<Result<()>>,
163
164 ping_interval: Interval,
166
167 #[cfg(feature = "metrics")]
169 metrics: Arc<NetworkMetrics<K>>,
170}
171
172#[derive(Debug)]
173struct Peer {
174 addr: Address,
175 role: Role,
176 budget: Budget,
177}
178
179#[derive(Debug)]
181struct ConnectTask {
182 h: AbortHandle,
183}
184
185impl Drop for ConnectTask {
187 fn drop(&mut self) {
188 self.h.abort();
189 }
190}
191
192#[derive(Debug)]
194struct IoTask {
195 rh: AbortHandle,
197
198 wh: AbortHandle,
200
201 tx: chan::Sender<Message>,
203}
204
205impl Drop for IoTask {
207 fn drop(&mut self) {
208 self.rh.abort();
209 self.wh.abort();
210 }
211}
212
213#[derive(Debug)]
215enum Message {
216 Data(Bytes),
217 Ping(Timestamp),
218 Pong(Timestamp),
219}
220
221impl<K> Network<K>
222where
223 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
224{
225 pub async fn create(cfg: NetConf<K>) -> Result<Self> {
226 let listener = TcpListener::bind(cfg.bind.to_string())
227 .await
228 .map_err(|e| NetworkError::Bind(cfg.bind.clone(), e))?;
229
230 debug!(
231 name = %cfg.name,
232 node = %cfg.label,
233 addr = %listener.local_addr()?,
234 "listening"
235 );
236
237 let mut parties = HashMap::new();
238 let mut peers = HashMap::new();
239 let mut index = BiHashMap::new();
240
241 for (k, x, a) in cfg.parties.iter().cloned() {
242 parties.insert(k.clone(), Role::Active);
243 index.insert(k.clone(), x);
244 peers.insert(
245 k,
246 Peer {
247 addr: a,
248 role: Role::Active,
249 budget: cfg.new_budget(),
250 },
251 );
252 }
253
254 let (otx, orx) = mpsc::channel(cfg.total_capacity_egress);
256
257 let (itx, irx) = mpsc::channel(cfg.total_capacity_ingress);
259
260 let mut interval = tokio::time::interval(PING_INTERVAL);
261 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
262
263 let name = cfg.name;
264 let label = cfg.label.clone();
265 let mmsze = cfg.max_message_size;
266
267 #[cfg(feature = "metrics")]
268 let metrics = {
269 let it = parties.keys().filter(|k| **k != label).cloned();
270 NetworkMetrics::new(name, &*cfg.metrics, it)
271 };
272
273 let server = Server {
274 conf: cfg,
275 role: Role::Active,
276 ibound: itx,
277 obound: orx,
278 peers,
279 index,
280 connecting: HashMap::new(),
281 active: HashMap::new(),
282 task2key: HashMap::new(),
283 handshake_tasks: JoinSet::new(),
284 connect_tasks: JoinSet::new(),
285 io_tasks: JoinSet::new(),
286 ping_interval: interval,
287 #[cfg(feature = "metrics")]
288 metrics: Arc::new(metrics),
289 };
290
291 Ok(Self {
292 name,
293 label,
294 parties: Mutex::new(parties),
295 rx: AsyncMutex::new(irx),
296 tx: otx,
297 srv: spawn(server.run(listener)),
298 max_message_size: mmsze,
299 })
300 }
301
302 pub fn public_key(&self) -> &K {
303 &self.label
304 }
305
306 pub fn name(&self) -> &str {
307 self.name
308 }
309
310 pub fn parties(&self, r: Role) -> Vec<K> {
311 self.parties
312 .lock()
313 .iter()
314 .filter(|&(_, x)| r == *x)
315 .map(|(k, _)| k.clone())
316 .collect()
317 }
318
319 pub async fn unicast(&self, to: K, msg: Bytes) -> Result<()> {
321 if msg.len() > self.max_message_size {
322 warn!(
323 name = %self.name,
324 node = %self.label,
325 to = %to,
326 len = %msg.len(),
327 max = %self.max_message_size,
328 "message too large to send"
329 );
330 return Err(NetworkError::MessageTooLarge);
331 }
332 self.tx
333 .send(Command::Unicast(to, None, msg))
334 .await
335 .map_err(|_| NetworkError::ChannelClosed)
336 }
337
338 pub async fn broadcast(&self, msg: Bytes) -> Result<()> {
340 if msg.len() > self.max_message_size {
341 warn!(
342 name = %self.name,
343 node = %self.label,
344 len = %msg.len(),
345 max = %self.max_message_size,
346 "message too large to broadcast"
347 );
348 return Err(NetworkError::MessageTooLarge);
349 }
350 self.tx
351 .send(Command::Broadcast(None, msg))
352 .await
353 .map_err(|_| NetworkError::ChannelClosed)
354 }
355
356 pub async fn receive(&self) -> Result<(K, Bytes)> {
358 let mut rx = self.rx.lock().await;
359 let (k, b, _) = rx.recv().await.ok_or(NetworkError::ChannelClosed)?;
360 Ok((k, b))
361 }
362
363 pub async fn add(&self, peers: Vec<(K, PublicKey, Address)>) -> Result<()> {
368 self.parties
369 .lock()
370 .extend(peers.iter().map(|(p, ..)| (p.clone(), Role::Passive)));
371 self.tx
372 .send(Command::Add(peers))
373 .await
374 .map_err(|_| NetworkError::ChannelClosed)
375 }
376
377 pub async fn remove(&self, peers: Vec<K>) -> Result<()> {
379 {
380 let mut parties = self.parties.lock();
381 for p in &peers {
382 parties.remove(p);
383 }
384 }
385 self.tx
386 .send(Command::Remove(peers))
387 .await
388 .map_err(|_| NetworkError::ChannelClosed)
389 }
390
391 pub async fn assign(&self, r: Role, peers: Vec<K>) -> Result<()> {
393 {
394 let mut parties = self.parties.lock();
395 for p in &peers {
396 if let Some(role) = parties.get_mut(p) {
397 *role = r
398 }
399 }
400 }
401 self.tx
402 .send(Command::Assign(r, peers))
403 .await
404 .map_err(|_| NetworkError::ChannelClosed)
405 }
406
407 pub(crate) fn sender(&self) -> Sender<Command<K>> {
409 self.tx.clone()
410 }
411}
412
413impl<K> Server<K>
414where
415 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
416{
417 async fn run(mut self, listener: TcpListener) -> Result<Empty> {
425 self.handshake_tasks.spawn(pending());
426 self.io_tasks.spawn(pending());
427
428 for k in self
430 .peers
431 .keys()
432 .filter(|k| **k != self.conf.label)
433 .cloned()
434 .collect::<Vec<_>>()
435 {
436 self.spawn_connect(k)
437 }
438
439 loop {
440 trace!(
441 name = %self.conf.name,
442 node = %self.conf.label,
443 active = %self.active.len(),
444 connects = %self.connect_tasks.len(),
445 handshakes = %self.handshake_tasks.len().saturating_sub(1), io_tasks = %self.io_tasks.len().saturating_sub(1), tasks_ids = %self.task2key.len(),
448 iqueue = %self.ibound.capacity(),
449 oqueue = %self.obound.capacity(),
450 );
451
452 #[cfg(feature = "metrics")]
453 {
454 self.metrics.iqueue.set(self.ibound.capacity());
455 self.metrics.oqueue.set(self.obound.capacity());
456 }
457
458 tokio::select! {
459 i = listener.accept() => match i {
461 Ok((s, a)) => {
462 debug!(
463 name = %self.conf.name,
464 node = %self.conf.label,
465 addr = %a,
466 "accepted connection"
467 );
468 self.spawn_handshake(s)
469 }
470 Err(e) => {
471 warn!(
472 name = %self.conf.name,
473 node = %self.conf.label,
474 err = %e,
475 "error accepting connection"
476 )
477 }
478 },
479 Some(h) = self.handshake_tasks.join_next() => match h {
481 Ok(Ok((s, t))) => {
482 let Some((k, peer)) = self.lookup_peer(&t) else {
483 info!(
484 name = %self.conf.name,
485 node = %self.conf.label,
486 peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
487 addr = ?s.peer_addr().ok(),
488 "unknown peer"
489 );
490 continue
491 };
492 if !self.is_valid_ip(&k, &s) {
493 warn!(
494 name = %self.conf.name,
495 node = %self.conf.label,
496 peer = %k,
497 addr = ?s.peer_addr().ok(), "invalid peer ip addr"
498 );
499 continue
500 }
501 if k > self.conf.label || !self.active.contains_key(&k) {
505 self.spawn_io(k, s, t, peer.budget.clone())
506 } else {
507 debug!(
508 name = %self.conf.name,
509 node = %self.conf.label,
510 peer = %k,
511 "dropping accepted connection"
512 );
513 }
514 }
515 Ok(Err(e)) => {
516 warn!(
517 name = %self.conf.name,
518 node = %self.conf.label,
519 err = %e,
520 "handshake failed"
521 )
522 }
523 Err(e) => {
524 if !e.is_cancelled() {
525 error!(
526 name = %self.conf.name,
527 node = %self.conf.label,
528 err = %e,
529 "handshake task panic"
530 )
531 }
532 }
533 },
534 Some(tt) = self.connect_tasks.join_next_with_id() => {
536 match tt {
537 Ok((id, (s, t))) => {
538 self.on_connect_task_end(id);
539 let Some((k, peer)) = self.lookup_peer(&t) else {
540 warn!(
541 name = %self.conf.name,
542 node = %self.conf.label,
543 peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
544 addr = ?s.peer_addr().ok(),
545 "connected to unknown peer"
546 );
547 continue
548 };
549 if k < self.conf.label || !self.active.contains_key(&k) {
552 self.spawn_io(k, s, t, peer.budget.clone())
553 } else {
554 debug!(
555 name = %self.conf.name,
556 node = %self.conf.label,
557 peer = %k,
558 "dropping new connection"
559 )
560 }
561 }
562 Err(e) => {
563 if !e.is_cancelled() {
564 error!(
565 name = %self.conf.name,
566 node = %self.conf.label,
567 err = %e,
568 "connect task panic"
569 )
570 }
571 self.on_connect_task_end(e.id());
572 }
573 }
574 },
575 Some(io) = self.io_tasks.join_next_with_id() => {
577 match io {
578 Ok((id, r)) => {
579 if let Err(e) = r {
580 warn!(
581 name = %self.conf.name,
582 node = %self.conf.label,
583 err = %e,
584 "i/o error"
585 )
586 }
587 self.on_io_task_end(id);
588 }
589 Err(e) => {
590 if e.is_cancelled() {
591 self.task2key.remove(&e.id());
596 continue
597 }
598 error!(
600 name = %self.conf.name,
601 node = %self.conf.label,
602 err = %e,
603 "i/o task panic"
604 );
605 self.on_io_task_end(e.id())
606 }
607 };
608 },
609 cmd = self.obound.recv() => match cmd {
610 Some(Command::Add(peers)) => {
611 #[cfg(feature = "metrics")]
612 Arc::make_mut(&mut self.metrics).add_parties(peers.iter().map(|(k, ..)| k).cloned());
613 for (k, x, a) in peers {
614 if self.peers.contains_key(&k) {
615 warn!(
616 name = %self.conf.name,
617 node = %self.conf.label,
618 peer = %k,
619 "peer to add already exists"
620 );
621 continue
622 }
623 info!(
624 name = %self.conf.name,
625 node = %self.conf.label,
626 peer = %k,
627 "adding peer"
628 );
629 let p = Peer {
630 addr: a,
631 role: Role::Passive,
632 budget: self.conf.new_budget()
633 };
634 self.peers.insert(k.clone(), p);
635 self.index.insert(k.clone(), x);
636 self.spawn_connect(k)
637 }
638 }
639 Some(Command::Remove(peers)) => {
640 for k in &peers {
641 info!(
642 name = %self.conf.name,
643 node = %self.conf.label,
644 peer = %k,
645 "removing peer"
646 );
647 self.peers.remove(k);
648 self.index.remove_by_left(k);
649 self.connecting.remove(k);
650 self.active.remove(k);
651 }
652 #[cfg(feature = "metrics")]
653 Arc::make_mut(&mut self.metrics).remove_parties(&peers)
654 }
655 Some(Command::Assign(role, peers)) => {
656 for k in &peers {
657 if let Some(p) = self.peers.get_mut(k) {
658 p.role = role
659 } else {
660 warn!(
661 name = %self.conf.name,
662 node = %self.conf.label,
663 peer = %k,
664 role = ?role,
665 "peer to assign role to not found"
666 );
667 }
668 }
669 }
670 Some(Command::Unicast(to, id, m)) => {
671 if to == self.conf.label {
672 trace!(
673 name = %self.conf.name,
674 node = %self.conf.label,
675 to = %to,
676 len = %m.len(),
677 queue = self.ibound.capacity(),
678 "sending message"
679 );
680 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m, None)) {
681 warn!(
682 name = %self.conf.name,
683 node = %self.conf.label,
684 err = %err,
685 cap = %self.ibound.capacity(),
686 "channel full => dropping unicast message"
687 )
688 }
689 continue
690 }
691 if let Some(task) = self.active.get(&to) {
692 trace!(
693 name = %self.conf.name,
694 node = %self.conf.label,
695 to = %to,
696 len = %m.len(),
697 queue = task.tx.capacity(),
698 "sending message"
699 );
700 #[cfg(feature = "metrics")]
701 self.metrics.set_peer_oqueue_cap(&to, task.tx.capacity());
702 task.tx.send(id, Message::Data(m))
703 }
704 }
705 Some(Command::Multicast(peers, id, m)) => {
706 if peers.contains(&self.conf.label) {
707 trace!(
708 name = %self.conf.name,
709 node = %self.conf.label,
710 to = %self.conf.label,
711 len = %m.len(),
712 queue = self.ibound.capacity(),
713 "sending message"
714 );
715 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
716 warn!(
717 name = %self.conf.name,
718 node = %self.conf.label,
719 err = %err,
720 cap = %self.ibound.capacity(),
721 "channel full => dropping multicast message"
722 )
723 }
724 }
725 for (to, task) in &self.active {
726 if !peers.contains(to) {
727 continue
728 }
729 trace!(
730 name = %self.conf.name,
731 node = %self.conf.label,
732 to = %to,
733 len = %m.len(),
734 queue = task.tx.capacity(),
735 "sending message"
736 );
737 #[cfg(feature = "metrics")]
738 self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
739 task.tx.send(id, Message::Data(m.clone()))
740 }
741 }
742 Some(Command::Broadcast(id, m)) => {
743 if self.role.is_active() {
744 trace!(
745 name = %self.conf.name,
746 node = %self.conf.label,
747 to = %self.conf.label,
748 len = %m.len(),
749 queue = self.ibound.capacity(),
750 "sending message"
751 );
752 if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
753 warn!(
754 name = %self.conf.name,
755 node = %self.conf.label,
756 err = %err,
757 cap = %self.ibound.capacity(),
758 "channel full => dropping broadcast message"
759 )
760 }
761 }
762 for (to, task) in &self.active {
763 if Some(Role::Active) != self.peers.get(to).map(|p| p.role) {
764 continue
765 }
766 trace!(
767 name = %self.conf.name,
768 node = %self.conf.label,
769 to = %to,
770 len = %m.len(),
771 queue = task.tx.capacity(),
772 "sending message"
773 );
774 #[cfg(feature = "metrics")]
775 self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
776 task.tx.send(id, Message::Data(m.clone()))
777 }
778 }
779 None => {
780 return Err(NetworkError::ChannelClosed)
781 }
782 },
783 _ = self.ping_interval.tick() => {
784 let now = Timestamp::now();
785 for task in self.active.values() {
786 task.tx.send(None, Message::Ping(now))
787 }
788 }
789 }
790 }
791 }
792
793 fn on_connect_task_end(&mut self, id: task::Id) {
795 let Some(k) = self.task2key.remove(&id) else {
796 error!(name = %self.conf.name, node = %self.conf.label, "no key for connect task");
797 return;
798 };
799 self.connecting.remove(&k);
800 }
801
802 fn on_io_task_end(&mut self, id: task::Id) {
808 let Some(k) = self.task2key.remove(&id) else {
809 error!(name = %self.conf.name, node = %self.conf.label, "no key for i/o task");
810 return;
811 };
812 let Some(task) = self.active.get(&k) else {
813 return;
814 };
815 if task.rh.id() == id {
816 debug!(
817 name = %self.conf.name,
818 node = %self.conf.label,
819 peer = %k,
820 "read-half closed => dropping connection"
821 );
822 self.active.remove(&k);
823 self.spawn_connect(k)
824 } else if task.wh.id() == id {
825 debug!(
826 name = %self.conf.name,
827 node = %self.conf.label,
828 peer = %k,
829 "write-half closed => dropping connection"
830 );
831 self.active.remove(&k);
832 self.spawn_connect(k)
833 } else {
834 debug!(
835 name = %self.conf.name,
836 node = %self.conf.label,
837 peer = %k,
838 "i/o task was previously replaced"
839 );
840 }
841 }
842
843 fn spawn_connect(&mut self, k: K) {
848 if self.connecting.contains_key(&k) {
849 debug!(
850 name = %self.conf.name,
851 node = %self.conf.label,
852 peer = %k,
853 "connect task already started"
854 );
855 return;
856 }
857 let x = self.index.get_by_left(&k).expect("known public key");
858 let p = self.peers.get(&k).expect("known peer");
859 let h = self.connect_tasks.spawn(connect(
860 self.conf.name,
861 (self.conf.label.clone(), self.conf.keypair.clone()),
862 (k.clone(), *x),
863 p.addr.clone(),
864 self.conf.retry_delays,
865 #[cfg(feature = "metrics")]
866 self.metrics.clone(),
867 ));
868 assert!(self.task2key.insert(h.id(), k.clone()).is_none());
869 self.connecting.insert(k, ConnectTask { h });
870 }
871
872 fn spawn_handshake(&mut self, s: TcpStream) {
878 let h = Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
879 .local_private_key(&self.conf.keypair.secret_key().as_bytes())
880 .expect("valid private key")
881 .prologue(self.conf.name.as_bytes())
882 .expect("1st time we set the prologue")
883 .build_responder()
884 .expect("valid noise params yield valid handshake state");
885 self.handshake_tasks.spawn(async move {
886 timeout(HANDSHAKE_TIMEOUT, on_handshake(h, s))
887 .await
888 .or(Err(NetworkError::Timeout))?
889 });
890 }
891
892 fn spawn_io(&mut self, k: K, s: TcpStream, t: TransportState, b: Budget) {
896 debug!(
897 name = %self.conf.name,
898 node = %self.conf.label,
899 peer = %k,
900 addr = ?s.peer_addr().ok(),
901 "starting i/o tasks"
902 );
903 let (to_remote, from_remote) = chan::channel(self.conf.peer_capacity_egress);
904 let (r, w) = s.into_split();
905 let t1 = Arc::new(Mutex::new(t));
906 let t2 = t1.clone();
907 let ibound = self.ibound.clone();
908 let to_write = to_remote.clone();
909 let countdown = Countdown::new();
910 let rh = self.io_tasks.spawn(recv_loop(
911 self.conf.name,
912 k.clone(),
913 r,
914 t1,
915 ibound,
916 to_write,
917 #[cfg(feature = "metrics")]
918 self.metrics.clone(),
919 b,
920 countdown.clone(),
921 self.conf.max_message_size,
922 ));
923 let wh = self
924 .io_tasks
925 .spawn(send_loop(w, t2, from_remote, countdown));
926 assert!(self.task2key.insert(rh.id(), k.clone()).is_none());
927 assert!(self.task2key.insert(wh.id(), k.clone()).is_none());
928 let io = IoTask {
929 rh,
930 wh,
931 tx: to_remote,
932 };
933 self.active.insert(k, io);
934 #[cfg(feature = "metrics")]
935 self.metrics.connections.set(self.active.len());
936 }
937
938 fn lookup_peer(&self, t: &TransportState) -> Option<(K, &Peer)> {
940 let x = t.get_remote_static()?;
941 let x = PublicKey::try_from(x).ok()?;
942 let k = self.index.get_by_right(&x)?;
943 self.peers.get(k).map(|p| (k.clone(), p))
944 }
945
946 fn is_valid_ip(&self, k: &K, s: &TcpStream) -> bool {
948 self.peers
949 .get(k)
950 .map(|p| {
951 let Address::Inet(ip, _) = p.addr else {
952 return true;
953 };
954 Some(ip) == s.peer_addr().ok().map(|a| a.ip())
955 })
956 .unwrap_or(false)
957 }
958}
959
960async fn connect<K>(
965 name: &'static str,
966 this: (K, Keypair),
967 to: (K, PublicKey),
968 addr: Address,
969 delays: [u8; NUM_DELAYS],
970 #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
971) -> (TcpStream, TransportState)
972where
973 K: Eq + Hash + Display + Clone,
974{
975 use rand::prelude::*;
976
977 let new_handshake_state = || {
978 Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
979 .local_private_key(this.1.secret_key().as_slice())
980 .expect("valid private key")
981 .remote_public_key(to.1.as_slice())
982 .expect("valid remote pub key")
983 .prologue(name.as_bytes())
984 .expect("1st time we set the prologue")
985 .build_initiator()
986 .expect("valid noise params yield valid handshake state")
987 };
988
989 let delays = once(rand::rng().random_range(0..=1000))
990 .chain(delays.into_iter().map(|d| u64::from(d) * 1000))
991 .chain(repeat(u64::from(delays[LAST_DELAY]) * 1000));
992
993 let addr = addr.to_string();
994
995 for d in delays {
996 sleep(Duration::from_millis(d)).await;
997 debug!(%name, node = %this.0, peer = %to.0, %addr, "connecting");
998 #[cfg(feature = "metrics")]
999 metrics.add_connect_attempt(&to.0);
1000 match timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
1001 Ok(Ok(s)) => {
1002 if let Err(err) = s.set_nodelay(true) {
1003 error!(%name, node = %this.0, %err, "failed to set NO_DELAY socket option");
1004 continue;
1005 }
1006 match timeout(HANDSHAKE_TIMEOUT, handshake(new_handshake_state(), s)).await {
1007 Ok(Ok(x)) => {
1008 debug!(%name, node = %this.0, peer = %to.0, %addr, "connection established");
1009 return x;
1010 },
1011 Ok(Err(err)) => {
1012 warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "handshake failure");
1013 },
1014 Err(_) => {
1015 warn!(%name, node = %this.0, peer = %to.0, %addr, "handshake timeout");
1016 },
1017 }
1018 },
1019 Ok(Err(err)) => {
1020 warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "failed to connect");
1021 },
1022 Err(_) => {
1023 warn!(%name, node = %this.0, peer = %to.0, %addr, "connect timeout");
1024 },
1025 }
1026 }
1027
1028 unreachable!("for loop repeats forever")
1029}
1030
1031async fn handshake(
1033 mut hs: HandshakeState,
1034 mut stream: TcpStream,
1035) -> Result<(TcpStream, TransportState)> {
1036 let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1037 let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1038 let h = Header::data(n as u16);
1039 send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1040 let (h, m) = recv_frame(&mut stream).await?;
1041 if !h.is_data() || h.is_partial() {
1042 return Err(NetworkError::InvalidHandshakeMessage);
1043 }
1044 hs.read_message(&m, &mut b)?;
1045 Ok((stream, hs.into_transport_mode()?))
1046}
1047
1048async fn on_handshake(
1050 mut hs: HandshakeState,
1051 mut stream: TcpStream,
1052) -> Result<(TcpStream, TransportState)> {
1053 stream.set_nodelay(true)?;
1054 let (h, m) = recv_frame(&mut stream).await?;
1055 if !h.is_data() || h.is_partial() {
1056 return Err(NetworkError::InvalidHandshakeMessage);
1057 }
1058 let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1059 hs.read_message(&m, &mut b)?;
1060 let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1061 let h = Header::data(n as u16);
1062 send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1063 Ok((stream, hs.into_transport_mode()?))
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1070async fn recv_loop<R, K>(
1071 name: &'static str,
1072 id: K,
1073 mut reader: R,
1074 state: Arc<Mutex<TransportState>>,
1075 to_deliver: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
1076 to_writer: chan::Sender<Message>,
1077 #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1078 budget: Arc<Semaphore>,
1079 mut countdown: Countdown,
1080 max_message_size: usize,
1081) -> Result<()>
1082where
1083 R: AsyncRead + Unpin,
1084 K: Eq + Hash + Display + Clone,
1085{
1086 let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1087 loop {
1088 #[cfg(feature = "metrics")]
1089 metrics.set_peer_iqueue_cap(&id, budget.available_permits());
1090 let permit = budget
1091 .clone()
1092 .acquire_owned()
1093 .await
1094 .map_err(|_| NetworkError::BudgetClosed)?;
1095 let mut msg = BytesMut::new();
1096 loop {
1097 tokio::select! {
1098 val = recv_frame(&mut reader) => {
1099 countdown.stop();
1100 match val {
1101 Ok((h, f)) => {
1102 match h.frame_type() {
1103 Ok(Type::Ping) => {
1104 let n = state.lock().read_message(&f, &mut buf)?;
1106 if let Some(ping) = Timestamp::try_from_slice(&buf[..n]) {
1107 to_writer.send(None, Message::Pong(ping))
1108 }
1109 }
1110 Ok(Type::Pong) => {
1111 let _n = state.lock().read_message(&f, &mut buf)?;
1113 #[cfg(feature = "metrics")]
1114 if let Some(ping) = Timestamp::try_from_slice(&buf[.._n])
1115 && let Some(delay) = Timestamp::now().diff(ping)
1116 {
1117 metrics.set_latency(&id, delay)
1118 }
1119 }
1120 Ok(Type::Data) => {
1121 let n = state.lock().read_message(&f, &mut buf)?;
1122 msg.extend_from_slice(&buf[..n]);
1123 if !h.is_partial() {
1124 break;
1125 }
1126 if msg.len() > max_message_size {
1127 return Err(NetworkError::MessageTooLarge);
1128 }
1129 }
1130 Err(t) => return Err(NetworkError::UnknownFrameType(t)),
1131 }
1132 }
1133 Err(e) => return Err(e)
1134 }
1135 },
1136 () = &mut countdown => {
1137 warn!(%name, node = %id, "timeout waiting for peer");
1138 return Err(NetworkError::Timeout)
1139 }
1140 }
1141 }
1142 if to_deliver
1143 .send((id.clone(), msg.freeze(), Some(permit)))
1144 .await
1145 .is_err()
1146 {
1147 break;
1148 }
1149 }
1150 Ok(())
1151}
1152
1153async fn send_loop<W>(
1158 mut writer: W,
1159 state: Arc<Mutex<TransportState>>,
1160 rx: chan::Receiver<Message>,
1161 countdown: Countdown,
1162) -> Result<()>
1163where
1164 W: AsyncWrite + Unpin,
1165{
1166 let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1167
1168 while let Some(msg) = rx.recv().await {
1169 match msg {
1170 Message::Ping(ping) => {
1171 let n = state
1172 .lock()
1173 .write_message(&ping.to_bytes()[..], &mut buf[Header::SIZE..])?;
1174 let h = Header::ping(n as u16);
1175 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?;
1176 countdown.start(REPLY_TIMEOUT)
1177 },
1178 Message::Pong(pong) => {
1179 let n = state
1180 .lock()
1181 .write_message(&pong.to_bytes()[..], &mut buf[Header::SIZE..])?;
1182 let h = Header::pong(n as u16);
1183 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1184 },
1185 Message::Data(msg) => {
1186 let mut it = msg.chunks(MAX_PAYLOAD_SIZE).peekable();
1187 while let Some(m) = it.next() {
1188 let n = state.lock().write_message(m, &mut buf[Header::SIZE..])?;
1189 let h = if it.peek().is_some() {
1190 Header::data(n as u16).partial()
1191 } else {
1192 Header::data(n as u16)
1193 };
1194 send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1195 }
1196 },
1197 }
1198 }
1199 Ok(())
1200}
1201
1202async fn recv_frame<R>(r: &mut R) -> Result<(Header, Vec<u8>)>
1204where
1205 R: AsyncRead + Unpin,
1206{
1207 let b = r.read_u32().await?;
1208 let h = Header::try_from(b.to_be_bytes())?;
1209 let mut v = vec![0; h.len().into()];
1210 r.read_exact(&mut v).await?;
1211 Ok((h, v))
1212}
1213
1214async fn send_frame<W>(w: &mut W, hdr: Header, msg: &mut [u8]) -> Result<()>
1219where
1220 W: AsyncWrite + Unpin,
1221{
1222 debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
1223 msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
1224 w.write_all(msg).await?;
1225 Ok(())
1226}