hotshot_libp2p_networking/network/node/
handle.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11    network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{Multiaddr, request_response::ResponseChannel};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17    sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18    time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23    ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
24    behaviours::dht::{
25        record::{Namespace, RecordKey, RecordValue},
26        store::persistent::DhtPersistentStorage,
27    },
28    gen_multiaddr,
29};
30
31/// A handle containing:
32/// - A reference to the state
33/// - Controls for the swarm
34#[derive(Debug, Clone)]
35pub struct NetworkNodeHandle<T: NodeType> {
36    /// network configuration
37    network_config: NetworkNodeConfig,
38
39    /// send an action to the networkbehaviour
40    send_network: UnboundedSender<ClientRequest>,
41
42    /// The map from consensus keys to peer IDs
43    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
44
45    /// the local address we're listening on
46    listen_addr: Multiaddr,
47
48    /// the peer id of the networkbehaviour
49    peer_id: PeerId,
50
51    /// human readable id
52    id: usize,
53}
54
55/// internal network node receiver
56#[derive(Debug)]
57pub struct NetworkNodeReceiver {
58    /// the receiver
59    receiver: UnboundedReceiver<NetworkEvent>,
60
61    ///kill switch
62    recv_kill: Option<Receiver<()>>,
63}
64
65impl NetworkNodeReceiver {
66    /// recv a network event
67    /// # Errors
68    /// Errors if the receiver channel is closed
69    pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
70        self.receiver
71            .recv()
72            .await
73            .ok_or(NetworkError::ChannelReceiveError(
74                "Receiver channel closed".to_string(),
75            ))
76    }
77    /// Add a kill switch to the receiver
78    pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
79        self.recv_kill = Some(kill_switch);
80    }
81
82    /// Take the kill switch to allow killing the receiver task
83    pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
84        self.recv_kill.take()
85    }
86}
87
88/// Spawn a network node task task and return the handle and the receiver for it
89/// # Errors
90/// Errors if spawning the task fails
91pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
92    config: NetworkNodeConfig,
93    dht_persistent_storage: D,
94    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
95    id: usize,
96) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
97    let mut network: NetworkNode<T, _> = NetworkNode::new(
98        config.clone(),
99        dht_persistent_storage,
100        Arc::clone(&consensus_key_to_pid_map),
101    )
102    .await
103    .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
104    // randomly assigned port
105    let listen_addr = config
106        .bind_address
107        .clone()
108        .unwrap_or_else(|| gen_multiaddr(0));
109    let peer_id = network.peer_id();
110    let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
111        NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
112    })?;
113    // pin here to force the future onto the heap since it can be large
114    // in the case of flume
115    let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
116        NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
117    })?;
118    let receiver = NetworkNodeReceiver {
119        receiver: recv_chan,
120        recv_kill: None,
121    };
122
123    let handle = NetworkNodeHandle::<T> {
124        network_config: config,
125        send_network: send_chan,
126        consensus_key_to_pid_map,
127        listen_addr,
128        peer_id,
129        id,
130    };
131    Ok((receiver, handle))
132}
133
134impl<T: NodeType> NetworkNodeHandle<T> {
135    /// Cleanly shuts down a swarm node
136    /// This is done by sending a message to
137    /// the swarm itself to spin down
138    #[instrument]
139    pub async fn shutdown(&self) -> Result<(), NetworkError> {
140        self.send_request(ClientRequest::Shutdown)?;
141        Ok(())
142    }
143    /// Notify the network to begin the bootstrap process
144    /// # Errors
145    /// If unable to send via `send_network`. This should only happen
146    /// if the network is shut down.
147    pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
148        let req = ClientRequest::BeginBootstrap;
149        self.send_request(req)
150    }
151
152    /// Get a reference to the network node handle's listen addr.
153    #[must_use]
154    pub fn listen_addr(&self) -> Multiaddr {
155        self.listen_addr.clone()
156    }
157
158    /// Print out the routing table used by kademlia
159    /// NOTE: only for debugging purposes currently
160    /// # Errors
161    /// if the client has stopped listening for a response
162    pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
163        let (s, r) = futures::channel::oneshot::channel();
164        let req = ClientRequest::GetRoutingTable(s);
165        self.send_request(req)?;
166        r.await
167            .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
168    }
169    /// Wait until at least `num_peers` have connected
170    ///
171    /// # Errors
172    /// If the channel closes before the result can be sent back
173    pub async fn wait_to_connect(
174        &self,
175        num_required_peers: usize,
176        node_id: usize,
177    ) -> Result<(), NetworkError> {
178        // Wait for the required number of peers to connect
179        loop {
180            // Get the number of currently connected peers
181            let num_connected = self.num_connected().await?;
182            if num_connected >= num_required_peers {
183                break;
184            }
185
186            // Log the number of connected peers
187            info!(
188                "Node {} connected to {}/{} peers",
189                node_id, num_connected, num_required_peers
190            );
191
192            // Sleep for a second before checking again
193            sleep(Duration::from_secs(1)).await;
194        }
195
196        Ok(())
197    }
198
199    /// Look up a peer's addresses in kademlia
200    /// NOTE: this should always be called before any `request_response` is initiated
201    /// # Errors
202    /// if the client has stopped listening for a response
203    pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
204        let (s, r) = futures::channel::oneshot::channel();
205        let req = ClientRequest::LookupPeer(peer_id, s);
206        self.send_request(req)?;
207        r.await
208            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
209    }
210
211    /// Looks up a node's `PeerId` by its consensus key.
212    ///
213    /// # Errors
214    /// If the DHT lookup fails
215    pub async fn lookup_node(
216        &self,
217        consensus_key: &T::SignatureKey,
218        dht_timeout: Duration,
219    ) -> Result<PeerId, NetworkError> {
220        // First check if we already have an open connection to the peer
221        if let Some(pid) = self
222            .consensus_key_to_pid_map
223            .lock()
224            .get_by_left(consensus_key)
225        {
226            return Ok(*pid);
227        }
228
229        // Create the record key
230        let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
231
232        // Get the record from the DHT
233        let pid = self.get_record_timeout(key, dht_timeout).await?;
234
235        PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
236    }
237
238    /// Insert a record into the kademlia DHT
239    /// # Errors
240    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
241    pub async fn put_record(
242        &self,
243        key: RecordKey,
244        value: RecordValue<T::SignatureKey>,
245    ) -> Result<(), NetworkError> {
246        // Serialize the key
247        let key = key.to_bytes();
248
249        // Serialize the record
250        let value = bincode::serialize(&value)
251            .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
252
253        let (s, r) = futures::channel::oneshot::channel();
254        let req = ClientRequest::PutDHT {
255            key: key.clone(),
256            value,
257            notify: s,
258        };
259
260        self.send_request(req)?;
261
262        r.await.map_err(|_| NetworkError::RequestCancelled)
263    }
264
265    /// Receive a record from the kademlia DHT if it exists.
266    /// Must be replicated on at least 2 nodes
267    /// # Errors
268    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key
269    /// - Will return [`NetworkError::FailedToDeserialize`] when unable to deserialize the returned value
270    pub async fn get_record(
271        &self,
272        key: RecordKey,
273        retry_count: u8,
274    ) -> Result<Vec<u8>, NetworkError> {
275        // Serialize the key
276        let serialized_key = key.to_bytes();
277
278        let (s, r) = futures::channel::oneshot::channel();
279        let req = ClientRequest::GetDHT {
280            key: serialized_key.clone(),
281            notify: vec![s],
282            retry_count,
283        };
284        self.send_request(req)?;
285
286        // Map the error
287        let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
288
289        // Deserialize the record's value
290        let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
291            .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
292
293        Ok(record.value().to_vec())
294    }
295
296    /// Get a record from the kademlia DHT with a timeout
297    /// # Errors
298    /// - Will return [`NetworkError::Timeout`] when times out
299    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
300    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
301    pub async fn get_record_timeout(
302        &self,
303        key: RecordKey,
304        timeout_duration: Duration,
305    ) -> Result<Vec<u8>, NetworkError> {
306        timeout(timeout_duration, self.get_record(key, 3))
307            .await
308            .map_err(|err| NetworkError::Timeout(err.to_string()))?
309    }
310
311    /// Insert a record into the kademlia DHT with a timeout
312    /// # Errors
313    /// - Will return [`NetworkError::Timeout`] when times out
314    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
315    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
316    pub async fn put_record_timeout(
317        &self,
318        key: RecordKey,
319        value: RecordValue<T::SignatureKey>,
320        timeout_duration: Duration,
321    ) -> Result<(), NetworkError> {
322        timeout(timeout_duration, self.put_record(key, value))
323            .await
324            .map_err(|err| NetworkError::Timeout(err.to_string()))?
325    }
326
327    /// Subscribe to a topic
328    /// # Errors
329    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
330    pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
331        let (s, r) = futures::channel::oneshot::channel();
332        let req = ClientRequest::Subscribe(topic, Some(s));
333        self.send_request(req)?;
334        r.await
335            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
336    }
337
338    /// Unsubscribe from a topic
339    /// # Errors
340    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
341    pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
342        let (s, r) = futures::channel::oneshot::channel();
343        let req = ClientRequest::Unsubscribe(topic, Some(s));
344        self.send_request(req)?;
345        r.await
346            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
347    }
348
349    /// Ignore `peers` when pruning
350    /// e.g. maintain their connection
351    /// # Errors
352    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
353    pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
354        let req = ClientRequest::IgnorePeers(peers);
355        self.send_request(req)
356    }
357
358    /// Make a direct request to `peer_id` containing `msg`
359    /// # Errors
360    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
361    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
362    pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
363        self.direct_request_no_serialize(pid, msg.to_vec())
364    }
365
366    /// Make a direct request to `peer_id` containing `msg` without serializing
367    /// # Errors
368    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
369    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
370    pub fn direct_request_no_serialize(
371        &self,
372        pid: PeerId,
373        contents: Vec<u8>,
374    ) -> Result<(), NetworkError> {
375        let req = ClientRequest::DirectRequest {
376            pid,
377            contents,
378            retry_count: 1,
379        };
380        self.send_request(req)
381    }
382
383    /// Reply with `msg` to a request over `chan`
384    /// # Errors
385    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
386    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
387    pub fn direct_response(
388        &self,
389        chan: ResponseChannel<Vec<u8>>,
390        msg: &[u8],
391    ) -> Result<(), NetworkError> {
392        let req = ClientRequest::DirectResponse(chan, msg.to_vec());
393        self.send_request(req)
394    }
395
396    /// Forcefully disconnect from a peer
397    /// # Errors
398    /// If the channel is closed somehow
399    /// Shouldnt' happen.
400    /// # Panics
401    /// If channel errors out
402    /// shouldn't happen.
403    pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
404        let req = ClientRequest::Prune(pid);
405        self.send_request(req)
406    }
407
408    /// Gossip a message to peers
409    /// # Errors
410    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
411    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
412    pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
413        self.gossip_no_serialize(topic, msg.to_vec())
414    }
415
416    /// Gossip a message to peers without serializing
417    /// # Errors
418    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
419    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
420    pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
421        let req = ClientRequest::GossipMsg(topic, msg);
422        self.send_request(req)
423    }
424
425    /// Tell libp2p about known network nodes
426    /// # Errors
427    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
428    pub fn add_known_peers(
429        &self,
430        known_peers: Vec<(PeerId, Multiaddr)>,
431    ) -> Result<(), NetworkError> {
432        debug!("Adding {} known peers", known_peers.len());
433        let req = ClientRequest::AddKnownPeers(known_peers);
434        self.send_request(req)
435    }
436
437    /// Send a client request to the network
438    ///
439    /// # Errors
440    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
441    fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
442        self.send_network
443            .send(req)
444            .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
445    }
446
447    /// Returns number of peers this node is connected to
448    /// # Errors
449    /// If the channel is closed somehow
450    /// Shouldnt' happen.
451    /// # Panics
452    /// If channel errors out
453    /// shouldn't happen.
454    pub async fn num_connected(&self) -> Result<usize, NetworkError> {
455        let (s, r) = futures::channel::oneshot::channel();
456        let req = ClientRequest::GetConnectedPeerNum(s);
457        self.send_request(req)?;
458        Ok(r.await.unwrap())
459    }
460
461    /// return hashset of PIDs this node is connected to
462    /// # Errors
463    /// If the channel is closed somehow
464    /// Shouldnt' happen.
465    /// # Panics
466    /// If channel errors out
467    /// shouldn't happen.
468    pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
469        let (s, r) = futures::channel::oneshot::channel();
470        let req = ClientRequest::GetConnectedPeers(s);
471        self.send_request(req)?;
472        Ok(r.await.unwrap())
473    }
474
475    /// Get a reference to the network node handle's id.
476    #[must_use]
477    pub fn id(&self) -> usize {
478        self.id
479    }
480
481    /// Get a reference to the network node handle's peer id.
482    #[must_use]
483    pub fn peer_id(&self) -> PeerId {
484        self.peer_id
485    }
486
487    /// Return a reference to the network config
488    #[must_use]
489    pub fn config(&self) -> &NetworkNodeConfig {
490        &self.network_config
491    }
492}