hotshot/traits/networking/
compat_network.rs

1use std::sync::{
2    Arc, OnceLock,
3    atomic::{AtomicBool, Ordering},
4};
5#[cfg(feature = "hotshot-testing")]
6use std::{collections::HashMap, time::Duration};
7
8use async_trait::async_trait;
9use hotshot_types::{
10    BoxSyncFuture, boxed_sync,
11    data::{EpochNumber, ViewNumber},
12    epoch_membership::EpochMembershipCoordinator,
13    message::UpgradeLock,
14    traits::{
15        network::{BroadcastDelay, ConnectedNetwork, NetworkError, Topic},
16        node_implementation::NodeType,
17    },
18};
19#[cfg(feature = "hotshot-testing")]
20use hotshot_types::{
21    PeerConnectInfo,
22    traits::network::{AsyncGenerator, NetworkReliability, TestableNetworkingImplementation},
23};
24use tokio::{join, select};
25use tracing::info;
26use versions::CLIQUENET_VERSION;
27
28use crate::traits::networking::Cliquenet;
29
30/// Compatibility network.
31///
32/// Uses either a fallback network (any impl of `ConnectedNetwork`), or else
33/// `Cliquenet` once the protocol has been upgraded to `CLIQUENET_VERSION`.
34///
35/// Receiving listens on both networks simultaneously so that messages arriving
36/// on either side are never lost. Sending is routed to the active network only.
37#[derive(Clone)]
38pub struct CompatNetwork<A, TYPES: NodeType> {
39    cliquenet: Cliquenet<TYPES::SignatureKey>,
40    fallback: A,
41    use_cliquenet: Arc<AtomicBool>,
42    upgrade_lock: Arc<OnceLock<UpgradeLock<TYPES>>>,
43}
44
45impl<A, TYPES> CompatNetwork<A, TYPES>
46where
47    TYPES: NodeType,
48{
49    pub async fn new(cliquenet: Cliquenet<TYPES::SignatureKey>, fallback: A) -> Self {
50        Self {
51            cliquenet,
52            fallback,
53            use_cliquenet: Arc::new(AtomicBool::new(false)),
54            upgrade_lock: Arc::new(OnceLock::new()),
55        }
56    }
57
58    pub fn cliquenet(&self) -> &Cliquenet<TYPES::SignatureKey> {
59        &self.cliquenet
60    }
61
62    pub fn fallback(&self) -> &A {
63        &self.fallback
64    }
65
66    pub fn set_upgrade_lock(&self, lock: UpgradeLock<TYPES>) {
67        let _ = self.upgrade_lock.set(lock);
68    }
69
70    pub fn use_cliquenet(&self) {
71        self.use_cliquenet.store(true, Ordering::Relaxed)
72    }
73
74    pub fn is_cliquenet(&self) -> bool {
75        self.use_cliquenet.load(Ordering::Relaxed)
76    }
77
78    fn maybe_switch_to_cliquenet(&self, view: ViewNumber) {
79        if self.is_cliquenet() {
80            return;
81        }
82        if let Some(lock) = self.upgrade_lock.get()
83            && lock.version_infallible(view) >= CLIQUENET_VERSION
84        {
85            info!("switching to cliquenet network");
86            self.use_cliquenet();
87        }
88    }
89}
90
91#[async_trait]
92impl<A, TYPES> ConnectedNetwork<TYPES::SignatureKey> for CompatNetwork<A, TYPES>
93where
94    A: ConnectedNetwork<TYPES::SignatureKey>,
95    TYPES: NodeType,
96{
97    async fn broadcast_message(
98        &self,
99        v: ViewNumber,
100        m: Vec<u8>,
101        t: Topic,
102        d: BroadcastDelay,
103    ) -> Result<(), NetworkError> {
104        if self.is_cliquenet() {
105            self.cliquenet.broadcast_message(v, m, t, d).await
106        } else {
107            self.fallback.broadcast_message(v, m, t, d).await
108        }
109    }
110
111    async fn da_broadcast_message(
112        &self,
113        v: ViewNumber,
114        m: Vec<u8>,
115        recipients: Vec<TYPES::SignatureKey>,
116        d: BroadcastDelay,
117    ) -> Result<(), NetworkError> {
118        if self.is_cliquenet() {
119            self.cliquenet
120                .da_broadcast_message(v, m, recipients, d)
121                .await
122        } else {
123            self.fallback
124                .da_broadcast_message(v, m, recipients, d)
125                .await
126        }
127    }
128
129    async fn direct_message(
130        &self,
131        v: ViewNumber,
132        m: Vec<u8>,
133        recipient: TYPES::SignatureKey,
134    ) -> Result<(), NetworkError> {
135        if self.is_cliquenet() {
136            self.cliquenet.direct_message(v, m, recipient).await
137        } else {
138            self.fallback.direct_message(v, m, recipient).await
139        }
140    }
141
142    async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
143        select! {
144            m = self.cliquenet.recv_message() => m,
145            m = self.fallback.recv_message() => m
146        }
147    }
148
149    async fn update_view<U>(
150        &self,
151        v: ViewNumber,
152        e: Option<EpochNumber>,
153        m: EpochMembershipCoordinator<U>,
154    ) where
155        U: NodeType<SignatureKey = TYPES::SignatureKey>,
156    {
157        self.maybe_switch_to_cliquenet(v);
158        join! {
159            self.cliquenet.update_view(v, e, m.clone()),
160            self.fallback.update_view(v, e, m)
161        };
162    }
163
164    async fn wait_for_ready(&self) {
165        if self.is_cliquenet() {
166            self.cliquenet.wait_for_ready().await
167        } else {
168            self.fallback.wait_for_ready().await
169        }
170    }
171
172    fn pause(&self) {
173        self.cliquenet.pause();
174        self.fallback.pause()
175    }
176
177    fn resume(&self) {
178        self.cliquenet.resume();
179        self.fallback.resume()
180    }
181
182    fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
183    where
184        'a: 'b,
185        Self: 'b,
186    {
187        let a = self.cliquenet.shut_down();
188        let b = self.fallback.shut_down();
189        boxed_sync(async {
190            join!(a, b);
191        })
192    }
193}
194
195#[cfg(feature = "hotshot-testing")]
196impl<T, A> TestableNetworkingImplementation<T> for CompatNetwork<A, T>
197where
198    T: NodeType,
199    A: TestableNetworkingImplementation<T> + Clone + Send + 'static,
200{
201    fn generator(
202        nodes: usize,
203        num_bootstrap: usize,
204        network_id: usize,
205        da_committee_size: usize,
206        reliability_config: Option<Box<dyn NetworkReliability>>,
207        secondary_network_delay: Duration,
208        connect_infos: &mut HashMap<T::SignatureKey, PeerConnectInfo>,
209    ) -> AsyncGenerator<Arc<Self>> {
210        let cliquenet =
211            <Cliquenet<T::SignatureKey> as TestableNetworkingImplementation<T>>::generator(
212                nodes,
213                num_bootstrap,
214                network_id,
215                da_committee_size,
216                reliability_config.clone(),
217                secondary_network_delay,
218                connect_infos,
219            );
220
221        let fallback = <A as TestableNetworkingImplementation<T>>::generator(
222            nodes,
223            num_bootstrap,
224            network_id,
225            da_committee_size,
226            reliability_config.clone(),
227            secondary_network_delay,
228            connect_infos,
229        );
230
231        Box::pin(move |i: u64| {
232            let cliquenet = cliquenet(i);
233            let fallback = fallback(i);
234
235            let future = async move {
236                let cliquenet = Arc::unwrap_or_clone(cliquenet.await);
237                let fallback = Arc::unwrap_or_clone(fallback.await);
238                Arc::new(Self::new(cliquenet, fallback).await)
239            };
240
241            Box::pin(future)
242        })
243    }
244
245    fn in_flight_message_count(&self) -> Option<usize> {
246        None
247    }
248}