hotshot_libp2p_networking/network/node/
handle.rs1use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11 network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{request_response::ResponseChannel, Multiaddr};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17 sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18 time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23 behaviours::dht::{
24 record::{Namespace, RecordKey, RecordValue},
25 store::persistent::DhtPersistentStorage,
26 },
27 gen_multiaddr, ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
28};
29
30#[derive(Debug, Clone)]
34pub struct NetworkNodeHandle<T: NodeType> {
35 network_config: NetworkNodeConfig<T>,
37
38 send_network: UnboundedSender<ClientRequest>,
40
41 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
43
44 listen_addr: Multiaddr,
46
47 peer_id: PeerId,
49
50 id: usize,
52}
53
54#[derive(Debug)]
56pub struct NetworkNodeReceiver {
57 receiver: UnboundedReceiver<NetworkEvent>,
59
60 recv_kill: Option<Receiver<()>>,
62}
63
64impl NetworkNodeReceiver {
65 pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
69 self.receiver
70 .recv()
71 .await
72 .ok_or(NetworkError::ChannelReceiveError(
73 "Receiver channel closed".to_string(),
74 ))
75 }
76 pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
78 self.recv_kill = Some(kill_switch);
79 }
80
81 pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
83 self.recv_kill.take()
84 }
85}
86
87pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
91 config: NetworkNodeConfig<T>,
92 dht_persistent_storage: D,
93 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
94 id: usize,
95) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
96 let mut network = NetworkNode::new(
97 config.clone(),
98 dht_persistent_storage,
99 Arc::clone(&consensus_key_to_pid_map),
100 )
101 .await
102 .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
103 let listen_addr = config
105 .bind_address
106 .clone()
107 .unwrap_or_else(|| gen_multiaddr(0));
108 let peer_id = network.peer_id();
109 let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
110 NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
111 })?;
112 let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
115 NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
116 })?;
117 let receiver = NetworkNodeReceiver {
118 receiver: recv_chan,
119 recv_kill: None,
120 };
121
122 let handle = NetworkNodeHandle::<T> {
123 network_config: config,
124 send_network: send_chan,
125 consensus_key_to_pid_map,
126 listen_addr,
127 peer_id,
128 id,
129 };
130 Ok((receiver, handle))
131}
132
133impl<T: NodeType> NetworkNodeHandle<T> {
134 #[instrument]
138 pub async fn shutdown(&self) -> Result<(), NetworkError> {
139 self.send_request(ClientRequest::Shutdown)?;
140 Ok(())
141 }
142 pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
147 let req = ClientRequest::BeginBootstrap;
148 self.send_request(req)
149 }
150
151 #[must_use]
153 pub fn listen_addr(&self) -> Multiaddr {
154 self.listen_addr.clone()
155 }
156
157 pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
162 let (s, r) = futures::channel::oneshot::channel();
163 let req = ClientRequest::GetRoutingTable(s);
164 self.send_request(req)?;
165 r.await
166 .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
167 }
168 pub async fn wait_to_connect(
173 &self,
174 num_required_peers: usize,
175 node_id: usize,
176 ) -> Result<(), NetworkError> {
177 loop {
179 let num_connected = self.num_connected().await?;
181 if num_connected >= num_required_peers {
182 break;
183 }
184
185 info!(
187 "Node {} connected to {}/{} peers",
188 node_id, num_connected, num_required_peers
189 );
190
191 sleep(Duration::from_secs(1)).await;
193 }
194
195 Ok(())
196 }
197
198 pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
203 let (s, r) = futures::channel::oneshot::channel();
204 let req = ClientRequest::LookupPeer(peer_id, s);
205 self.send_request(req)?;
206 r.await
207 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
208 }
209
210 pub async fn lookup_node(
215 &self,
216 consensus_key: &T::SignatureKey,
217 dht_timeout: Duration,
218 ) -> Result<PeerId, NetworkError> {
219 if let Some(pid) = self
221 .consensus_key_to_pid_map
222 .lock()
223 .get_by_left(consensus_key)
224 {
225 return Ok(*pid);
226 }
227
228 let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
230
231 let pid = self.get_record_timeout(key, dht_timeout).await?;
233
234 PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
235 }
236
237 pub async fn put_record(
241 &self,
242 key: RecordKey,
243 value: RecordValue<T::SignatureKey>,
244 ) -> Result<(), NetworkError> {
245 let key = key.to_bytes();
247
248 let value = bincode::serialize(&value)
250 .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
251
252 let (s, r) = futures::channel::oneshot::channel();
253 let req = ClientRequest::PutDHT {
254 key: key.clone(),
255 value,
256 notify: s,
257 };
258
259 self.send_request(req)?;
260
261 r.await.map_err(|_| NetworkError::RequestCancelled)
262 }
263
264 pub async fn get_record(
270 &self,
271 key: RecordKey,
272 retry_count: u8,
273 ) -> Result<Vec<u8>, NetworkError> {
274 let serialized_key = key.to_bytes();
276
277 let (s, r) = futures::channel::oneshot::channel();
278 let req = ClientRequest::GetDHT {
279 key: serialized_key.clone(),
280 notify: vec![s],
281 retry_count,
282 };
283 self.send_request(req)?;
284
285 let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
287
288 let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
290 .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
291
292 Ok(record.value().to_vec())
293 }
294
295 pub async fn get_record_timeout(
301 &self,
302 key: RecordKey,
303 timeout_duration: Duration,
304 ) -> Result<Vec<u8>, NetworkError> {
305 timeout(timeout_duration, self.get_record(key, 3))
306 .await
307 .map_err(|err| NetworkError::Timeout(err.to_string()))?
308 }
309
310 pub async fn put_record_timeout(
316 &self,
317 key: RecordKey,
318 value: RecordValue<T::SignatureKey>,
319 timeout_duration: Duration,
320 ) -> Result<(), NetworkError> {
321 timeout(timeout_duration, self.put_record(key, value))
322 .await
323 .map_err(|err| NetworkError::Timeout(err.to_string()))?
324 }
325
326 pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
330 let (s, r) = futures::channel::oneshot::channel();
331 let req = ClientRequest::Subscribe(topic, Some(s));
332 self.send_request(req)?;
333 r.await
334 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
335 }
336
337 pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
341 let (s, r) = futures::channel::oneshot::channel();
342 let req = ClientRequest::Unsubscribe(topic, Some(s));
343 self.send_request(req)?;
344 r.await
345 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
346 }
347
348 pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
353 let req = ClientRequest::IgnorePeers(peers);
354 self.send_request(req)
355 }
356
357 pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
362 self.direct_request_no_serialize(pid, msg.to_vec())
363 }
364
365 pub fn direct_request_no_serialize(
370 &self,
371 pid: PeerId,
372 contents: Vec<u8>,
373 ) -> Result<(), NetworkError> {
374 let req = ClientRequest::DirectRequest {
375 pid,
376 contents,
377 retry_count: 1,
378 };
379 self.send_request(req)
380 }
381
382 pub fn direct_response(
387 &self,
388 chan: ResponseChannel<Vec<u8>>,
389 msg: &[u8],
390 ) -> Result<(), NetworkError> {
391 let req = ClientRequest::DirectResponse(chan, msg.to_vec());
392 self.send_request(req)
393 }
394
395 pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
403 let req = ClientRequest::Prune(pid);
404 self.send_request(req)
405 }
406
407 pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
412 self.gossip_no_serialize(topic, msg.to_vec())
413 }
414
415 pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
420 let req = ClientRequest::GossipMsg(topic, msg);
421 self.send_request(req)
422 }
423
424 pub fn add_known_peers(
428 &self,
429 known_peers: Vec<(PeerId, Multiaddr)>,
430 ) -> Result<(), NetworkError> {
431 debug!("Adding {} known peers", known_peers.len());
432 let req = ClientRequest::AddKnownPeers(known_peers);
433 self.send_request(req)
434 }
435
436 fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
441 self.send_network
442 .send(req)
443 .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
444 }
445
446 pub async fn num_connected(&self) -> Result<usize, NetworkError> {
454 let (s, r) = futures::channel::oneshot::channel();
455 let req = ClientRequest::GetConnectedPeerNum(s);
456 self.send_request(req)?;
457 Ok(r.await.unwrap())
458 }
459
460 pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
468 let (s, r) = futures::channel::oneshot::channel();
469 let req = ClientRequest::GetConnectedPeers(s);
470 self.send_request(req)?;
471 Ok(r.await.unwrap())
472 }
473
474 #[must_use]
476 pub fn id(&self) -> usize {
477 self.id
478 }
479
480 #[must_use]
482 pub fn peer_id(&self) -> PeerId {
483 self.peer_id
484 }
485
486 #[must_use]
488 pub fn config(&self) -> &NetworkNodeConfig<T> {
489 &self.network_config
490 }
491}