hotshot_libp2p_networking/network/node/
handle.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11    network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{request_response::ResponseChannel, Multiaddr};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17    sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18    time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23    behaviours::dht::{
24        record::{Namespace, RecordKey, RecordValue},
25        store::persistent::DhtPersistentStorage,
26    },
27    gen_multiaddr, ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
28};
29
30/// A handle containing:
31/// - A reference to the state
32/// - Controls for the swarm
33#[derive(Debug, Clone)]
34pub struct NetworkNodeHandle<T: NodeType> {
35    /// network configuration
36    network_config: NetworkNodeConfig<T>,
37
38    /// send an action to the networkbehaviour
39    send_network: UnboundedSender<ClientRequest>,
40
41    /// The map from consensus keys to peer IDs
42    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
43
44    /// the local address we're listening on
45    listen_addr: Multiaddr,
46
47    /// the peer id of the networkbehaviour
48    peer_id: PeerId,
49
50    /// human readable id
51    id: usize,
52}
53
54/// internal network node receiver
55#[derive(Debug)]
56pub struct NetworkNodeReceiver {
57    /// the receiver
58    receiver: UnboundedReceiver<NetworkEvent>,
59
60    ///kill switch
61    recv_kill: Option<Receiver<()>>,
62}
63
64impl NetworkNodeReceiver {
65    /// recv a network event
66    /// # Errors
67    /// Errors if the receiver channel is closed
68    pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
69        self.receiver
70            .recv()
71            .await
72            .ok_or(NetworkError::ChannelReceiveError(
73                "Receiver channel closed".to_string(),
74            ))
75    }
76    /// Add a kill switch to the receiver
77    pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
78        self.recv_kill = Some(kill_switch);
79    }
80
81    /// Take the kill switch to allow killing the receiver task
82    pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
83        self.recv_kill.take()
84    }
85}
86
87/// Spawn a network node task task and return the handle and the receiver for it
88/// # Errors
89/// Errors if spawning the task fails
90pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
91    config: NetworkNodeConfig<T>,
92    dht_persistent_storage: D,
93    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
94    id: usize,
95) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
96    let mut network = NetworkNode::new(
97        config.clone(),
98        dht_persistent_storage,
99        Arc::clone(&consensus_key_to_pid_map),
100    )
101    .await
102    .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
103    // randomly assigned port
104    let listen_addr = config
105        .bind_address
106        .clone()
107        .unwrap_or_else(|| gen_multiaddr(0));
108    let peer_id = network.peer_id();
109    let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
110        NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
111    })?;
112    // pin here to force the future onto the heap since it can be large
113    // in the case of flume
114    let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
115        NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
116    })?;
117    let receiver = NetworkNodeReceiver {
118        receiver: recv_chan,
119        recv_kill: None,
120    };
121
122    let handle = NetworkNodeHandle::<T> {
123        network_config: config,
124        send_network: send_chan,
125        consensus_key_to_pid_map,
126        listen_addr,
127        peer_id,
128        id,
129    };
130    Ok((receiver, handle))
131}
132
133impl<T: NodeType> NetworkNodeHandle<T> {
134    /// Cleanly shuts down a swarm node
135    /// This is done by sending a message to
136    /// the swarm itself to spin down
137    #[instrument]
138    pub async fn shutdown(&self) -> Result<(), NetworkError> {
139        self.send_request(ClientRequest::Shutdown)?;
140        Ok(())
141    }
142    /// Notify the network to begin the bootstrap process
143    /// # Errors
144    /// If unable to send via `send_network`. This should only happen
145    /// if the network is shut down.
146    pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
147        let req = ClientRequest::BeginBootstrap;
148        self.send_request(req)
149    }
150
151    /// Get a reference to the network node handle's listen addr.
152    #[must_use]
153    pub fn listen_addr(&self) -> Multiaddr {
154        self.listen_addr.clone()
155    }
156
157    /// Print out the routing table used by kademlia
158    /// NOTE: only for debugging purposes currently
159    /// # Errors
160    /// if the client has stopped listening for a response
161    pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
162        let (s, r) = futures::channel::oneshot::channel();
163        let req = ClientRequest::GetRoutingTable(s);
164        self.send_request(req)?;
165        r.await
166            .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
167    }
168    /// Wait until at least `num_peers` have connected
169    ///
170    /// # Errors
171    /// If the channel closes before the result can be sent back
172    pub async fn wait_to_connect(
173        &self,
174        num_required_peers: usize,
175        node_id: usize,
176    ) -> Result<(), NetworkError> {
177        // Wait for the required number of peers to connect
178        loop {
179            // Get the number of currently connected peers
180            let num_connected = self.num_connected().await?;
181            if num_connected >= num_required_peers {
182                break;
183            }
184
185            // Log the number of connected peers
186            info!(
187                "Node {} connected to {}/{} peers",
188                node_id, num_connected, num_required_peers
189            );
190
191            // Sleep for a second before checking again
192            sleep(Duration::from_secs(1)).await;
193        }
194
195        Ok(())
196    }
197
198    /// Look up a peer's addresses in kademlia
199    /// NOTE: this should always be called before any `request_response` is initiated
200    /// # Errors
201    /// if the client has stopped listening for a response
202    pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
203        let (s, r) = futures::channel::oneshot::channel();
204        let req = ClientRequest::LookupPeer(peer_id, s);
205        self.send_request(req)?;
206        r.await
207            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
208    }
209
210    /// Looks up a node's `PeerId` by its consensus key.
211    ///
212    /// # Errors
213    /// If the DHT lookup fails
214    pub async fn lookup_node(
215        &self,
216        consensus_key: &T::SignatureKey,
217        dht_timeout: Duration,
218    ) -> Result<PeerId, NetworkError> {
219        // First check if we already have an open connection to the peer
220        if let Some(pid) = self
221            .consensus_key_to_pid_map
222            .lock()
223            .get_by_left(consensus_key)
224        {
225            return Ok(*pid);
226        }
227
228        // Create the record key
229        let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
230
231        // Get the record from the DHT
232        let pid = self.get_record_timeout(key, dht_timeout).await?;
233
234        PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
235    }
236
237    /// Insert a record into the kademlia DHT
238    /// # Errors
239    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
240    pub async fn put_record(
241        &self,
242        key: RecordKey,
243        value: RecordValue<T::SignatureKey>,
244    ) -> Result<(), NetworkError> {
245        // Serialize the key
246        let key = key.to_bytes();
247
248        // Serialize the record
249        let value = bincode::serialize(&value)
250            .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
251
252        let (s, r) = futures::channel::oneshot::channel();
253        let req = ClientRequest::PutDHT {
254            key: key.clone(),
255            value,
256            notify: s,
257        };
258
259        self.send_request(req)?;
260
261        r.await.map_err(|_| NetworkError::RequestCancelled)
262    }
263
264    /// Receive a record from the kademlia DHT if it exists.
265    /// Must be replicated on at least 2 nodes
266    /// # Errors
267    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key
268    /// - Will return [`NetworkError::FailedToDeserialize`] when unable to deserialize the returned value
269    pub async fn get_record(
270        &self,
271        key: RecordKey,
272        retry_count: u8,
273    ) -> Result<Vec<u8>, NetworkError> {
274        // Serialize the key
275        let serialized_key = key.to_bytes();
276
277        let (s, r) = futures::channel::oneshot::channel();
278        let req = ClientRequest::GetDHT {
279            key: serialized_key.clone(),
280            notify: vec![s],
281            retry_count,
282        };
283        self.send_request(req)?;
284
285        // Map the error
286        let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
287
288        // Deserialize the record's value
289        let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
290            .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
291
292        Ok(record.value().to_vec())
293    }
294
295    /// Get a record from the kademlia DHT with a timeout
296    /// # Errors
297    /// - Will return [`NetworkError::Timeout`] when times out
298    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
299    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
300    pub async fn get_record_timeout(
301        &self,
302        key: RecordKey,
303        timeout_duration: Duration,
304    ) -> Result<Vec<u8>, NetworkError> {
305        timeout(timeout_duration, self.get_record(key, 3))
306            .await
307            .map_err(|err| NetworkError::Timeout(err.to_string()))?
308    }
309
310    /// Insert a record into the kademlia DHT with a timeout
311    /// # Errors
312    /// - Will return [`NetworkError::Timeout`] when times out
313    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
314    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
315    pub async fn put_record_timeout(
316        &self,
317        key: RecordKey,
318        value: RecordValue<T::SignatureKey>,
319        timeout_duration: Duration,
320    ) -> Result<(), NetworkError> {
321        timeout(timeout_duration, self.put_record(key, value))
322            .await
323            .map_err(|err| NetworkError::Timeout(err.to_string()))?
324    }
325
326    /// Subscribe to a topic
327    /// # Errors
328    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
329    pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
330        let (s, r) = futures::channel::oneshot::channel();
331        let req = ClientRequest::Subscribe(topic, Some(s));
332        self.send_request(req)?;
333        r.await
334            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
335    }
336
337    /// Unsubscribe from a topic
338    /// # Errors
339    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
340    pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
341        let (s, r) = futures::channel::oneshot::channel();
342        let req = ClientRequest::Unsubscribe(topic, Some(s));
343        self.send_request(req)?;
344        r.await
345            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
346    }
347
348    /// Ignore `peers` when pruning
349    /// e.g. maintain their connection
350    /// # Errors
351    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
352    pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
353        let req = ClientRequest::IgnorePeers(peers);
354        self.send_request(req)
355    }
356
357    /// Make a direct request to `peer_id` containing `msg`
358    /// # Errors
359    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
360    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
361    pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
362        self.direct_request_no_serialize(pid, msg.to_vec())
363    }
364
365    /// Make a direct request to `peer_id` containing `msg` without serializing
366    /// # Errors
367    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
368    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
369    pub fn direct_request_no_serialize(
370        &self,
371        pid: PeerId,
372        contents: Vec<u8>,
373    ) -> Result<(), NetworkError> {
374        let req = ClientRequest::DirectRequest {
375            pid,
376            contents,
377            retry_count: 1,
378        };
379        self.send_request(req)
380    }
381
382    /// Reply with `msg` to a request over `chan`
383    /// # Errors
384    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
385    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
386    pub fn direct_response(
387        &self,
388        chan: ResponseChannel<Vec<u8>>,
389        msg: &[u8],
390    ) -> Result<(), NetworkError> {
391        let req = ClientRequest::DirectResponse(chan, msg.to_vec());
392        self.send_request(req)
393    }
394
395    /// Forcefully disconnect from a peer
396    /// # Errors
397    /// If the channel is closed somehow
398    /// Shouldnt' happen.
399    /// # Panics
400    /// If channel errors out
401    /// shouldn't happen.
402    pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
403        let req = ClientRequest::Prune(pid);
404        self.send_request(req)
405    }
406
407    /// Gossip a message to peers
408    /// # Errors
409    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
410    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
411    pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
412        self.gossip_no_serialize(topic, msg.to_vec())
413    }
414
415    /// Gossip a message to peers without serializing
416    /// # Errors
417    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
418    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
419    pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
420        let req = ClientRequest::GossipMsg(topic, msg);
421        self.send_request(req)
422    }
423
424    /// Tell libp2p about known network nodes
425    /// # Errors
426    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
427    pub fn add_known_peers(
428        &self,
429        known_peers: Vec<(PeerId, Multiaddr)>,
430    ) -> Result<(), NetworkError> {
431        debug!("Adding {} known peers", known_peers.len());
432        let req = ClientRequest::AddKnownPeers(known_peers);
433        self.send_request(req)
434    }
435
436    /// Send a client request to the network
437    ///
438    /// # Errors
439    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
440    fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
441        self.send_network
442            .send(req)
443            .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
444    }
445
446    /// Returns number of peers this node is connected to
447    /// # Errors
448    /// If the channel is closed somehow
449    /// Shouldnt' happen.
450    /// # Panics
451    /// If channel errors out
452    /// shouldn't happen.
453    pub async fn num_connected(&self) -> Result<usize, NetworkError> {
454        let (s, r) = futures::channel::oneshot::channel();
455        let req = ClientRequest::GetConnectedPeerNum(s);
456        self.send_request(req)?;
457        Ok(r.await.unwrap())
458    }
459
460    /// return hashset of PIDs this node is connected to
461    /// # Errors
462    /// If the channel is closed somehow
463    /// Shouldnt' happen.
464    /// # Panics
465    /// If channel errors out
466    /// shouldn't happen.
467    pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
468        let (s, r) = futures::channel::oneshot::channel();
469        let req = ClientRequest::GetConnectedPeers(s);
470        self.send_request(req)?;
471        Ok(r.await.unwrap())
472    }
473
474    /// Get a reference to the network node handle's id.
475    #[must_use]
476    pub fn id(&self) -> usize {
477        self.id
478    }
479
480    /// Get a reference to the network node handle's peer id.
481    #[must_use]
482    pub fn peer_id(&self) -> PeerId {
483        self.peer_id
484    }
485
486    /// Return a reference to the network config
487    #[must_use]
488    pub fn config(&self) -> &NetworkNodeConfig<T> {
489        &self.network_config
490    }
491}