hotshot_libp2p_networking/network/node/
handle.rs1use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11 network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{Multiaddr, request_response::ResponseChannel};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17 sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18 time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23 ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
24 behaviours::dht::{
25 record::{Namespace, RecordKey, RecordValue},
26 store::persistent::DhtPersistentStorage,
27 },
28 gen_multiaddr,
29};
30
31#[derive(Debug, Clone)]
35pub struct NetworkNodeHandle<T: NodeType> {
36 network_config: NetworkNodeConfig,
38
39 send_network: UnboundedSender<ClientRequest>,
41
42 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
44
45 listen_addr: Multiaddr,
47
48 peer_id: PeerId,
50
51 id: usize,
53}
54
55#[derive(Debug)]
57pub struct NetworkNodeReceiver {
58 receiver: UnboundedReceiver<NetworkEvent>,
60
61 recv_kill: Option<Receiver<()>>,
63}
64
65impl NetworkNodeReceiver {
66 pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
70 self.receiver
71 .recv()
72 .await
73 .ok_or(NetworkError::ChannelReceiveError(
74 "Receiver channel closed".to_string(),
75 ))
76 }
77 pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
79 self.recv_kill = Some(kill_switch);
80 }
81
82 pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
84 self.recv_kill.take()
85 }
86}
87
88pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
92 config: NetworkNodeConfig,
93 dht_persistent_storage: D,
94 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
95 id: usize,
96) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
97 let mut network: NetworkNode<T, _> = NetworkNode::new(
98 config.clone(),
99 dht_persistent_storage,
100 Arc::clone(&consensus_key_to_pid_map),
101 )
102 .await
103 .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
104 let listen_addr = config
106 .bind_address
107 .clone()
108 .unwrap_or_else(|| gen_multiaddr(0));
109 let peer_id = network.peer_id();
110 let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
111 NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
112 })?;
113 let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
116 NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
117 })?;
118 let receiver = NetworkNodeReceiver {
119 receiver: recv_chan,
120 recv_kill: None,
121 };
122
123 let handle = NetworkNodeHandle::<T> {
124 network_config: config,
125 send_network: send_chan,
126 consensus_key_to_pid_map,
127 listen_addr,
128 peer_id,
129 id,
130 };
131 Ok((receiver, handle))
132}
133
134impl<T: NodeType> NetworkNodeHandle<T> {
135 #[instrument]
139 pub async fn shutdown(&self) -> Result<(), NetworkError> {
140 self.send_request(ClientRequest::Shutdown)?;
141 Ok(())
142 }
143 pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
148 let req = ClientRequest::BeginBootstrap;
149 self.send_request(req)
150 }
151
152 #[must_use]
154 pub fn listen_addr(&self) -> Multiaddr {
155 self.listen_addr.clone()
156 }
157
158 pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
163 let (s, r) = futures::channel::oneshot::channel();
164 let req = ClientRequest::GetRoutingTable(s);
165 self.send_request(req)?;
166 r.await
167 .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
168 }
169 pub async fn wait_to_connect(
174 &self,
175 num_required_peers: usize,
176 node_id: usize,
177 ) -> Result<(), NetworkError> {
178 loop {
180 let num_connected = self.num_connected().await?;
182 if num_connected >= num_required_peers {
183 break;
184 }
185
186 info!(
188 "Node {} connected to {}/{} peers",
189 node_id, num_connected, num_required_peers
190 );
191
192 sleep(Duration::from_secs(1)).await;
194 }
195
196 Ok(())
197 }
198
199 pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
204 let (s, r) = futures::channel::oneshot::channel();
205 let req = ClientRequest::LookupPeer(peer_id, s);
206 self.send_request(req)?;
207 r.await
208 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
209 }
210
211 pub async fn lookup_node(
216 &self,
217 consensus_key: &T::SignatureKey,
218 dht_timeout: Duration,
219 ) -> Result<PeerId, NetworkError> {
220 if let Some(pid) = self
222 .consensus_key_to_pid_map
223 .lock()
224 .get_by_left(consensus_key)
225 {
226 return Ok(*pid);
227 }
228
229 let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
231
232 let pid = self.get_record_timeout(key, dht_timeout).await?;
234
235 PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
236 }
237
238 pub async fn put_record(
242 &self,
243 key: RecordKey,
244 value: RecordValue<T::SignatureKey>,
245 ) -> Result<(), NetworkError> {
246 let key = key.to_bytes();
248
249 let value = bincode::serialize(&value)
251 .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
252
253 let (s, r) = futures::channel::oneshot::channel();
254 let req = ClientRequest::PutDHT {
255 key: key.clone(),
256 value,
257 notify: s,
258 };
259
260 self.send_request(req)?;
261
262 r.await.map_err(|_| NetworkError::RequestCancelled)
263 }
264
265 pub async fn get_record(
271 &self,
272 key: RecordKey,
273 retry_count: u8,
274 ) -> Result<Vec<u8>, NetworkError> {
275 let serialized_key = key.to_bytes();
277
278 let (s, r) = futures::channel::oneshot::channel();
279 let req = ClientRequest::GetDHT {
280 key: serialized_key.clone(),
281 notify: vec![s],
282 retry_count,
283 };
284 self.send_request(req)?;
285
286 let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
288
289 let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
291 .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
292
293 Ok(record.value().to_vec())
294 }
295
296 pub async fn get_record_timeout(
302 &self,
303 key: RecordKey,
304 timeout_duration: Duration,
305 ) -> Result<Vec<u8>, NetworkError> {
306 timeout(timeout_duration, self.get_record(key, 3))
307 .await
308 .map_err(|err| NetworkError::Timeout(err.to_string()))?
309 }
310
311 pub async fn put_record_timeout(
317 &self,
318 key: RecordKey,
319 value: RecordValue<T::SignatureKey>,
320 timeout_duration: Duration,
321 ) -> Result<(), NetworkError> {
322 timeout(timeout_duration, self.put_record(key, value))
323 .await
324 .map_err(|err| NetworkError::Timeout(err.to_string()))?
325 }
326
327 pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
331 let (s, r) = futures::channel::oneshot::channel();
332 let req = ClientRequest::Subscribe(topic, Some(s));
333 self.send_request(req)?;
334 r.await
335 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
336 }
337
338 pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
342 let (s, r) = futures::channel::oneshot::channel();
343 let req = ClientRequest::Unsubscribe(topic, Some(s));
344 self.send_request(req)?;
345 r.await
346 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
347 }
348
349 pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
354 let req = ClientRequest::IgnorePeers(peers);
355 self.send_request(req)
356 }
357
358 pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
363 self.direct_request_no_serialize(pid, msg.to_vec())
364 }
365
366 pub fn direct_request_no_serialize(
371 &self,
372 pid: PeerId,
373 contents: Vec<u8>,
374 ) -> Result<(), NetworkError> {
375 let req = ClientRequest::DirectRequest {
376 pid,
377 contents,
378 retry_count: 1,
379 };
380 self.send_request(req)
381 }
382
383 pub fn direct_response(
388 &self,
389 chan: ResponseChannel<Vec<u8>>,
390 msg: &[u8],
391 ) -> Result<(), NetworkError> {
392 let req = ClientRequest::DirectResponse(chan, msg.to_vec());
393 self.send_request(req)
394 }
395
396 pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
404 let req = ClientRequest::Prune(pid);
405 self.send_request(req)
406 }
407
408 pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
413 self.gossip_no_serialize(topic, msg.to_vec())
414 }
415
416 pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
421 let req = ClientRequest::GossipMsg(topic, msg);
422 self.send_request(req)
423 }
424
425 pub fn add_known_peers(
429 &self,
430 known_peers: Vec<(PeerId, Multiaddr)>,
431 ) -> Result<(), NetworkError> {
432 debug!("Adding {} known peers", known_peers.len());
433 let req = ClientRequest::AddKnownPeers(known_peers);
434 self.send_request(req)
435 }
436
437 fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
442 self.send_network
443 .send(req)
444 .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
445 }
446
447 pub async fn num_connected(&self) -> Result<usize, NetworkError> {
455 let (s, r) = futures::channel::oneshot::channel();
456 let req = ClientRequest::GetConnectedPeerNum(s);
457 self.send_request(req)?;
458 Ok(r.await.unwrap())
459 }
460
461 pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
469 let (s, r) = futures::channel::oneshot::channel();
470 let req = ClientRequest::GetConnectedPeers(s);
471 self.send_request(req)?;
472 Ok(r.await.unwrap())
473 }
474
475 #[must_use]
477 pub fn id(&self) -> usize {
478 self.id
479 }
480
481 #[must_use]
483 pub fn peer_id(&self) -> PeerId {
484 self.peer_id
485 }
486
487 #[must_use]
489 pub fn config(&self) -> &NetworkNodeConfig {
490 &self.network_config
491 }
492}