cliquenet/
net.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    collections::HashMap,
5    fmt::Display,
6    future::pending,
7    hash::Hash,
8    iter::{once, repeat},
9    sync::Arc,
10    time::Duration,
11};
12
13use bimap::BiHashMap;
14use bytes::{Bytes, BytesMut};
15use parking_lot::Mutex;
16use snow::{Builder, HandshakeState, TransportState};
17use tokio::{
18    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
19    net::{TcpListener, TcpStream},
20    spawn,
21    sync::{
22        Mutex as AsyncMutex, OwnedSemaphorePermit, Semaphore,
23        mpsc::{self, Receiver, Sender},
24    },
25    task::{self, AbortHandle, JoinHandle, JoinSet},
26    time::{Interval, MissedTickBehavior, sleep, timeout},
27};
28use tracing::{debug, error, info, trace, warn};
29
30#[cfg(feature = "metrics")]
31use crate::metrics::NetworkMetrics;
32use crate::{
33    Address, Id, Keypair, LAST_DELAY, NUM_DELAYS, NetConf, NetworkError, PublicKey, Role, chan,
34    error::Empty,
35    frame::{Header, Type},
36    time::{Countdown, Timestamp},
37};
38
39type Budget = Arc<Semaphore>;
40type Result<T> = std::result::Result<T, NetworkError>;
41
42/// Max. message size using noise handshake.
43const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
44
45/// Max. message size using noise protocol.
46const MAX_NOISE_MESSAGE_SIZE: usize = 64 * 1024;
47
48/// Max. number of bytes for payload data.
49const MAX_PAYLOAD_SIZE: usize = MAX_NOISE_MESSAGE_SIZE - 32;
50
51/// Noise parameters to initialize the builders.
52const NOISE_PARAMS: &str = "Noise_IK_25519_AESGCM_BLAKE2s";
53
54/// Interval between ping protocol.
55const PING_INTERVAL: Duration = Duration::from_secs(15);
56
57/// Max. allowed duration of a single TCP connect attempt.
58const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
59
60/// Max. allowed duration of a Noise handshake.
61const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
62
63/// Max. allowed duration to wait for a peer to answer.
64///
65/// This is started when we have sent a ping. Unless we receive
66/// some data back within this duration, the connection times
67/// out and is dropped.
68const REPLY_TIMEOUT: Duration = Duration::from_secs(30);
69
70/// `Network` is the API facade of this crate.
71#[derive(Debug)]
72pub struct Network<K> {
73    /// Name of this network.
74    name: &'static str,
75
76    /// Log label.
77    label: K,
78
79    /// The network participants.
80    parties: Mutex<HashMap<K, Role>>,
81
82    /// MPSC sender of server task instructions.
83    tx: Sender<Command<K>>,
84
85    /// MPSC receiver of messages from a remote party.
86    ///
87    /// The public key identifies the remote.
88    rx: AsyncMutex<Receiver<(K, Bytes, Option<OwnedSemaphorePermit>)>>,
89
90    /// Handle of the server task that has been spawned by `Network`.
91    srv: JoinHandle<Result<Empty>>,
92
93    /// Max. number of bytes per message.
94    max_message_size: usize,
95}
96
97impl<K> Drop for Network<K> {
98    fn drop(&mut self) {
99        self.srv.abort()
100    }
101}
102
103/// Server task instructions.
104#[derive(Debug)]
105pub(crate) enum Command<K> {
106    /// Add the given peers.
107    Add(Vec<(K, PublicKey, Address)>),
108    /// Remove the given peers.
109    Remove(Vec<K>),
110    /// Assign a `Role` to the given peers.
111    Assign(Role, Vec<K>),
112    /// Send a message to one peer.
113    Unicast(K, Option<Id>, Bytes),
114    /// Send a message to some peers.
115    Multicast(Vec<K>, Option<Id>, Bytes),
116    /// Send a message to all peers with `Role::Active`.
117    Broadcast(Option<Id>, Bytes),
118}
119
120/// The `Server` is accepting connections and also establishing and
121/// maintaining connections with all parties.
122#[derive(Debug)]
123struct Server<K> {
124    conf: NetConf<K>,
125
126    /// This server's role.
127    role: Role,
128
129    /// MPSC sender for messages received over a connection to a party.
130    ///
131    /// (see `Network` for the accompanying receiver).
132    ibound: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
133
134    /// MPSC receiver for server task instructions.
135    ///
136    /// (see `Network` for the accompanying sender).
137    obound: Receiver<Command<K>>,
138
139    /// All parties of the network and their addresses.
140    peers: HashMap<K, Peer>,
141
142    /// Bi-directional mapping of signing key and X25519 keys to identify
143    /// remote parties.
144    index: BiHashMap<K, PublicKey>,
145
146    /// Find the public key given a tokio task ID.
147    task2key: HashMap<task::Id, K>,
148
149    /// Currently active connect attempts.
150    connecting: HashMap<K, ConnectTask>,
151
152    /// Currently active connections (post handshake).
153    active: HashMap<K, IoTask>,
154
155    /// Tasks performing a handshake with a remote party.
156    handshake_tasks: JoinSet<Result<(TcpStream, TransportState)>>,
157
158    /// Tasks connecting to a remote party and performing a handshake.
159    connect_tasks: JoinSet<(TcpStream, TransportState)>,
160
161    /// Active I/O tasks, exchanging data with remote parties.
162    io_tasks: JoinSet<Result<()>>,
163
164    /// Interval at which to ping peers.
165    ping_interval: Interval,
166
167    /// For gathering network metrics.
168    #[cfg(feature = "metrics")]
169    metrics: Arc<NetworkMetrics<K>>,
170}
171
172#[derive(Debug)]
173struct Peer {
174    addr: Address,
175    role: Role,
176    budget: Budget,
177}
178
179/// A connect task.
180#[derive(Debug)]
181struct ConnectTask {
182    h: AbortHandle,
183}
184
185// Make sure the task is stopped when `ConnectTask` is dropped.
186impl Drop for ConnectTask {
187    fn drop(&mut self) {
188        self.h.abort();
189    }
190}
191
192/// An I/O task, reading data from and writing data to a remote party.
193#[derive(Debug)]
194struct IoTask {
195    /// Abort handle of the read-half of the connection.
196    rh: AbortHandle,
197
198    /// Abort handle of the write-half of the connection.
199    wh: AbortHandle,
200
201    /// MPSC sender of outgoing messages to the remote.
202    tx: chan::Sender<Message>,
203}
204
205// Make sure all tasks are stopped when `IoTask` is dropped.
206impl Drop for IoTask {
207    fn drop(&mut self) {
208        self.rh.abort();
209        self.wh.abort();
210    }
211}
212
213/// Unify the various data types we want to send to the writer task.
214#[derive(Debug)]
215enum Message {
216    Data(Bytes),
217    Ping(Timestamp),
218    Pong(Timestamp),
219}
220
221impl<K> Network<K>
222where
223    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
224{
225    pub async fn create(cfg: NetConf<K>) -> Result<Self> {
226        let listener = TcpListener::bind(cfg.bind.to_string())
227            .await
228            .map_err(|e| NetworkError::Bind(cfg.bind.clone(), e))?;
229
230        debug!(
231            name = %cfg.name,
232            node = %cfg.label,
233            addr = %listener.local_addr()?,
234            "listening"
235        );
236
237        let mut parties = HashMap::new();
238        let mut peers = HashMap::new();
239        let mut index = BiHashMap::new();
240
241        for (k, x, a) in cfg.parties.iter().cloned() {
242            parties.insert(k.clone(), Role::Active);
243            index.insert(k.clone(), x);
244            peers.insert(
245                k,
246                Peer {
247                    addr: a,
248                    role: Role::Active,
249                    budget: cfg.new_budget(),
250                },
251            );
252        }
253
254        // Command channel from application to network.
255        let (otx, orx) = mpsc::channel(cfg.total_capacity_egress);
256
257        // Channel of messages from peers to the application.
258        let (itx, irx) = mpsc::channel(cfg.total_capacity_ingress);
259
260        let mut interval = tokio::time::interval(PING_INTERVAL);
261        interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
262
263        let name = cfg.name;
264        let label = cfg.label.clone();
265        let mmsze = cfg.max_message_size;
266
267        #[cfg(feature = "metrics")]
268        let metrics = {
269            let it = parties.keys().filter(|k| **k != label).cloned();
270            NetworkMetrics::new(name, &*cfg.metrics, it)
271        };
272
273        let server = Server {
274            conf: cfg,
275            role: Role::Active,
276            ibound: itx,
277            obound: orx,
278            peers,
279            index,
280            connecting: HashMap::new(),
281            active: HashMap::new(),
282            task2key: HashMap::new(),
283            handshake_tasks: JoinSet::new(),
284            connect_tasks: JoinSet::new(),
285            io_tasks: JoinSet::new(),
286            ping_interval: interval,
287            #[cfg(feature = "metrics")]
288            metrics: Arc::new(metrics),
289        };
290
291        Ok(Self {
292            name,
293            label,
294            parties: Mutex::new(parties),
295            rx: AsyncMutex::new(irx),
296            tx: otx,
297            srv: spawn(server.run(listener)),
298            max_message_size: mmsze,
299        })
300    }
301
302    pub fn public_key(&self) -> &K {
303        &self.label
304    }
305
306    pub fn name(&self) -> &str {
307        self.name
308    }
309
310    pub fn parties(&self, r: Role) -> Vec<K> {
311        self.parties
312            .lock()
313            .iter()
314            .filter(|&(_, x)| r == *x)
315            .map(|(k, _)| k.clone())
316            .collect()
317    }
318
319    /// Send a message to a party, identified by the given public key.
320    pub async fn unicast(&self, to: K, msg: Bytes) -> Result<()> {
321        if msg.len() > self.max_message_size {
322            warn!(
323                name = %self.name,
324                node = %self.label,
325                to   = %to,
326                len  = %msg.len(),
327                max  = %self.max_message_size,
328                "message too large to send"
329            );
330            return Err(NetworkError::MessageTooLarge);
331        }
332        self.tx
333            .send(Command::Unicast(to, None, msg))
334            .await
335            .map_err(|_| NetworkError::ChannelClosed)
336    }
337
338    /// Send a message to all parties.
339    pub async fn broadcast(&self, msg: Bytes) -> Result<()> {
340        if msg.len() > self.max_message_size {
341            warn!(
342                name = %self.name,
343                node = %self.label,
344                len  = %msg.len(),
345                max  = %self.max_message_size,
346                "message too large to broadcast"
347            );
348            return Err(NetworkError::MessageTooLarge);
349        }
350        self.tx
351            .send(Command::Broadcast(None, msg))
352            .await
353            .map_err(|_| NetworkError::ChannelClosed)
354    }
355
356    /// Receive a message from a remote party.
357    pub async fn receive(&self) -> Result<(K, Bytes)> {
358        let mut rx = self.rx.lock().await;
359        let (k, b, _) = rx.recv().await.ok_or(NetworkError::ChannelClosed)?;
360        Ok((k, b))
361    }
362
363    /// Add the given peers to the network.
364    ///
365    /// NB that peers added here are passive. See `Network::assign` for
366    /// giving peers a different `Role`.
367    pub async fn add(&self, peers: Vec<(K, PublicKey, Address)>) -> Result<()> {
368        self.parties
369            .lock()
370            .extend(peers.iter().map(|(p, ..)| (p.clone(), Role::Passive)));
371        self.tx
372            .send(Command::Add(peers))
373            .await
374            .map_err(|_| NetworkError::ChannelClosed)
375    }
376
377    /// Remove the given peers from the network.
378    pub async fn remove(&self, peers: Vec<K>) -> Result<()> {
379        {
380            let mut parties = self.parties.lock();
381            for p in &peers {
382                parties.remove(p);
383            }
384        }
385        self.tx
386            .send(Command::Remove(peers))
387            .await
388            .map_err(|_| NetworkError::ChannelClosed)
389    }
390
391    /// Assign the given role to the given peers.
392    pub async fn assign(&self, r: Role, peers: Vec<K>) -> Result<()> {
393        {
394            let mut parties = self.parties.lock();
395            for p in &peers {
396                if let Some(role) = parties.get_mut(p) {
397                    *role = r
398                }
399            }
400        }
401        self.tx
402            .send(Command::Assign(r, peers))
403            .await
404            .map_err(|_| NetworkError::ChannelClosed)
405    }
406
407    /// Get a clone of the MPSC sender.
408    pub(crate) fn sender(&self) -> Sender<Command<K>> {
409        self.tx.clone()
410    }
411}
412
413impl<K> Server<K>
414where
415    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
416{
417    /// Runs the main loop of this network node.
418    ///
419    /// This function:
420    ///
421    /// - Tries to connect to each remote peer in the committee.
422    /// - Handles tasks that have been completed or terminated.
423    /// - Processes new messages we received on the network.
424    async fn run(mut self, listener: TcpListener) -> Result<Empty> {
425        self.handshake_tasks.spawn(pending());
426        self.io_tasks.spawn(pending());
427
428        // Connect to all peers.
429        for k in self
430            .peers
431            .keys()
432            .filter(|k| **k != self.conf.label)
433            .cloned()
434            .collect::<Vec<_>>()
435        {
436            self.spawn_connect(k)
437        }
438
439        loop {
440            trace!(
441                name       = %self.conf.name,
442                node       = %self.conf.label,
443                active     = %self.active.len(),
444                connects   = %self.connect_tasks.len(),
445                handshakes = %self.handshake_tasks.len().saturating_sub(1), // -1 for `pending()`
446                io_tasks   = %self.io_tasks.len().saturating_sub(1), // -1 for `pending()`
447                tasks_ids  = %self.task2key.len(),
448                iqueue     = %self.ibound.capacity(),
449                oqueue     = %self.obound.capacity(),
450            );
451
452            #[cfg(feature = "metrics")]
453            {
454                self.metrics.iqueue.set(self.ibound.capacity());
455                self.metrics.oqueue.set(self.obound.capacity());
456            }
457
458            tokio::select! {
459                // Accepted a new connection.
460                i = listener.accept() => match i {
461                    Ok((s, a)) => {
462                        debug!(
463                            name = %self.conf.name,
464                            node = %self.conf.label,
465                            addr = %a,
466                            "accepted connection"
467                        );
468                        self.spawn_handshake(s)
469                    }
470                    Err(e) => {
471                        warn!(
472                            name = %self.conf.name,
473                            node = %self.conf.label,
474                            err  = %e,
475                            "error accepting connection"
476                        )
477                    }
478                },
479                // The handshake of an inbound connection completed.
480                Some(h) = self.handshake_tasks.join_next() => match h {
481                    Ok(Ok((s, t))) => {
482                        let Some((k, peer)) = self.lookup_peer(&t) else {
483                            info!(
484                                name = %self.conf.name,
485                                node = %self.conf.label,
486                                peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
487                                addr = ?s.peer_addr().ok(),
488                                "unknown peer"
489                            );
490                            continue
491                        };
492                        if !self.is_valid_ip(&k, &s) {
493                            warn!(
494                                name = %self.conf.name,
495                                node = %self.conf.label,
496                                peer = %k,
497                                addr = ?s.peer_addr().ok(), "invalid peer ip addr"
498                            );
499                            continue
500                        }
501                        // We only accept connections whose party has a public key that
502                        // is larger than ours, or if we do not have a connection for
503                        // that key at the moment.
504                        if k > self.conf.label || !self.active.contains_key(&k) {
505                            self.spawn_io(k, s, t, peer.budget.clone())
506                        } else {
507                            debug!(
508                                name = %self.conf.name,
509                                node = %self.conf.label,
510                                peer = %k,
511                                "dropping accepted connection"
512                            );
513                        }
514                    }
515                    Ok(Err(e)) => {
516                        warn!(
517                            name = %self.conf.name,
518                            node = %self.conf.label,
519                            err  = %e,
520                            "handshake failed"
521                        )
522                    }
523                    Err(e) => {
524                        if !e.is_cancelled() {
525                            error!(
526                                name = %self.conf.name,
527                                node = %self.conf.label,
528                                err  = %e,
529                                "handshake task panic"
530                            )
531                        }
532                    }
533                },
534                // One of our connection attempts completed.
535                Some(tt) = self.connect_tasks.join_next_with_id() => {
536                    match tt {
537                        Ok((id, (s, t))) => {
538                            self.on_connect_task_end(id);
539                            let Some((k, peer)) = self.lookup_peer(&t) else {
540                                warn!(
541                                    name = %self.conf.name,
542                                    node = %self.conf.label,
543                                    peer = ?t.get_remote_static().and_then(|k| PublicKey::try_from(k).ok()),
544                                    addr = ?s.peer_addr().ok(),
545                                    "connected to unknown peer"
546                                );
547                                continue
548                            };
549                            // We only keep the connection if our key is larger than the remote,
550                            // or if we do not have a connection for that key at the moment.
551                            if k < self.conf.label || !self.active.contains_key(&k) {
552                                self.spawn_io(k, s, t, peer.budget.clone())
553                            } else {
554                                debug!(
555                                    name = %self.conf.name,
556                                    node = %self.conf.label,
557                                    peer = %k,
558                                    "dropping new connection"
559                                )
560                            }
561                        }
562                        Err(e) => {
563                            if !e.is_cancelled() {
564                                error!(
565                                    name = %self.conf.name,
566                                    node = %self.conf.label,
567                                    err  = %e,
568                                    "connect task panic"
569                                )
570                            }
571                            self.on_connect_task_end(e.id());
572                        }
573                    }
574                },
575                // A read or write task completed.
576                Some(io) = self.io_tasks.join_next_with_id() => {
577                    match io {
578                        Ok((id, r)) => {
579                            if let Err(e) = r {
580                                warn!(
581                                    name = %self.conf.name,
582                                    node = %self.conf.label,
583                                    err  = %e,
584                                    "i/o error"
585                                )
586                            }
587                            self.on_io_task_end(id);
588                        }
589                        Err(e) => {
590                            if e.is_cancelled() {
591                                // If one half completes we cancel the other, so there is
592                                // nothing else to do here, except to remove the cancelled
593                                // tasks's ID. Same if we kill the connection, both tasks
594                                // get cancelled.
595                                self.task2key.remove(&e.id());
596                                continue
597                            }
598                            // If the task has not been cancelled, it must have panicked.
599                            error!(
600                                name = %self.conf.name,
601                                node = %self.conf.label,
602                                err  = %e,
603                                "i/o task panic"
604                            );
605                            self.on_io_task_end(e.id())
606                        }
607                    };
608                },
609                cmd = self.obound.recv() => match cmd {
610                    Some(Command::Add(peers)) => {
611                        #[cfg(feature = "metrics")]
612                        Arc::make_mut(&mut self.metrics).add_parties(peers.iter().map(|(k, ..)| k).cloned());
613                        for (k, x, a) in peers {
614                            if self.peers.contains_key(&k) {
615                                warn!(
616                                    name = %self.conf.name,
617                                    node = %self.conf.label,
618                                    peer = %k,
619                                    "peer to add already exists"
620                                );
621                                continue
622                            }
623                            info!(
624                                name = %self.conf.name,
625                                node = %self.conf.label,
626                                peer = %k,
627                                "adding peer"
628                            );
629                            let p = Peer {
630                                addr: a,
631                                role: Role::Passive,
632                                budget: self.conf.new_budget()
633                            };
634                            self.peers.insert(k.clone(), p);
635                            self.index.insert(k.clone(), x);
636                            self.spawn_connect(k)
637                        }
638                    }
639                    Some(Command::Remove(peers)) => {
640                        for k in &peers {
641                            info!(
642                                name = %self.conf.name,
643                                node = %self.conf.label,
644                                peer = %k,
645                                "removing peer"
646                            );
647                            self.peers.remove(k);
648                            self.index.remove_by_left(k);
649                            self.connecting.remove(k);
650                            self.active.remove(k);
651                        }
652                        #[cfg(feature = "metrics")]
653                        Arc::make_mut(&mut self.metrics).remove_parties(&peers)
654                    }
655                    Some(Command::Assign(role, peers)) => {
656                        for k in &peers {
657                            if let Some(p) = self.peers.get_mut(k) {
658                                p.role = role
659                            } else {
660                                warn!(
661                                    name = %self.conf.name,
662                                    node = %self.conf.label,
663                                    peer = %k,
664                                    role = ?role,
665                                    "peer to assign role to not found"
666                                );
667                            }
668                        }
669                    }
670                    Some(Command::Unicast(to, id, m)) => {
671                        if to == self.conf.label {
672                            trace!(
673                                name  = %self.conf.name,
674                                node  = %self.conf.label,
675                                to    = %to,
676                                len   = %m.len(),
677                                queue = self.ibound.capacity(),
678                                "sending message"
679                            );
680                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m, None)) {
681                                warn!(
682                                    name = %self.conf.name,
683                                    node = %self.conf.label,
684                                    err  = %err,
685                                    cap  = %self.ibound.capacity(),
686                                    "channel full => dropping unicast message"
687                                )
688                            }
689                            continue
690                        }
691                        if let Some(task) = self.active.get(&to) {
692                            trace!(
693                                name  = %self.conf.name,
694                                node  = %self.conf.label,
695                                to    = %to,
696                                len   = %m.len(),
697                                queue = task.tx.capacity(),
698                                "sending message"
699                            );
700                            #[cfg(feature = "metrics")]
701                            self.metrics.set_peer_oqueue_cap(&to, task.tx.capacity());
702                            task.tx.send(id, Message::Data(m))
703                        }
704                    }
705                    Some(Command::Multicast(peers, id, m)) => {
706                        if peers.contains(&self.conf.label) {
707                            trace!(
708                                name  = %self.conf.name,
709                                node  = %self.conf.label,
710                                to    = %self.conf.label,
711                                len   = %m.len(),
712                                queue = self.ibound.capacity(),
713                                "sending message"
714                            );
715                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
716                                warn!(
717                                    name = %self.conf.name,
718                                    node = %self.conf.label,
719                                    err  = %err,
720                                    cap  = %self.ibound.capacity(),
721                                    "channel full => dropping multicast message"
722                                )
723                            }
724                        }
725                        for (to, task) in &self.active {
726                            if !peers.contains(to) {
727                                continue
728                            }
729                            trace!(
730                                name  = %self.conf.name,
731                                node  = %self.conf.label,
732                                to    = %to,
733                                len   = %m.len(),
734                                queue = task.tx.capacity(),
735                                "sending message"
736                            );
737                            #[cfg(feature = "metrics")]
738                            self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
739                            task.tx.send(id, Message::Data(m.clone()))
740                        }
741                    }
742                    Some(Command::Broadcast(id, m)) => {
743                        if self.role.is_active() {
744                            trace!(
745                                name  = %self.conf.name,
746                                node  = %self.conf.label,
747                                to    = %self.conf.label,
748                                len   = %m.len(),
749                                queue = self.ibound.capacity(),
750                                "sending message"
751                            );
752                            if let Err(err) = self.ibound.try_send((self.conf.label.clone(), m.clone(), None)) {
753                                warn!(
754                                    name = %self.conf.name,
755                                    node = %self.conf.label,
756                                    err  = %err,
757                                    cap  = %self.ibound.capacity(),
758                                    "channel full => dropping broadcast message"
759                                )
760                            }
761                        }
762                        for (to, task) in &self.active {
763                            if Some(Role::Active) != self.peers.get(to).map(|p| p.role) {
764                                continue
765                            }
766                            trace!(
767                                name  = %self.conf.name,
768                                node  = %self.conf.label,
769                                to    = %to,
770                                len   = %m.len(),
771                                queue = task.tx.capacity(),
772                                "sending message"
773                            );
774                            #[cfg(feature = "metrics")]
775                            self.metrics.set_peer_oqueue_cap(to, task.tx.capacity());
776                            task.tx.send(id, Message::Data(m.clone()))
777                        }
778                    }
779                    None => {
780                        return Err(NetworkError::ChannelClosed)
781                    }
782                },
783                _ = self.ping_interval.tick() => {
784                    let now = Timestamp::now();
785                    for task in self.active.values() {
786                        task.tx.send(None, Message::Ping(now))
787                    }
788                }
789            }
790        }
791    }
792
793    /// Handles a completed connect task.
794    fn on_connect_task_end(&mut self, id: task::Id) {
795        let Some(k) = self.task2key.remove(&id) else {
796            error!(name = %self.conf.name, node = %self.conf.label, "no key for connect task");
797            return;
798        };
799        self.connecting.remove(&k);
800    }
801
802    /// Handles a completed I/O task.
803    ///
804    /// This function will get the public key of the task that was terminated
805    /// and then cleanly removes the associated I/O task data and re-connects
806    /// to the peer node it was interacting with.
807    fn on_io_task_end(&mut self, id: task::Id) {
808        let Some(k) = self.task2key.remove(&id) else {
809            error!(name = %self.conf.name, node = %self.conf.label, "no key for i/o task");
810            return;
811        };
812        let Some(task) = self.active.get(&k) else {
813            return;
814        };
815        if task.rh.id() == id {
816            debug!(
817                name = %self.conf.name,
818                node = %self.conf.label,
819                peer = %k,
820                "read-half closed => dropping connection"
821            );
822            self.active.remove(&k);
823            self.spawn_connect(k)
824        } else if task.wh.id() == id {
825            debug!(
826                name = %self.conf.name,
827                node = %self.conf.label,
828                peer = %k,
829                "write-half closed => dropping connection"
830            );
831            self.active.remove(&k);
832            self.spawn_connect(k)
833        } else {
834            debug!(
835                name = %self.conf.name,
836                node = %self.conf.label,
837                peer = %k,
838                "i/o task was previously replaced"
839            );
840        }
841    }
842
843    /// Spawns a new connection task to a peer identified by public key.
844    ///
845    /// This function will look up the x25519 public key of the ed25519 key
846    /// and the remote address and then spawn a connection task.
847    fn spawn_connect(&mut self, k: K) {
848        if self.connecting.contains_key(&k) {
849            debug!(
850                name = %self.conf.name,
851                node = %self.conf.label,
852                peer = %k,
853                "connect task already started"
854            );
855            return;
856        }
857        let x = self.index.get_by_left(&k).expect("known public key");
858        let p = self.peers.get(&k).expect("known peer");
859        let h = self.connect_tasks.spawn(connect(
860            self.conf.name,
861            (self.conf.label.clone(), self.conf.keypair.clone()),
862            (k.clone(), *x),
863            p.addr.clone(),
864            self.conf.retry_delays,
865            #[cfg(feature = "metrics")]
866            self.metrics.clone(),
867        ));
868        assert!(self.task2key.insert(h.id(), k.clone()).is_none());
869        self.connecting.insert(k, ConnectTask { h });
870    }
871
872    /// Spawns a new `Noise` responder handshake task using the IK pattern.
873    ///
874    /// This function will create the responder handshake machine using its
875    /// own private key and then spawn a task that awaits an initiator handshake
876    /// to which it will respond.
877    fn spawn_handshake(&mut self, s: TcpStream) {
878        let h = Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
879            .local_private_key(&self.conf.keypair.secret_key().as_bytes())
880            .expect("valid private key")
881            .prologue(self.conf.name.as_bytes())
882            .expect("1st time we set the prologue")
883            .build_responder()
884            .expect("valid noise params yield valid handshake state");
885        self.handshake_tasks.spawn(async move {
886            timeout(HANDSHAKE_TIMEOUT, on_handshake(h, s))
887                .await
888                .or(Err(NetworkError::Timeout))?
889        });
890    }
891
892    /// Spawns a new I/O task for handling communication with a remote peer over
893    /// a TCP connection using the noise framework to create an authenticated
894    /// secure link.
895    fn spawn_io(&mut self, k: K, s: TcpStream, t: TransportState, b: Budget) {
896        debug!(
897            name = %self.conf.name,
898            node = %self.conf.label,
899            peer = %k,
900            addr = ?s.peer_addr().ok(),
901            "starting i/o tasks"
902        );
903        let (to_remote, from_remote) = chan::channel(self.conf.peer_capacity_egress);
904        let (r, w) = s.into_split();
905        let t1 = Arc::new(Mutex::new(t));
906        let t2 = t1.clone();
907        let ibound = self.ibound.clone();
908        let to_write = to_remote.clone();
909        let countdown = Countdown::new();
910        let rh = self.io_tasks.spawn(recv_loop(
911            self.conf.name,
912            k.clone(),
913            r,
914            t1,
915            ibound,
916            to_write,
917            #[cfg(feature = "metrics")]
918            self.metrics.clone(),
919            b,
920            countdown.clone(),
921            self.conf.max_message_size,
922        ));
923        let wh = self
924            .io_tasks
925            .spawn(send_loop(w, t2, from_remote, countdown));
926        assert!(self.task2key.insert(rh.id(), k.clone()).is_none());
927        assert!(self.task2key.insert(wh.id(), k.clone()).is_none());
928        let io = IoTask {
929            rh,
930            wh,
931            tx: to_remote,
932        };
933        self.active.insert(k, io);
934        #[cfg(feature = "metrics")]
935        self.metrics.connections.set(self.active.len());
936    }
937
938    /// Get the public key of a party by their static X25519 public key.
939    fn lookup_peer(&self, t: &TransportState) -> Option<(K, &Peer)> {
940        let x = t.get_remote_static()?;
941        let x = PublicKey::try_from(x).ok()?;
942        let k = self.index.get_by_right(&x)?;
943        self.peers.get(k).map(|p| (k.clone(), p))
944    }
945
946    /// Check if the socket's peer IP address corresponds to the configured one.
947    fn is_valid_ip(&self, k: &K, s: &TcpStream) -> bool {
948        self.peers
949            .get(k)
950            .map(|p| {
951                let Address::Inet(ip, _) = p.addr else {
952                    return true;
953                };
954                Some(ip) == s.peer_addr().ok().map(|a| a.ip())
955            })
956            .unwrap_or(false)
957    }
958}
959
960/// Connect to the given socket address.
961///
962/// This function will only return, when a connection has been established and the handshake
963/// has been completed.
964async fn connect<K>(
965    name: &'static str,
966    this: (K, Keypair),
967    to: (K, PublicKey),
968    addr: Address,
969    delays: [u8; NUM_DELAYS],
970    #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
971) -> (TcpStream, TransportState)
972where
973    K: Eq + Hash + Display + Clone,
974{
975    use rand::prelude::*;
976
977    let new_handshake_state = || {
978        Builder::new(NOISE_PARAMS.parse().expect("valid noise params"))
979            .local_private_key(this.1.secret_key().as_slice())
980            .expect("valid private key")
981            .remote_public_key(to.1.as_slice())
982            .expect("valid remote pub key")
983            .prologue(name.as_bytes())
984            .expect("1st time we set the prologue")
985            .build_initiator()
986            .expect("valid noise params yield valid handshake state")
987    };
988
989    let delays = once(rand::rng().random_range(0..=1000))
990        .chain(delays.into_iter().map(|d| u64::from(d) * 1000))
991        .chain(repeat(u64::from(delays[LAST_DELAY]) * 1000));
992
993    let addr = addr.to_string();
994
995    for d in delays {
996        sleep(Duration::from_millis(d)).await;
997        debug!(%name, node = %this.0, peer = %to.0, %addr, "connecting");
998        #[cfg(feature = "metrics")]
999        metrics.add_connect_attempt(&to.0);
1000        match timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
1001            Ok(Ok(s)) => {
1002                if let Err(err) = s.set_nodelay(true) {
1003                    error!(%name, node = %this.0, %err, "failed to set NO_DELAY socket option");
1004                    continue;
1005                }
1006                match timeout(HANDSHAKE_TIMEOUT, handshake(new_handshake_state(), s)).await {
1007                    Ok(Ok(x)) => {
1008                        debug!(%name, node = %this.0, peer = %to.0, %addr, "connection established");
1009                        return x;
1010                    },
1011                    Ok(Err(err)) => {
1012                        warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "handshake failure");
1013                    },
1014                    Err(_) => {
1015                        warn!(%name, node = %this.0, peer = %to.0, %addr, "handshake timeout");
1016                    },
1017                }
1018            },
1019            Ok(Err(err)) => {
1020                warn!(%name, node = %this.0, peer = %to.0, %addr, %err, "failed to connect");
1021            },
1022            Err(_) => {
1023                warn!(%name, node = %this.0, peer = %to.0, %addr, "connect timeout");
1024            },
1025        }
1026    }
1027
1028    unreachable!("for loop repeats forever")
1029}
1030
1031/// Perform a noise handshake as initiator with the remote party.
1032async fn handshake(
1033    mut hs: HandshakeState,
1034    mut stream: TcpStream,
1035) -> Result<(TcpStream, TransportState)> {
1036    let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1037    let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1038    let h = Header::data(n as u16);
1039    send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1040    let (h, m) = recv_frame(&mut stream).await?;
1041    if !h.is_data() || h.is_partial() {
1042        return Err(NetworkError::InvalidHandshakeMessage);
1043    }
1044    hs.read_message(&m, &mut b)?;
1045    Ok((stream, hs.into_transport_mode()?))
1046}
1047
1048/// Perform a noise handshake as responder with a remote party.
1049async fn on_handshake(
1050    mut hs: HandshakeState,
1051    mut stream: TcpStream,
1052) -> Result<(TcpStream, TransportState)> {
1053    stream.set_nodelay(true)?;
1054    let (h, m) = recv_frame(&mut stream).await?;
1055    if !h.is_data() || h.is_partial() {
1056        return Err(NetworkError::InvalidHandshakeMessage);
1057    }
1058    let mut b = vec![0; MAX_NOISE_HANDSHAKE_SIZE];
1059    hs.read_message(&m, &mut b)?;
1060    let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
1061    let h = Header::data(n as u16);
1062    send_frame(&mut stream, h, &mut b[..Header::SIZE + n]).await?;
1063    Ok((stream, hs.into_transport_mode()?))
1064}
1065
1066/// Read messages from the remote by assembling frames together.
1067///
1068/// Once complete the message will be handed over to the given MPSC sender.
1069#[allow(clippy::too_many_arguments)]
1070async fn recv_loop<R, K>(
1071    name: &'static str,
1072    id: K,
1073    mut reader: R,
1074    state: Arc<Mutex<TransportState>>,
1075    to_deliver: Sender<(K, Bytes, Option<OwnedSemaphorePermit>)>,
1076    to_writer: chan::Sender<Message>,
1077    #[cfg(feature = "metrics")] metrics: Arc<NetworkMetrics<K>>,
1078    budget: Arc<Semaphore>,
1079    mut countdown: Countdown,
1080    max_message_size: usize,
1081) -> Result<()>
1082where
1083    R: AsyncRead + Unpin,
1084    K: Eq + Hash + Display + Clone,
1085{
1086    let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1087    loop {
1088        #[cfg(feature = "metrics")]
1089        metrics.set_peer_iqueue_cap(&id, budget.available_permits());
1090        let permit = budget
1091            .clone()
1092            .acquire_owned()
1093            .await
1094            .map_err(|_| NetworkError::BudgetClosed)?;
1095        let mut msg = BytesMut::new();
1096        loop {
1097            tokio::select! {
1098                val = recv_frame(&mut reader) => {
1099                    countdown.stop();
1100                    match val {
1101                        Ok((h, f)) => {
1102                            match h.frame_type() {
1103                                Ok(Type::Ping) => {
1104                                    // Received ping message; sending pong to writer
1105                                    let n = state.lock().read_message(&f, &mut buf)?;
1106                                    if let Some(ping) = Timestamp::try_from_slice(&buf[..n]) {
1107                                        to_writer.send(None, Message::Pong(ping))
1108                                    }
1109                                }
1110                                Ok(Type::Pong) => {
1111                                    // Received pong message; measure elapsed time
1112                                    let _n = state.lock().read_message(&f, &mut buf)?;
1113                                    #[cfg(feature = "metrics")]
1114                                    if let Some(ping) = Timestamp::try_from_slice(&buf[.._n])
1115                                        && let Some(delay) = Timestamp::now().diff(ping)
1116                                    {
1117                                        metrics.set_latency(&id, delay)
1118                                    }
1119                                }
1120                                Ok(Type::Data) => {
1121                                    let n = state.lock().read_message(&f, &mut buf)?;
1122                                    msg.extend_from_slice(&buf[..n]);
1123                                    if !h.is_partial() {
1124                                        break;
1125                                    }
1126                                    if msg.len() > max_message_size {
1127                                        return Err(NetworkError::MessageTooLarge);
1128                                    }
1129                                }
1130                                Err(t) => return Err(NetworkError::UnknownFrameType(t)),
1131                            }
1132                        }
1133                        Err(e) => return Err(e)
1134                    }
1135                },
1136                () = &mut countdown => {
1137                    warn!(%name, node = %id, "timeout waiting for peer");
1138                    return Err(NetworkError::Timeout)
1139                }
1140            }
1141        }
1142        if to_deliver
1143            .send((id.clone(), msg.freeze(), Some(permit)))
1144            .await
1145            .is_err()
1146        {
1147            break;
1148        }
1149    }
1150    Ok(())
1151}
1152
1153/// Consume messages to be delivered to remote parties and send them.
1154///
1155/// The function automatically splits large messages into chunks that fit into
1156/// a noise package.
1157async fn send_loop<W>(
1158    mut writer: W,
1159    state: Arc<Mutex<TransportState>>,
1160    rx: chan::Receiver<Message>,
1161    countdown: Countdown,
1162) -> Result<()>
1163where
1164    W: AsyncWrite + Unpin,
1165{
1166    let mut buf = vec![0; MAX_NOISE_MESSAGE_SIZE];
1167
1168    while let Some(msg) = rx.recv().await {
1169        match msg {
1170            Message::Ping(ping) => {
1171                let n = state
1172                    .lock()
1173                    .write_message(&ping.to_bytes()[..], &mut buf[Header::SIZE..])?;
1174                let h = Header::ping(n as u16);
1175                send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?;
1176                countdown.start(REPLY_TIMEOUT)
1177            },
1178            Message::Pong(pong) => {
1179                let n = state
1180                    .lock()
1181                    .write_message(&pong.to_bytes()[..], &mut buf[Header::SIZE..])?;
1182                let h = Header::pong(n as u16);
1183                send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1184            },
1185            Message::Data(msg) => {
1186                let mut it = msg.chunks(MAX_PAYLOAD_SIZE).peekable();
1187                while let Some(m) = it.next() {
1188                    let n = state.lock().write_message(m, &mut buf[Header::SIZE..])?;
1189                    let h = if it.peek().is_some() {
1190                        Header::data(n as u16).partial()
1191                    } else {
1192                        Header::data(n as u16)
1193                    };
1194                    send_frame(&mut writer, h, &mut buf[..Header::SIZE + n]).await?
1195                }
1196            },
1197        }
1198    }
1199    Ok(())
1200}
1201
1202/// Read a single frame (header + payload) from the remote.
1203async fn recv_frame<R>(r: &mut R) -> Result<(Header, Vec<u8>)>
1204where
1205    R: AsyncRead + Unpin,
1206{
1207    let b = r.read_u32().await?;
1208    let h = Header::try_from(b.to_be_bytes())?;
1209    let mut v = vec![0; h.len().into()];
1210    r.read_exact(&mut v).await?;
1211    Ok((h, v))
1212}
1213
1214/// Write a single frame (header + payload) to the remote.
1215///
1216/// The header is serialised into the first 4 bytes of `msg`. It is the
1217/// caller's responsibility to ensure there is room at the beginning.
1218async fn send_frame<W>(w: &mut W, hdr: Header, msg: &mut [u8]) -> Result<()>
1219where
1220    W: AsyncWrite + Unpin,
1221{
1222    debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
1223    msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
1224    w.write_all(msg).await?;
1225    Ok(())
1226}