cliquenet/
net.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    collections::HashMap,
5    fmt::Display,
6    future::pending,
7    hash::Hash,
8    iter::{once, repeat},
9    result::Result as StdResult,
10    sync::Arc,
11    time::Duration,
12};
13
14use bimap::BiHashMap;
15use bytes::{Bytes, BytesMut};
16use hotshot_types::addr::NetAddr;
17use parking_lot::Mutex;
18use snow::{Builder, HandshakeState, TransportState};
19use tokio::{
20    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
21    net::{TcpListener, TcpStream},
22    spawn,
23    sync::{
24        Mutex as AsyncMutex, OwnedSemaphorePermit, Semaphore,
25        mpsc::{self, Receiver, Sender},
26    },
27    task::{self, AbortHandle, JoinHandle, JoinSet},
28    time::{Interval, MissedTickBehavior, sleep, timeout},
29};
30use tracing::{debug, error, info, trace, warn};
31
32#[cfg(feature = "metrics")]
33use crate::metrics::NetworkMetrics;
34use crate::{
35    Id, Keypair, LAST_DELAY, NUM_DELAYS, NetConf, NetworkDown, NetworkError, PublicKey, Role, chan,
36    error::Empty,
37    frame::{Header, Type},
38    time::{Countdown, Timestamp},
39};
40
41type Budget = Arc<Semaphore>;
42type Result<T> = std::result::Result<T, NetworkError>;
43
44/// Max. message size using noise handshake.
45const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
46
47/// Max. message size using noise protocol.
48const MAX_NOISE_MESSAGE_SIZE: usize = 64 * 1024;
49
50/// Max. number of bytes for payload data.
51const MAX_PAYLOAD_SIZE: usize = MAX_NOISE_MESSAGE_SIZE - 32;
52
53/// Noise parameters to initialize the builders.
54const NOISE_PARAMS: &str = "Noise_IK_25519_AESGCM_BLAKE2s";
55
56/// Interval between ping protocol.
57const PING_INTERVAL: Duration = Duration::from_secs(15);
58
59/// Max. allowed duration of a single TCP connect attempt.
60const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61
62/// Max. allowed duration of a Noise handshake.
63const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
64
65/// Max. allowed duration to wait for a peer to answer.
66///
67/// This is started when we have sent a ping. Unless we receive
68/// some data back within this duration, the connection times
69/// out and is dropped.
70const REPLY_TIMEOUT: Duration = Duration::from_secs(30);
71
72/// `Network` is the API facade of this crate.
73#[derive(Debug)]
74pub struct Network<K> {
75    /// Name of this network.
76    pub(crate) name: &'static str,
77
78    /// Log label.
79    pub(crate) label: K,
80
81    /// Max. number of bytes per message.
82    pub(crate) max_message_size: usize,
83
84    /// The network participants.
85    parties: Mutex<HashMap<K, Role>>,
86
87    /// MPSC sender of server task instructions.
88    tx: Sender<Command<K>>,
89
90    /// MPSC receiver of messages from a remote party.
91    ///
92    /// The public key identifies the remote.
93    rx: AsyncMutex<Receiver<(K, Bytes, Option<OwnedSemaphorePermit>)>>,
94
95    /// Handle of the server task that has been spawned by `Network`.
96    srv: Mutex<Option<JoinHandle<Result<Empty>>>>,
97}
98
99impl<K> Drop for Network<K> {
100    fn drop(&mut self) {
101        if let Some(srv) = self.srv.lock().take() {
102            srv.abort();
103        }
104    }
105}
106
107/// Server task instructions.
108#[derive(Debug)]
109pub(crate) enum Command<K> {
110    /// Add the given peers.
111    Add(Role, Vec<(K, PublicKey, NetAddr)>),
112    /// Remove the given peers.
113    Remove(Vec<K>),
114    /// Assign a `Role` to the given peers.
115    Assign(Role, Vec<K>),
116    /// Send a message to one peer.
117    Unicast(K, Option<Id>, Bytes),
118    /// Send a message to some peers.
119    Multicast(Vec<K>, Option<Id>, Bytes),
120    /// Send a message to all peers with `Role::Active`.
121    Broadcast(Option<Id>, Bytes),
122}
123
124/// The `Server` is accepting connections and also establishing and
125/// maintaining connections with all parties.
126#[derive(Debug)]
127struct Server<K> {
128    conf: NetConf<K>,
129
130    /// This server's role.
131    role: Role,
132
133    /// MPSC sender for messages received over a connection to a party.
134    ///
135    /// (see `Network` for the accompanying receiver).
136    ibound: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
137
138    /// MPSC receiver for server task instructions.
139    ///
140    /// (see `Network` for the accompanying sender).
141    obound: Receiver<Command<K>>,
142
143    /// All parties of the network and their addresses.
144    peers: HashMap<K, Peer>,
145
146    /// Bi-directional mapping of signing key and X25519 keys to identify
147    /// remote parties.
148    index: BiHashMap<K, PublicKey>,
149
150    /// Find the public key given a tokio task ID.
151    task2key: HashMap<task::Id, K>,
152
153    /// Currently active connect attempts.
154    connecting: HashMap<K, ConnectTask>,
155
156    /// Currently active connections (post handshake).
157    active: HashMap<K, IoTask>,
158
159    /// Tasks performing a handshake with a remote party.
160    handshake_tasks: JoinSet<Result<(TcpStream, TransportState)>>,
161
162    /// Tasks connecting to a remote party and performing a handshake.
163    connect_tasks: JoinSet<(TcpStream, TransportState)>,
164
165    /// Active I/O tasks, exchanging data with remote parties.
166    io_tasks: JoinSet<Result<()>>,
167
168    /// Interval at which to ping peers.
169    ping_interval: Interval,
170
171    /// For gathering network metrics.
172    #[cfg(feature = "metrics")]
173    metrics: Arc<NetworkMetrics<K>>,
174}
175
176#[derive(Debug)]
177struct Peer {
178    addr: NetAddr,
179    role: Role,
180    budget: Budget,
181}
182
183/// A connect task.
184#[derive(Debug)]
185struct ConnectTask {
186    h: AbortHandle,
187}
188
189// Make sure the task is stopped when `ConnectTask` is dropped.
190impl Drop for ConnectTask {
191    fn drop(&mut self) {
192        self.h.abort();
193    }
194}
195
196/// An I/O task, reading data from and writing data to a remote party.
197#[derive(Debug)]
198struct IoTask {
199    /// Abort handle of the read-half of the connection.
200    rh: AbortHandle,
201
202    /// Abort handle of the write-half of the connection.
203    wh: AbortHandle,
204
205    /// MPSC sender of outgoing messages to the remote.
206    tx: chan::Sender<Message>,
207}
208
209// Make sure all tasks are stopped when `IoTask` is dropped.
210impl Drop for IoTask {
211    fn drop(&mut self) {
212        self.rh.abort();
213        self.wh.abort();
214    }
215}
216
217/// Unify the various data types we want to send to the writer task.
218#[derive(Debug)]
219enum Message {
220    Data(Bytes),
221    Ping(Timestamp),
222    Pong(Timestamp),
223}
224
225impl<K> Network<K> {
226    pub fn public_key(&self) -> &K {
227        &self.label
228    }
229
230    pub fn name(&self) -> &str {
231        self.name
232    }
233
234    /// Get a clone of the MPSC sender.
235    pub(crate) fn sender(&self) -> Sender<Command<K>> {
236        self.tx.clone()
237    }
238}
239
240impl<K: Display> Network<K> {
241    pub async fn close(&self) {
242        let srv = self.srv.lock().take();
243        if let Some(srv) = srv {
244            info!(name = %self.name, node = %self.label, "aborting server task");
245            srv.abort();
246            let _ = srv.await;
247        }
248    }
249}
250
251impl<K> Network<K>
252where
253    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
254{
255    pub async fn create(cfg: NetConf<K>) -> Result<Self> {
256        let listener = TcpListener::bind(cfg.bind.to_string())
257            .await
258            .map_err(|e| NetworkError::Bind(cfg.bind.clone(), e))?;
259
260        debug!(
261            name = %cfg.name,
262            node = %cfg.label,
263            addr = %listener.local_addr()?,
264            "listening"
265        );
266
267        let mut parties = HashMap::new();
268        let mut peers = HashMap::new();
269        let mut index = BiHashMap::new();
270
271        for (k, x, a) in cfg.parties.iter().cloned() {
272            parties.insert(k.clone(), Role::Active);
273            index.insert(k.clone(), x);
274            peers.insert(
275                k,
276                Peer {
277                    addr: a,
278                    role: Role::Active,
279                    budget: cfg.new_budget(),
280                },
281            );
282        }
283
284        // Command channel from application to network.
285        let (otx, orx) = mpsc::channel(cfg.total_capacity_egress);
286
287        // Channel of messages from peers to the application.
288        let (itx, irx) = mpsc::channel(cfg.total_capacity_ingress);
289
290        let mut interval = tokio::time::interval(PING_INTERVAL);
291        interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
292
293        let name = cfg.name;
294        let label = cfg.label.clone();
295        let mmsze = cfg.max_message_size;
296
297        #[cfg(feature = "metrics")]
298        let metrics = {
299            let it = parties.keys().filter(|k| **k != label).cloned();
300            NetworkMetrics::new(name, &*cfg.metrics, it)
301        };
302
303        let server = Server {
304            conf: cfg,
305            role: Role::Active,
306            ibound: itx,
307            obound: orx,
308            peers,
309            index,
310            connecting: HashMap::new(),
311            active: HashMap::new(),
312            task2key: HashMap::new(),
313            handshake_tasks: JoinSet::new(),
314            connect_tasks: JoinSet::new(),
315            io_tasks: JoinSet::new(),
316            ping_interval: interval,
317            #[cfg(feature = "metrics")]
318            metrics: Arc::new(metrics),
319        };
320
321        Ok(Self {
322            name,
323            label,
324            parties: Mutex::new(parties),
325            rx: AsyncMutex::new(irx),
326            tx: otx,
327            srv: Mutex::new(Some(spawn(server.run(listener)))),
328            max_message_size: mmsze,
329        })
330    }
331
332    pub fn parties(&self, r: Option<Role>) -> Vec<K> {
333        self.parties
334            .lock()
335            .iter()
336            .filter(|&(_, x)| r.map(|r| r == *x).unwrap_or(true))
337            .map(|(k, _)| k.clone())
338            .collect()
339    }
340
341    /// Send a message to a party, identified by the given public key.
342    pub async fn unicast(&self, to: K, msg: Bytes) -> Result<()> {
343        if msg.len() > self.max_message_size {
344            warn!(
345                name = %self.name,
346                node = %self.label,
347                to   = %to,
348                len  = %msg.len(),
349                max  = %self.max_message_size,
350                "message too large to send"
351            );
352            return Err(NetworkError::MessageTooLarge);
353        }
354        self.tx
355            .send(Command::Unicast(to, None, msg))
356            .await
357            .map_err(|_| NetworkError::ChannelClosed)
358    }
359
360    /// Send a message to all parties.
361    pub async fn broadcast(&self, msg: Bytes) -> Result<()> {
362        if msg.len() > self.max_message_size {
363            warn!(
364                name = %self.name,
365                node = %self.label,
366                len  = %msg.len(),
367                max  = %self.max_message_size,
368                "message too large to broadcast"
369            );
370            return Err(NetworkError::MessageTooLarge);
371        }
372        self.tx
373            .send(Command::Broadcast(None, msg))
374            .await
375            .map_err(|_| NetworkError::ChannelClosed)
376    }
377
378    /// Receive a message from a remote party.
379    pub async fn receive(&self) -> StdResult<(K, Bytes), NetworkDown> {
380        let mut rx = self.rx.lock().await;
381        let (k, b, _) = rx.recv().await.ok_or(NetworkDown::new())?;
382        Ok((k, b))
383    }
384
385    /// Add the given peers to the network.
386    pub async fn add(
387        &self,
388        r: Role,
389        peers: Vec<(K, PublicKey, NetAddr)>,
390    ) -> StdResult<(), NetworkDown> {
391        self.parties
392            .lock()
393            .extend(peers.iter().map(|(p, ..)| (p.clone(), r)));
394        self.tx
395            .send(Command::Add(r, peers))
396            .await
397            .map_err(|_| NetworkDown::new())
398    }
399
400    /// Remove the given peers from the network.
401    pub async fn remove(&self, peers: Vec<K>) -> StdResult<(), NetworkDown> {
402        {
403            let mut parties = self.parties.lock();
404            for p in &peers {
405                parties.remove(p);
406            }
407        }
408        self.tx
409            .send(Command::Remove(peers))
410            .await
411            .map_err(|_| NetworkDown::new())
412    }
413
414    /// Assign the given role to the given peers.
415    pub async fn assign(&self, r: Role, peers: Vec<K>) -> StdResult<(), NetworkDown> {
416        {
417            let mut parties = self.parties.lock();
418            for p in &peers {
419                if let Some(role) = parties.get_mut(p) {
420                    *role = r
421                }
422            }
423        }
424        self.tx
425            .send(Command::Assign(r, peers))
426            .await
427            .map_err(|_| NetworkDown::new())
428    }
429}
430
431impl<K> Server<K>
432where
433    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
434{
435    /// Runs the main loop of this network node.
436    ///
437    /// This function:
438    ///
439    /// - Tries to connect to each remote peer in the committee.
440    /// - Handles tasks that have been completed or terminated.
441    /// - Processes new messages we received on the network.
442    async fn run(mut self, listener: TcpListener) -> Result<Empty> {
443        self.handshake_tasks.spawn(pending());
444        self.io_tasks.spawn(pending());
445
446        // Connect to all peers.
447        for k in self
448            .peers
449            .keys()
450            .filter(|k| **k != self.conf.label)
451            .cloned()
452            .collect::<Vec<_>>()
453        {
454            self.spawn_connect(k)
455        }
456
457        loop {
458            trace!(
459                name       = %self.conf.name,
460                node       = %self.conf.label,
461                active     = %self.active.len(),
462                connects   = %self.connect_tasks.len(),
463                handshakes = %self.handshake_tasks.len().saturating_sub(1), // -1 for `pending()`
464                io_tasks   = %self.io_tasks.len().saturating_sub(1), // -1 for `pending()`
465                tasks_ids  = %self.task2key.len(),
466                iqueue     = %self.ibound.capacity(),
467                oqueue     = %self.obound.capacity(),
468            );
469
470            #[cfg(feature = "metrics")]
471            {
472                self.metrics.iqueue.set(self.ibound.capacity());
473                self.metrics.oqueue.set(self.obound.capacity());
474            }
475
476            tokio::select! {
477                // Accepted a new connection.
478                i = listener.accept() => match i {
479                    Ok((s, a)) => {
480                        debug!(
481                            name = %self.conf.name,
482                            node = %self.conf.label,
483                            addr = %a,
484                            "accepted connection"
485                        );
486                        self.spawn_handshake(s)
487                    }
488                    Err(e) => {
489                        warn!(
490                            name = %self.conf.name,
491                            node = %self.conf.label,
492                            err  = %e,
493                            "error accepting connection"
494                        )
495                    }
496                },
497                // The handshake of an inbound connection completed.
498                Some(h) = self.handshake_tasks.join_next() => match h {
499                    Ok(Ok((s, t))) => {
500                        let Some((k, peer)) = self.lookup_peer(&t) else {
501                            info!(
502                                name = %self.conf.name,
503                                node = %self.conf.label,
504                                peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
505                                addr = ?s.peer_addr().ok(),
506                                "unknown peer"
507                            );
508                            continue
509                        };
510                        if !self.is_valid_ip(&k, &s) {
511                            warn!(
512                                name = %self.conf.name,
513                                node = %self.conf.label,
514                                peer = %k,
515                                addr = ?s.peer_addr().ok(), "invalid peer ip addr"
516                            );
517                            continue
518                        }
519                        // We only accept connections whose party has a public key that
520                        // is larger than ours, or if we do not have a connection for
521                        // that key at the moment.
522                        if k > self.conf.label || !self.active.contains_key(&k) {
523                            self.spawn_io(k, s, t, peer.budget.clone())
524                        } else {
525                            debug!(
526                                name = %self.conf.name,
527                                node = %self.conf.label,
528                                peer = %k,
529                                "dropping accepted connection"
530                            );
531                        }
532                    }
533                    Ok(Err(e)) => {
534                        warn!(
535                            name = %self.conf.name,
536                            node = %self.conf.label,
537                            err  = %e,
538                            "handshake failed"
539                        )
540                    }
541                    Err(e) => {
542                        if !e.is_cancelled() {
543                            error!(
544                                name = %self.conf.name,
545                                node = %self.conf.label,
546                                err  = %e,
547                                "handshake task panic"
548                            )
549                        }
550                    }
551                },
552                // One of our connection attempts completed.
553                Some(tt) = self.connect_tasks.join_next_with_id() => {
554                    match tt {
555                        Ok((id, (s, t))) => {
556                            self.on_connect_task_end(id);
557                            let Some((k, peer)) = self.lookup_peer(&t) else {
558                                warn!(
559                                    name = %self.conf.name,
560                                    node = %self.conf.label,
561                                    peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
562                                    addr = ?s.peer_addr().ok(),
563                                    "connected to unknown peer"
564                                );
565                                continue
566                            };
567                            // We only keep the connection if our key is larger than the remote,
568                            // or if we do not have a connection for that key at the moment.
569                            if k < self.conf.label || !self.active.contains_key(&k) {
570                                self.spawn_io(k, s, t, peer.budget.clone())
571                            } else {
572                                debug!(
573                                    name = %self.conf.name,
574                                    node = %self.conf.label,
575                                    peer = %k,
576                                    "dropping new connection"
577                                )
578                            }
579                        }
580                        Err(e) => {
581                            if !e.is_cancelled() {
582                                error!(
583                                    name = %self.conf.name,
584                                    node = %self.conf.label,
585                                    err  = %e,
586                                    "connect task panic"
587                                )
588                            }
589                            self.on_connect_task_end(e.id());
590                        }
591                    }
592                },
593                // A read or write task completed.
594                Some(io) = self.io_tasks.join_next_with_id() => {
595                    match io {
596                        Ok((id, r)) => {
597                            if let Err(e) = r {
598                                warn!(
599                                    name = %self.conf.name,
600                                    node = %self.conf.label,
601                                    err  = %e,
602                                    "i/o error"
603                                )
604                            }
605                            self.on_io_task_end(id);
606                        }
607                        Err(e) => {
608                            if e.is_cancelled() {
609                                // If one half completes we cancel the other, so there is
610                                // nothing else to do here, except to remove the cancelled
611                                // tasks's ID. Same if we kill the connection, both tasks
612                                // get cancelled.
613                                self.task2key.remove(&e.id());
614                                continue
615                            }
616                            // If the task has not been cancelled, it must have panicked.
617                            error!(
618                                name = %self.conf.name,
619                                node = %self.conf.label,
620                                err  = %e,
621                                "i/o task panic"
622                            );
623                            self.on_io_task_end(e.id())
624                        }
625                    };
626                },
627                cmd = self.obound.recv() => match cmd {
628                    Some(Command::Add(role, peers)) => {
629                        #[cfg(feature = "metrics")]
630                        Arc::make_mut(&mut self.metrics).add_parties(peers.iter().map(|(k, ..)| k).cloned());
631                        for (k, x, a) in peers {
632                            let mut existing_budget = None;
633                            if let Some(peer) = self.peers.get(&k) {
634                                let Some(y) = self.index.get_by_left(&k) else {
635                                    error!(
636                                        name = %self.conf.name,
637                                        node = %self.conf.label,
638                                        peer = %k,
639                                        addr = %a,
640                                        "peer is known, but x25519 key not found"
641                                    );
642                                    continue
643                                };
644                                if a == peer.addr && x == *y {
645                                    debug!(
646                                        name = %self.conf.name,
647                                        node = %self.conf.label,
648                                        peer = %k,
649                                        addr = %a,
650                                        "peer to add already exists"
651                                    );
652                                    continue
653                                }
654                                existing_budget = Some(peer.budget.clone());
655                            }
656                            info!(
657                                name    = %self.conf.name,
658                                node    = %self.conf.label,
659                                role    = %role,
660                                peer    = %k,
661                                addr    = %a,
662                                replace = %existing_budget.is_some(),
663                                "adding or replacing peer"
664                            );
665                            let p = Peer {
666                                addr: a,
667                                role,
668                                budget: existing_budget.unwrap_or_else(|| self.conf.new_budget())
669                            };
670                            self.peers.insert(k.clone(), p);
671                            self.index.insert(k.clone(), x);
672                            self.connecting.remove(&k);
673                            self.active.remove(&k);
674                            self.spawn_connect(k)
675                        }
676                    }
677                    Some(Command::Remove(peers)) => {
678                        for k in &peers {
679                            info!(
680                                name = %self.conf.name,
681                                node = %self.conf.label,
682                                peer = %k,
683                                "removing peer"
684                            );
685                            self.peers.remove(k);
686                            self.index.remove_by_left(k);
687                            self.connecting.remove(k);
688                            self.active.remove(k);
689                        }
690                        #[cfg(feature = "metrics")]
691                        Arc::make_mut(&mut self.metrics).remove_parties(&peers)
692                    }
693                    Some(Command::Assign(role, peers)) => {
694                        for k in &peers {
695                            if let Some(p) = self.peers.get_mut(k) {
696                                p.role = role
697                            } else {
698                                warn!(
699                                    name = %self.conf.name,
700                                    node = %self.conf.label,
701                                    peer = %k,
702                                    role = ?role,
703                                    "peer to assign role to not found"
704                                );
705                            }
706                        }
707                    }
708                    Some(Command::Unicast(to, id, m)) => {
709                        if to == self.conf.label {
710                            trace!(
711                                name  = %self.conf.name,
712                                node  = %self.conf.label,
713                                to    = %to,
714                                len   = %m.len(),
715                                queue = self.ibound.capacity(),
716                                "sending message"
717                            );
718                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m, None)) {
719                                warn!(
720                                    name = %self.conf.name,
721                                    node = %self.conf.label,
722                                    err  = %err,
723                                    cap  = %self.ibound.capacity(),
724                                    "channel full => dropping unicast message"
725                                )
726                            }
727                            continue
728                        }
729                        if let Some(task) = self.active.get(&to) {
730                            trace!(
731                                name  = %self.conf.name,
732                                node  = %self.conf.label,
733                                to    = %to,
734                                len   = %m.len(),
735                                queue = task.tx.capacity(),
736                                "sending message"
737                            );
738                            #[cfg(feature = "metrics")]
739                            self.metrics.set_peer_oqueue_cap(&to, task.tx.capacity());
740                            task.tx.send(id, Message::Data(m))
741                        }
742                    }
743                    Some(Command::Multicast(peers, id, m)) => {
744                        if peers.contains(&self.conf.label) {
745                            trace!(
746                                name  = %self.conf.name,
747                                node  = %self.conf.label,
748                                to    = %self.conf.label,
749                                len   = %m.len(),
750                                queue = self.ibound.capacity(),
751                                "sending message"
752                            );
753                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
754                                warn!(
755                                    name = %self.conf.name,
756                                    node = %self.conf.label,
757                                    err  = %err,
758                                    cap  = %self.ibound.capacity(),
759                                    "channel full => dropping multicast message"
760                                )
761                            }
762                        }
763                        for (to, task) in &self.active {
764                            if !peers.contains(to) {
765                                continue
766                            }
767                            trace!(
768                                name  = %self.conf.name,
769                                node  = %self.conf.label,
770                                to    = %to,
771                                len   = %m.len(),
772                                queue = task.tx.capacity(),
773                                "sending message"
774                            );
775                            #[cfg(feature = "metrics")]
776                            self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
777                            task.tx.send(id, Message::Data(m.clone()))
778                        }
779                    }
780                    Some(Command::Broadcast(id, m)) => {
781                        if self.role.is_active() {
782                            trace!(
783                                name  = %self.conf.name,
784                                node  = %self.conf.label,
785                                to    = %self.conf.label,
786                                len   = %m.len(),
787                                queue = self.ibound.capacity(),
788                                "sending message"
789                            );
790                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
791                                warn!(
792                                    name = %self.conf.name,
793                                    node = %self.conf.label,
794                                    err  = %err,
795                                    cap  = %self.ibound.capacity(),
796                                    "channel full => dropping broadcast message"
797                                )
798                            }
799                        }
800                        for (to, task) in &self.active {
801                            if Some(Role::Active) != self.peers.get(to).map(|p| p.role) {
802                                continue
803                            }
804                            trace!(
805                                name  = %self.conf.name,
806                                node  = %self.conf.label,
807                                to    = %to,
808                                len   = %m.len(),
809                                queue = task.tx.capacity(),
810                                "sending message"
811                            );
812                            #[cfg(feature = "metrics")]
813                            self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
814                            task.tx.send(id, Message::Data(m.clone()))
815                        }
816                    }
817                    None => {
818                        return Err(NetworkError::ChannelClosed)
819                    }
820                },
821                _ = self.ping_interval.tick() => {
822                    let now = Timestamp::now();
823                    for task in self.active.values() {
824                        task.tx.send(None, Message::Ping(now))
825                    }
826                }
827            }
828        }
829    }
830
831    /// Handles a completed connect task.
832    fn on_connect_task_end(&mut self, id: task::Id) {
833        let Some(k) = self.task2key.remove(&id) else {
834            error!(name = %self.conf.name, node = %self.conf.label, "no key for connect task");
835            return;
836        };
837        self.connecting.remove(&k);
838    }
839
840    /// Handles a completed I/O task.
841    ///
842    /// This function will get the public key of the task that was terminated
843    /// and then cleanly removes the associated I/O task data and re-connects
844    /// to the peer node it was interacting with.
845    fn on_io_task_end(&mut self, id: task::Id) {
846        let Some(k) = self.task2key.remove(&id) else {
847            error!(name = %self.conf.name, node = %self.conf.label, "no key for i/o task");
848            return;
849        };
850        let Some(task) = self.active.get(&k) else {
851            return;
852        };
853        if task.rh.id() == id {
854            debug!(
855                name = %self.conf.name,
856                node = %self.conf.label,
857                peer = %k,
858                "read-half closed => dropping connection"
859            );
860            self.active.remove(&k);
861            self.spawn_connect(k)
862        } else if task.wh.id() == id {
863            debug!(
864                name = %self.conf.name,
865                node = %self.conf.label,
866                peer = %k,
867                "write-half closed => dropping connection"
868            );
869            self.active.remove(&k);
870            self.spawn_connect(k)
871        } else {
872            debug!(
873                name = %self.conf.name,
874                node = %self.conf.label,
875                peer = %k,
876                "i/o task was previously replaced"
877            );
878        }
879    }
880
881    /// Spawns a new connection task to a peer identified by public key.
882    ///
883    /// This function will look up the x25519 public key of the ed25519 key
884    /// and the remote address and then spawn a connection task.
885    fn spawn_connect(&mut self, k: K) {
886        if self.connecting.contains_key(&k) {
887            debug!(
888                name = %self.conf.name,
889                node = %self.conf.label,
890                peer = %k,
891                "connect task already started"
892            );
893            return;
894        }
895        let x = self.index.get_by_left(&k).expect("known public key");
896        let p = self.peers.get(&k).expect("known peer");
897        let h = self.connect_tasks.spawn(connect(
898            self.conf.name,
899            (self.conf.label.clone(), self.conf.keypair.clone()),
900            (k.clone(), *x),
901            p.addr.clone(),
902            self.conf.retry_delays,
903            #[cfg(feature = "metrics")]
904            self.metrics.clone(),
905        ));
906        assert!(self.task2key.insert(h.id(), k.clone()).is_none());
907        self.connecting.insert(k, ConnectTask { h });
908    }
909
910    /// Spawns a new `Noise` responder handshake task using the IK pattern.
911    ///
912    /// This function will create the responder handshake machine using its
913    /// own private key and then spawn a task that awaits an initiator handshake
914    /// to which it will respond.
915    fn spawn_handshake(&mut self, s: TcpStream) {
916        let h = Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
917            .local_private_key(&self.conf.keypair.secret_key().as_bytes())
918            .expect("valid private key")
919            .prologue(self.conf.name.as_bytes())
920            .expect("1st time we set the prologue")
921            .build_responder()
922            .expect("valid noise params yield valid handshake state");
923        self.handshake_tasks.spawn(async move {
924            timeout(HANDSHAKE_TIMEOUT, on_handshake(h, s))
925                .await
926                .or(Err(NetworkError::Timeout))?
927        });
928    }
929
930    /// Spawns a new I/O task for handling communication with a remote peer over
931    /// a TCP connection using the noise framework to create an authenticated
932    /// secure link.
933    fn spawn_io(&mut self, k: K, s: TcpStream, t: TransportState, b: Budget) {
934        debug!(
935            name = %self.conf.name,
936            node = %self.conf.label,
937            peer = %k,
938            addr = ?s.peer_addr().ok(),
939            "starting i/o tasks"
940        );
941        let (to_remote, from_remote) = chan::channel(self.conf.peer_capacity_egress);
942        let (r, w) = s.into_split();
943        let t1 = Arc::new(Mutex::new(t));
944        let t2 = t1.clone();
945        let ibound = self.ibound.clone();
946        let to_write = to_remote.clone();
947        let countdown = Countdown::new();
948        let rh = self.io_tasks.spawn(recv_loop(
949            self.conf.name,
950            k.clone(),
951            r,
952            t1,
953            ibound,
954            to_write,
955            #[cfg(feature = "metrics")]
956            self.metrics.clone(),
957            b,
958            countdown.clone(),
959            self.conf.max_message_size,
960        ));
961        let wh = self
962            .io_tasks
963            .spawn(send_loop(w, t2, from_remote, countdown));
964        assert!(self.task2key.insert(rh.id(), k.clone()).is_none());
965        assert!(self.task2key.insert(wh.id(), k.clone()).is_none());
966        let io = IoTask {
967            rh,
968            wh,
969            tx: to_remote,
970        };
971        self.active.insert(k, io);
972        #[cfg(feature = "metrics")]
973        self.metrics.connections.set(self.active.len());
974    }
975
976    /// Get the public key of a party by their static X25519 public key.
977    fn lookup_peer(&self, t: &TransportState) -> Option<(K, &Peer)> {
978        let x = t.get_remote_static()?;
979        let x = PublicKey::try_from(x).ok()?;
980        let k = self.index.get_by_right(&x)?;
981        self.peers.get(k).map(|p| (k.clone(), p))
982    }
983
984    /// Check if the socket's peer IP address corresponds to the configured one.
985    fn is_valid_ip(&self, k: &K, s: &TcpStream) -> bool {
986        self.peers
987            .get(k)
988            .map(|p| {
989                let NetAddr::Inet(ip, _) = p.addr else {
990                    return true;
991                };
992                Some(ip) == s.peer_addr().ok().map(|a| a.ip())
993            })
994            .unwrap_or(false)
995    }
996}
997
998/// Connect to the given socket address.
999///
1000/// This function will only return, when a connection has been established and the handshake
1001/// has been completed.
1002async fn connect<K>(
1003    name: &'static str,
1004    this: (K, Keypair),
1005    to: (K, PublicKey),
1006    addr: NetAddr,
1007    delays: [u8; NUM_DELAYS],
1008    #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1009) -> (TcpStream, TransportState)
1010where
1011    K: Eq + Hash + Display + Clone,
1012{
1013    use rand::prelude::*;
1014
1015    let new_handshake_state = || {
1016        Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
1017            .local_private_key(this.1.secret_key().as_slice())
1018            .expect("valid private key")
1019            .remote_public_key(to.1.as_slice())
1020            .expect("valid remote pub key")
1021            .prologue(name.as_bytes())
1022            .expect("1st time we set the prologue")
1023            .build_initiator()
1024            .expect("valid noise params yield valid handshake state")
1025    };
1026
1027    let delays = once(rand::rng().random_range(0..=1000))
1028        .chain(delays.into_iter().map(|d| u64::from(d) * 1000))
1029        .chain(repeat(u64::from(delays[LAST_DELAY]) * 1000));
1030
1031    let addr = addr.to_string();
1032
1033    for d in delays {
1034        sleep(Duration::from_millis(d)).await;
1035        debug!(%name, node = %this.0, peer = %to.0, %addr, "connecting");
1036        #[cfg(feature = "metrics")]
1037        metrics.add_connect_attempt(&to.0);
1038        match timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
1039            Ok(Ok(s)) => {
1040                if let Err(err) = s.set_nodelay(true) {
1041                    error!(%name, node = %this.0, %err, "failed to set NO_DELAY socket option");
1042                    continue;
1043                }
1044                match timeout(HANDSHAKE_TIMEOUT, handshake(new_handshake_state(), s)).await {
1045                    Ok(Ok(x)) => {
1046                        debug!(%name, node = %this.0, peer = %to.0, %addr, "connection established");
1047                        return x;
1048                    },
1049                    Ok(Err(err)) => {
1050                        warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "handshake failure");
1051                    },
1052                    Err(_) => {
1053                        warn!(%name, node = %this.0, peer = %to.0, %addr, "handshake timeout");
1054                    },
1055                }
1056            },
1057            Ok(Err(err)) => {
1058                warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "failed to connect");
1059            },
1060            Err(_) => {
1061                warn!(%name, node = %this.0, peer = %to.0, %addr, "connect timeout");
1062            },
1063        }
1064    }
1065
1066    unreachable!("for loop repeats forever")
1067}
1068
1069/// Perform a noise handshake as initiator with the remote party.
1070async fn handshake(
1071    mut hs: HandshakeState,
1072    mut stream: TcpStream,
1073) -> Result<(TcpStream, TransportState)> {
1074    let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1075    let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1076    let h = Header::data(n as u16);
1077    send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1078    let (h, m) = recv_frame(&mut stream).await?;
1079    if !h.is_data() || h.is_partial() {
1080        return Err(NetworkError::InvalidHandshakeMessage);
1081    }
1082    hs.read_message(&m, &mut b)?;
1083    Ok((stream, hs.into_transport_mode()?))
1084}
1085
1086/// Perform a noise handshake as responder with a remote party.
1087async fn on_handshake(
1088    mut hs: HandshakeState,
1089    mut stream: TcpStream,
1090) -> Result<(TcpStream, TransportState)> {
1091    stream.set_nodelay(true)?;
1092    let (h, m) = recv_frame(&mut stream).await?;
1093    if !h.is_data() || h.is_partial() {
1094        return Err(NetworkError::InvalidHandshakeMessage);
1095    }
1096    let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1097    hs.read_message(&m, &mut b)?;
1098    let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1099    let h = Header::data(n as u16);
1100    send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1101    Ok((stream, hs.into_transport_mode()?))
1102}
1103
1104/// Read messages from the remote by assembling frames together.
1105///
1106/// Once complete the message will be handed over to the given MPSC sender.
1107#[allow(clippy::too_many_arguments)]
1108async fn recv_loop<R, K>(
1109    name: &'static str,
1110    id: K,
1111    mut reader: R,
1112    state: Arc<Mutex<TransportState>>,
1113    to_deliver: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
1114    to_writer: chan::Sender<Message>,
1115    #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1116    budget: Arc<Semaphore>,
1117    mut countdown: Countdown,
1118    max_message_size: usize,
1119) -> Result<()>
1120where
1121    R: AsyncRead + Unpin,
1122    K: Eq + Hash + Display + Clone,
1123{
1124    let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1125    loop {
1126        #[cfg(feature = "metrics")]
1127        metrics.set_peer_iqueue_cap(&id, budget.available_permits());
1128        let permit = budget
1129            .clone()
1130            .acquire_owned()
1131            .await
1132            .map_err(|_| NetworkError::BudgetClosed)?;
1133        let mut msg = BytesMut::new();
1134        loop {
1135            tokio::select! {
1136                val = recv_frame(&mut reader) => {
1137                    countdown.stop();
1138                    match val {
1139                        Ok((h, f)) => {
1140                            match h.frame_type() {
1141                                Ok(Type::Ping) => {
1142                                    // Received ping message; sending pong to writer
1143                                    let n = state.lock().read_message(&f, &mut buf)?;
1144                                    if let Some(ping) = Timestamp::try_from_slice(&buf[..n]) {
1145                                        to_writer.send(None, Message::Pong(ping))
1146                                    }
1147                                }
1148                                Ok(Type::Pong) => {
1149                                    // Received pong message; measure elapsed time
1150                                    let _n = state.lock().read_message(&f, &mut buf)?;
1151                                    #[cfg(feature = "metrics")]
1152                                    if let Some(ping) = Timestamp::try_from_slice(&buf[.._n])
1153                                        && let Some(delay) = Timestamp::now().diff(ping)
1154                                    {
1155                                        metrics.set_latency(&id, delay)
1156                                    }
1157                                }
1158                                Ok(Type::Data) => {
1159                                    let n = state.lock().read_message(&f, &mut buf)?;
1160                                    msg.extend_from_slice(&buf[..n]);
1161                                    if msg.len() > max_message_size {
1162                                        return Err(NetworkError::MessageTooLarge);
1163                                    }
1164                                    if !h.is_partial() {
1165                                        break;
1166                                    }
1167                                }
1168                                Err(t) => return Err(NetworkError::UnknownFrameType(t)),
1169                            }
1170                        }
1171                        Err(e) => return Err(e)
1172                    }
1173                },
1174                () = &mut countdown => {
1175                    warn!(%name, node = %id, "timeout waiting for peer");
1176                    return Err(NetworkError::Timeout)
1177                }
1178            }
1179        }
1180        if to_deliver
1181            .send((id.clone(), msg.freeze(), Some(permit)))
1182            .await
1183            .is_err()
1184        {
1185            break;
1186        }
1187    }
1188    Ok(())
1189}
1190
1191/// Consume messages to be delivered to remote parties and send them.
1192///
1193/// The function automatically splits large messages into chunks that fit into
1194/// a noise package.
1195async fn send_loop<W>(
1196    mut writer: W,
1197    state: Arc<Mutex<TransportState>>,
1198    rx: chan::Receiver<Message>,
1199    countdown: Countdown,
1200) -> Result<()>
1201where
1202    W: AsyncWrite + Unpin,
1203{
1204    let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1205
1206    while let Some(msg) = rx.recv().await {
1207        match msg {
1208            Message::Ping(ping) => {
1209                let n = state
1210                    .lock()
1211                    .write_message(&ping.to_bytes()[..], &mut buf[Header::SIZE..])?;
1212                let h = Header::ping(n as u16);
1213                send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?;
1214                countdown.start(REPLY_TIMEOUT)
1215            },
1216            Message::Pong(pong) => {
1217                let n = state
1218                    .lock()
1219                    .write_message(&pong.to_bytes()[..], &mut buf[Header::SIZE..])?;
1220                let h = Header::pong(n as u16);
1221                send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1222            },
1223            Message::Data(msg) => {
1224                let mut it = msg.chunks(MAX_PAYLOAD_SIZE).peekable();
1225                while let Some(m) = it.next() {
1226                    let n = state.lock().write_message(m, &mut buf[Header::SIZE..])?;
1227                    let h = if it.peek().is_some() {
1228                        Header::data(n as u16).partial()
1229                    } else {
1230                        Header::data(n as u16)
1231                    };
1232                    send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1233                }
1234            },
1235        }
1236    }
1237    Ok(())
1238}
1239
1240/// Read a single frame (header + payload) from the remote.
1241async fn recv_frame<R>(r: &mut R) -> Result<(Header, Vec<u8>)>
1242where
1243    R: AsyncRead + Unpin,
1244{
1245    let b = r.read_u32().await?;
1246    let h = Header::try_from(b.to_be_bytes())?;
1247    let mut v = vec![0; h.len().into()];
1248    r.read_exact(&mut v).await?;
1249    Ok((h, v))
1250}
1251
1252/// Write a single frame (header + payload) to the remote.
1253///
1254/// The header is serialised into the first 4 bytes of `msg`. It is the
1255/// caller's responsibility to ensure there is room at the beginning.
1256async fn send_frame<W>(w: &mut W, hdr: Header, msg: &mut [u8]) -> Result<()>
1257where
1258    W: AsyncWrite + Unpin,
1259{
1260    debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
1261    msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
1262    w.write_all(msg).await?;
1263    Ok(())
1264}