1use std::{
10 collections::{BTreeMap, HashMap},
11 future::Future,
12 num::NonZeroUsize,
13 sync::{
14 Arc,
15 atomic::{AtomicBool, AtomicU64, Ordering},
16 },
17 time::Duration,
18};
19
20use async_broadcast::{InactiveReceiver, Sender, broadcast};
21use async_lock::RwLock;
22use async_trait::async_trait;
23use futures::{FutureExt, join, select};
24use hotshot_types::{
25 BoxSyncFuture, boxed_sync,
26 constants::{
27 COMBINED_NETWORK_CACHE_SIZE, COMBINED_NETWORK_DELAY_DURATION,
28 COMBINED_NETWORK_MIN_PRIMARY_FAILURES, COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL,
29 },
30 data::{EpochNumber, ViewNumber},
31 epoch_membership::EpochMembershipCoordinator,
32 traits::{
33 network::{BroadcastDelay, ConnectedNetwork, Topic},
34 node_implementation::NodeType,
35 },
36};
37#[cfg(feature = "hotshot-testing")]
38use hotshot_types::{
39 PeerConnectInfo,
40 traits::network::{AsyncGenerator, NetworkReliability, TestableNetworkingImplementation},
41};
42use lru::LruCache;
43use parking_lot::RwLock as PlRwLock;
44use tokio::{spawn, sync::mpsc::error::TrySendError, time::sleep};
45use tracing::{debug, info, warn};
46
47use super::{NetworkError, push_cdn_network::PushCdnNetwork};
48use crate::traits::implementations::Libp2pNetwork;
49
50type DelayedTasksChannelsMap = Arc<RwLock<BTreeMap<u64, (Sender<()>, InactiveReceiver<()>)>>>;
52
53#[derive(Clone)]
56pub struct CombinedNetworks<TYPES: NodeType> {
57 networks: Arc<UnderlyingCombinedNetworks<TYPES>>,
59
60 message_deduplication_cache: Arc<PlRwLock<MessageDeduplicationCache>>,
63
64 primary_fail_counter: Arc<AtomicU64>,
66
67 primary_down: Arc<AtomicBool>,
69
70 delay_duration: Arc<RwLock<Duration>>,
72
73 delayed_tasks_channels: DelayedTasksChannelsMap,
75
76 no_delay_counter: Arc<AtomicU64>,
78}
79
80impl<TYPES: NodeType> CombinedNetworks<TYPES> {
81 #[must_use]
87 pub fn new(
88 primary_network: PushCdnNetwork<TYPES::SignatureKey>,
89 secondary_network: Libp2pNetwork<TYPES>,
90 delay_duration: Option<Duration>,
91 ) -> Self {
92 let networks = Arc::from(UnderlyingCombinedNetworks(
94 primary_network,
95 secondary_network,
96 ));
97
98 Self {
99 networks,
100 message_deduplication_cache: Arc::new(PlRwLock::new(MessageDeduplicationCache::new())),
101 primary_fail_counter: Arc::new(AtomicU64::new(0)),
102 primary_down: Arc::new(AtomicBool::new(false)),
103 delay_duration: Arc::new(RwLock::new(
104 delay_duration.unwrap_or(Duration::from_millis(COMBINED_NETWORK_DELAY_DURATION)),
105 )),
106 delayed_tasks_channels: Arc::default(),
107 no_delay_counter: Arc::new(AtomicU64::new(0)),
108 }
109 }
110
111 #[must_use]
113 pub fn primary(&self) -> &PushCdnNetwork<TYPES::SignatureKey> {
114 &self.networks.0
115 }
116
117 #[must_use]
119 pub fn secondary(&self) -> &Libp2pNetwork<TYPES> {
120 &self.networks.1
121 }
122
123 async fn send_both_networks(
125 &self,
126 _message: Vec<u8>,
127 primary_future: impl Future<Output = Result<(), NetworkError>> + Send + 'static,
128 secondary_future: impl Future<Output = Result<(), NetworkError>> + Send + 'static,
129 broadcast_delay: BroadcastDelay,
130 ) -> Result<(), NetworkError> {
131 let mut primary_failed = false;
133 if self.primary_down.load(Ordering::Relaxed) {
134 primary_failed = true;
136 } else if self.primary_fail_counter.load(Ordering::Relaxed)
137 > COMBINED_NETWORK_MIN_PRIMARY_FAILURES
138 {
139 info!(
142 "View progression is slower than normally, stop delaying messages on the secondary"
143 );
144 self.primary_down.store(true, Ordering::Relaxed);
145 primary_failed = true;
146 }
147
148 if let Err(e) = primary_future.await {
150 warn!("Error on primary network: {}", e);
152 self.primary_fail_counter.fetch_add(1, Ordering::Relaxed);
153 primary_failed = true;
154 };
155
156 if let (BroadcastDelay::View(view), false) = (broadcast_delay, primary_failed) {
157 let duration = *self.delay_duration.read().await;
159 let primary_down = Arc::clone(&self.primary_down);
160 let primary_fail_counter = Arc::clone(&self.primary_fail_counter);
161 let mut receiver = self
164 .delayed_tasks_channels
165 .write()
166 .await
167 .entry(view)
168 .or_insert_with(|| {
169 let (s, r) = broadcast(1);
170 (s, r.deactivate())
171 })
172 .1
173 .activate_cloned();
174 spawn(async move {
176 sleep(duration).await;
177 if receiver.try_recv().is_ok() {
178 debug!(
180 "Not sending on secondary after delay, task was canceled in view update"
181 );
182 match primary_fail_counter.load(Ordering::Relaxed) {
183 0u64 => {
184 primary_down.store(false, Ordering::Relaxed);
186 debug!("primary_fail_counter reached zero, primary_down set to false");
187 },
188 c => {
189 primary_fail_counter.store(c - 1, Ordering::Relaxed);
191 debug!("primary_fail_counter set to {:?}", c - 1);
192 },
193 }
194 return Ok(());
195 }
196 debug!(
199 "Sending on secondary after delay, message possibly has not reached recipient \
200 on primary"
201 );
202 primary_fail_counter.fetch_add(1, Ordering::Relaxed);
203 secondary_future.await
204 });
205 Ok(())
206 } else {
207 if self.primary_down.load(Ordering::Relaxed) {
209 match self.no_delay_counter.load(Ordering::Relaxed) {
213 c if c < COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL => {
214 self.no_delay_counter.store(c + 1, Ordering::Relaxed);
216 },
217 _ => {
218 debug!(
220 "Sent on secondary without delay more than {} times,try delaying to \
221 check primary",
222 COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL
223 );
224 self.no_delay_counter.store(0u64, Ordering::Relaxed);
226 self.primary_down.store(false, Ordering::Relaxed);
228 self.primary_fail_counter
230 .store(COMBINED_NETWORK_MIN_PRIMARY_FAILURES, Ordering::Relaxed);
231 },
232 }
233 }
234 tokio::time::timeout(Duration::from_secs(2), secondary_future)
236 .await
237 .map_err(|e| NetworkError::Timeout(e.to_string()))?
238 }
239 }
240}
241
242#[derive(Clone)]
246pub struct UnderlyingCombinedNetworks<TYPES: NodeType>(
247 pub PushCdnNetwork<TYPES::SignatureKey>,
248 pub Libp2pNetwork<TYPES>,
249);
250
251#[cfg(feature = "hotshot-testing")]
252impl<TYPES: NodeType> TestableNetworkingImplementation<TYPES> for CombinedNetworks<TYPES> {
253 fn generator(
254 expected_node_count: usize,
255 num_bootstrap: usize,
256 network_id: usize,
257 da_committee_size: usize,
258 reliability_config: Option<Box<dyn NetworkReliability>>,
259 secondary_network_delay: Duration,
260 connect_infos: &mut HashMap<TYPES::SignatureKey, PeerConnectInfo>,
261 ) -> AsyncGenerator<Arc<Self>> {
262 let generators = (
263 <PushCdnNetwork<TYPES::SignatureKey> as TestableNetworkingImplementation<TYPES>>::generator(
264 expected_node_count,
265 num_bootstrap,
266 network_id,
267 da_committee_size,
268 None,
269 Duration::default(),
270 connect_infos
271 ),
272 <Libp2pNetwork<TYPES> as TestableNetworkingImplementation<TYPES>>::generator(
273 expected_node_count,
274 num_bootstrap,
275 network_id,
276 da_committee_size,
277 reliability_config,
278 Duration::default(),
279 connect_infos
280 )
281 );
282 Box::pin(move |node_id| {
283 let gen0 = generators.0(node_id);
284 let gen1 = generators.1(node_id);
285
286 Box::pin(async move {
287 let cdn = gen0.await;
289 let cdn = Arc::<PushCdnNetwork<TYPES::SignatureKey>>::into_inner(cdn).unwrap();
290
291 let p2p = gen1.await;
293
294 let underlying_combined = UnderlyingCombinedNetworks(
296 cdn.clone(),
297 Arc::<Libp2pNetwork<TYPES>>::unwrap_or_clone(p2p),
298 );
299
300 let message_deduplication_cache =
302 Arc::new(PlRwLock::new(MessageDeduplicationCache::new()));
303
304 let combined_network = Self {
306 networks: Arc::new(underlying_combined),
307 primary_fail_counter: Arc::new(AtomicU64::new(0)),
308 primary_down: Arc::new(AtomicBool::new(false)),
309 message_deduplication_cache: Arc::clone(&message_deduplication_cache),
310 delay_duration: Arc::new(RwLock::new(secondary_network_delay)),
311 delayed_tasks_channels: Arc::default(),
312 no_delay_counter: Arc::new(AtomicU64::new(0)),
313 };
314
315 Arc::new(combined_network)
316 })
317 })
318 }
319
320 fn in_flight_message_count(&self) -> Option<usize> {
324 None
325 }
326}
327
328#[async_trait]
329impl<TYPES: NodeType> ConnectedNetwork<TYPES::SignatureKey> for CombinedNetworks<TYPES> {
330 fn pause(&self) {
331 self.networks.0.pause();
332 }
333
334 fn resume(&self) {
335 self.networks.0.resume();
336 }
337
338 async fn wait_for_ready(&self) {
339 join!(
340 self.primary().wait_for_ready(),
341 self.secondary().wait_for_ready()
342 );
343 }
344
345 fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
346 where
347 'a: 'b,
348 Self: 'b,
349 {
350 let closure = async move {
351 join!(self.primary().shut_down(), self.secondary().shut_down());
352 };
353 boxed_sync(closure)
354 }
355
356 async fn broadcast_message(
357 &self,
358 view: ViewNumber,
359 message: Vec<u8>,
360 topic: Topic,
361 broadcast_delay: BroadcastDelay,
362 ) -> Result<(), NetworkError> {
363 let primary = self.primary().clone();
364 let secondary = self.secondary().clone();
365 let primary_message = message.clone();
366 let secondary_message = message.clone();
367 self.send_both_networks(
368 message,
369 async move {
370 primary
371 .broadcast_message(view, primary_message, topic, BroadcastDelay::None)
372 .await
373 },
374 async move {
375 secondary
376 .broadcast_message(view, secondary_message, topic, BroadcastDelay::None)
377 .await
378 },
379 broadcast_delay,
380 )
381 .await
382 }
383
384 async fn da_broadcast_message(
385 &self,
386 view: ViewNumber,
387 message: Vec<u8>,
388 recipients: Vec<TYPES::SignatureKey>,
389 broadcast_delay: BroadcastDelay,
390 ) -> Result<(), NetworkError> {
391 let primary = self.primary().clone();
392 let secondary = self.secondary().clone();
393 let primary_message = message.clone();
394 let secondary_message = message.clone();
395 let primary_recipients = recipients.clone();
396 self.send_both_networks(
397 message,
398 async move {
399 primary
400 .da_broadcast_message(
401 view,
402 primary_message,
403 primary_recipients,
404 BroadcastDelay::None,
405 )
406 .await
407 },
408 async move {
409 secondary
410 .da_broadcast_message(view, secondary_message, recipients, BroadcastDelay::None)
411 .await
412 },
413 broadcast_delay,
414 )
415 .await
416 }
417
418 async fn direct_message(
419 &self,
420 view: ViewNumber,
421 message: Vec<u8>,
422 recipient: TYPES::SignatureKey,
423 ) -> Result<(), NetworkError> {
424 let primary = self.primary().clone();
425 let secondary = self.secondary().clone();
426 let primary_message = message.clone();
427 let secondary_message = message.clone();
428 let primary_recipient = recipient.clone();
429 self.send_both_networks(
430 message,
431 async move {
432 primary
433 .direct_message(view, primary_message, primary_recipient)
434 .await
435 },
436 async move {
437 secondary
438 .direct_message(view, secondary_message, recipient)
439 .await
440 },
441 BroadcastDelay::None,
442 )
443 .await
444 }
445
446 async fn vid_broadcast_message(
447 &self,
448 messages: HashMap<TYPES::SignatureKey, (ViewNumber, Vec<u8>)>,
449 ) -> Result<(), NetworkError> {
450 self.networks.0.vid_broadcast_message(messages).await
451 }
452
453 async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
458 loop {
459 let mut primary_fut = self.primary().recv_message().fuse();
461 let mut secondary_fut = self.secondary().recv_message().fuse();
462
463 let (message, from_primary) = select! {
465 p = primary_fut => (p?, true),
466 s = secondary_fut => (s?, false),
467 };
468
469 if self
471 .message_deduplication_cache
472 .write()
473 .is_unique(&message, from_primary)
474 {
475 break Ok(message);
476 }
477 }
478 }
479
480 fn queue_node_lookup(
481 &self,
482 view_number: ViewNumber,
483 pk: TYPES::SignatureKey,
484 ) -> Result<(), TrySendError<Option<(ViewNumber, TYPES::SignatureKey)>>> {
485 self.primary().queue_node_lookup(view_number, pk.clone())?;
486 self.secondary().queue_node_lookup(view_number, pk)
487 }
488
489 async fn update_view<T>(
490 &self,
491 view: ViewNumber,
492 epoch: Option<EpochNumber>,
493 membership: EpochMembershipCoordinator<T>,
494 ) where
495 T: NodeType<SignatureKey = TYPES::SignatureKey>,
496 {
497 let delayed_tasks_channels = Arc::clone(&self.delayed_tasks_channels);
498 spawn(async move {
499 let mut map_lock = delayed_tasks_channels.write().await;
500 while let Some((first_view, _)) = map_lock.first_key_value() {
501 if *first_view < *view {
503 if let Some((_, (sender, _))) = map_lock.pop_first() {
504 let _ = sender.try_broadcast(());
505 } else {
506 break;
507 }
508 } else {
509 break;
510 }
511 }
512 });
513 self.networks
515 .1
516 .update_view::<T>(view, epoch, membership)
517 .await;
518 }
519
520 fn is_primary_down(&self) -> bool {
521 self.primary_down.load(Ordering::Relaxed)
522 }
523}
524
525struct MessageDeduplicationCache {
526 primary_message_cache: LruCache<blake3::Hash, ()>,
529
530 secondary_message_cache: LruCache<blake3::Hash, ()>,
533}
534
535impl MessageDeduplicationCache {
536 fn new() -> Self {
538 Self {
539 primary_message_cache: LruCache::new(
540 NonZeroUsize::new(COMBINED_NETWORK_CACHE_SIZE).unwrap(),
541 ),
542 secondary_message_cache: LruCache::new(
543 NonZeroUsize::new(COMBINED_NETWORK_CACHE_SIZE).unwrap(),
544 ),
545 }
546 }
547
548 fn is_unique(&mut self, message: &[u8], from_primary: bool) -> bool {
550 let message_hash = blake3::hash(message);
552
553 let (this_cache, other_cache) = if from_primary {
555 (
556 &mut self.primary_message_cache,
557 &mut self.secondary_message_cache,
558 )
559 } else {
560 (
561 &mut self.secondary_message_cache,
562 &mut self.primary_message_cache,
563 )
564 };
565
566 if other_cache.pop(&message_hash).is_some() {
569 false
571 } else {
572 this_cache.put(message_hash, ());
574 true
575 }
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_message_deduplication() {
585 let message = b"hello";
586
587 let mut cache = MessageDeduplicationCache::new();
589 assert!(cache.is_unique(message, true));
590 assert!(!cache.is_unique(message, false));
591
592 assert!(cache.is_unique(message, true));
595 assert!(!cache.is_unique(message, false));
596
597 assert!(cache.is_unique(message, false));
599 assert!(!cache.is_unique(message, true));
600 assert!(cache.is_unique(message, false));
601 assert!(!cache.is_unique(message, true));
602
603 assert!(cache.is_unique(message, true));
605 assert!(cache.is_unique(message, true));
606 assert!(cache.is_unique(message, true));
607 assert!(!cache.is_unique(message, false));
608 assert!(cache.is_unique(message, false));
609 }
610}