cliquenet/
retry.rs

1use std::{
2    collections::BTreeMap,
3    convert::Infallible,
4    fmt::{self, Display},
5    hash::Hash,
6    result::Result as StdResult,
7    sync::{
8        Arc,
9        atomic::{AtomicU64, Ordering},
10    },
11};
12
13use bytes::{Bytes, BytesMut};
14use hotshot_types::addr::NetAddr;
15use nohash_hasher::IntMap;
16use parking_lot::Mutex;
17use scopeguard::{ScopeGuard, guard};
18use tokio::{
19    spawn,
20    sync::mpsc::{Sender, error::TrySendError},
21    task::JoinHandle,
22    time::{self, Duration, Instant},
23};
24use tracing::warn;
25
26use crate::{
27    Id, NUM_DELAYS, NetConf, Network, NetworkDown, NetworkError, PublicKey, Role, net::Command,
28};
29
30type Result<T> = std::result::Result<T, NetworkError>;
31
32/// Max. bucket number.
33pub const MAX_BUCKET: Bucket = Bucket(u64::MAX);
34
35/// `Retry` wraps a [`Network`] and returns acknowledgements to senders.
36///
37/// It also retries messages until either an acknowledgement has been received
38/// or client code has indicated that the messages are no longer of interest
39/// by invoking `Retry::gc`.
40///
41/// Each message that is sent has a trailer appended that contains the bucket
42/// number and ID of the message. Receivers will send this trailer back. The
43/// sender then stops retrying the corresponding message.
44///
45/// Note that if malicious parties modify the trailer and have it point to a
46/// different message, they can only remove themselves from the set of parties
47/// the sender is expecting an acknowledgement from.
48#[derive(Debug, Clone)]
49pub struct Retry<K> {
50    inner: Arc<Inner<K>>,
51}
52
53#[derive(Debug)]
54struct Inner<K> {
55    net: Network<K>,
56    sender: Sender<Command<K>>,
57    id: AtomicU64,
58    buffer: Buffer<K>,
59    retry: JoinHandle<Infallible>,
60    pending: Mutex<BTreeMap<Trailer, Pending<K>>>,
61}
62
63impl<K> Drop for Retry<K> {
64    fn drop(&mut self) {
65        self.inner.retry.abort()
66    }
67}
68
69/// Buckets conceptionally contain messages.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
71pub struct Bucket(u64);
72
73/// Messages are associated with IDs and put into buckets.
74///
75/// Bucket numbers are given to us by clients which also garbage collect
76/// explicitly by specifying the bucket up to which to remove messages.
77/// Buckets often correspond to rounds elsewhere.
78#[derive(Debug, Clone)]
79#[allow(clippy::type_complexity)]
80struct Buffer<K>(Arc<Mutex<BTreeMap<Bucket, IntMap<Id, Message<K>>>>>);
81
82impl<K> Default for Buffer<K> {
83    fn default() -> Self {
84        Self(Default::default())
85    }
86}
87
88#[derive(Debug)]
89struct Message<K> {
90    /// The message bytes to (re-)send.
91    data: Bytes,
92    /// The time we started sending this message.
93    time: Instant,
94    /// The number of times we have sent this message.
95    retries: usize,
96    /// The remaining number of parties that have to acknowledge the message.
97    remaining: Vec<K>,
98}
99
100/// Meta information appended at the end of a message.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
102struct Trailer {
103    /// The bucket number the message corresponds to.
104    bucket: Bucket,
105    /// The message ID.
106    id: Id,
107}
108
109/// Data we have received but could not acknowledge yet.
110#[derive(Debug, Clone)]
111struct Pending<K> {
112    src: K,
113    data: Bytes,
114    trailer: Bytes,
115}
116
117enum Target<K> {
118    Single(K),
119    Multi(Vec<K>),
120    All,
121}
122
123impl<K> Retry<K>
124where
125    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
126{
127    pub async fn create(mut cfg: NetConf<K>) -> Result<Self> {
128        cfg.max_message_size += Trailer::SIZE;
129        let delays = cfg.retry_delays;
130        let net = Network::create(cfg).await?;
131        let buffer = Buffer::default();
132        let retry = spawn(retry(buffer.clone(), net.sender(), delays));
133        Ok(Self {
134            inner: Arc::new(Inner {
135                sender: net.sender(),
136                net,
137                buffer,
138                id: AtomicU64::new(0),
139                retry,
140                pending: Mutex::new(BTreeMap::new()),
141            }),
142        })
143    }
144
145    pub async fn close(&self) {
146        self.inner.retry.abort();
147        self.inner.net.close().await;
148    }
149
150    pub fn parties(&self, r: Option<Role>) -> Vec<K> {
151        self.inner.net.parties(r)
152    }
153
154    pub async fn broadcast<B>(&self, b: B, data: Vec<u8>) -> Result<Id>
155    where
156        B: Into<Bucket>,
157    {
158        self.send(b.into(), Target::All, data).await
159    }
160
161    pub async fn multicast<B>(&self, to: Vec<K>, b: B, data: Vec<u8>) -> Result<Id>
162    where
163        B: Into<Bucket>,
164    {
165        self.send(b.into(), Target::Multi(to), data).await
166    }
167
168    pub async fn unicast<B>(&self, to: K, b: B, data: Vec<u8>) -> Result<Id>
169    where
170        B: Into<Bucket>,
171    {
172        self.send(b.into(), Target::Single(to), data).await
173    }
174
175    pub async fn add(
176        &self,
177        r: Role,
178        peers: Vec<(K, PublicKey, NetAddr)>,
179    ) -> StdResult<(), NetworkDown> {
180        self.inner.net.add(r, peers).await
181    }
182
183    pub async fn remove(&self, peers: Vec<K>) -> StdResult<(), NetworkDown> {
184        self.inner.net.remove(peers).await
185    }
186
187    pub async fn assign(&self, r: Role, peers: Vec<K>) -> StdResult<(), NetworkDown> {
188        self.inner.net.assign(r, peers).await
189    }
190
191    pub async fn receive(&self) -> Result<(K, Bytes)> {
192        let first_entry = self.inner.pending.lock().pop_first();
193        if let Some((key, val)) = first_entry {
194            let pending = &self.inner.pending;
195            // Put the pending value back if the future is dropped:
196            let guard = guard((key, val.clone()), |(k, v)| {
197                pending.lock().insert(k, v);
198            });
199            self.inner
200                .sender
201                .send(Command::Unicast(val.src.clone(), None, val.trailer.clone()))
202                .await
203                .map_err(|_| NetworkError::ChannelClosed)?;
204            let _ = ScopeGuard::into_inner(guard);
205            return Ok((val.src, val.data));
206        }
207        loop {
208            let (src, mut bytes) = self.inner.net.receive().await?;
209
210            let Some((trailer, trailer_bytes)) = Trailer::from_bytes(&mut bytes) else {
211                warn!(node = %self.inner.net.label, "invalid trailer");
212                continue;
213            };
214
215            if !bytes.is_empty() {
216                // Send the trailer back as acknowledgement:
217                match self
218                    .inner
219                    .sender
220                    .try_send(Command::Unicast(src.clone(), None, trailer_bytes))
221                {
222                    Ok(()) => return Ok((src, bytes)),
223                    Err(TrySendError::Closed(_)) => return Err(NetworkError::ChannelClosed),
224                    Err(TrySendError::Full(Command::Unicast(src, _, trailer_bytes))) => {
225                        // Save received data for cancellation safety:
226                        self.inner.pending.lock().insert(
227                            trailer,
228                            Pending {
229                                src: src.clone(),
230                                data: bytes.clone(),
231                                trailer: trailer_bytes.clone(),
232                            },
233                        );
234                        self.inner
235                            .sender
236                            .send(Command::Unicast(src.clone(), None, trailer_bytes))
237                            .await
238                            .map_err(|_| NetworkError::ChannelClosed)?;
239                        self.inner.pending.lock().remove(&trailer);
240                        return Ok((src, bytes));
241                    },
242                    Err(TrySendError::Full(_)) => {
243                        unreachable!(
244                            "We tried sending a Command::Unicast so this is what we get back."
245                        )
246                    },
247                }
248            }
249
250            let mut messages = self.inner.buffer.0.lock();
251
252            if let Some(buckets) = messages.get_mut(&trailer.bucket)
253                && let Some(m) = buckets.get_mut(&trailer.id)
254            {
255                m.remaining.retain(|k| *k != src);
256                if m.remaining.is_empty() {
257                    buckets.remove(&trailer.id);
258                }
259            }
260        }
261    }
262
263    pub fn gc<B: Into<Bucket>>(&self, bucket: B) {
264        let bucket = bucket.into();
265        self.inner.buffer.0.lock().retain(|b, _| *b >= bucket);
266    }
267
268    pub fn rm<B: Into<Bucket>>(&self, bucket: B, id: Id) {
269        let bucket = bucket.into();
270        if let Some(messages) = self.inner.buffer.0.lock().get_mut(&bucket) {
271            messages.remove(&id);
272        }
273    }
274
275    async fn send(&self, b: Bucket, to: Target<K>, data: Vec<u8>) -> Result<Id> {
276        let id = self.next_id();
277
278        let trailer = Trailer { bucket: b, id };
279
280        let mut msg = BytesMut::from(Bytes::from(data));
281        msg.extend_from_slice(&trailer.to_bytes());
282        let msg = msg.freeze();
283
284        if msg.len() > self.inner.net.max_message_size {
285            warn!(
286                name = %self.inner.net.name,
287                node = %self.inner.net.label,
288                len  = %msg.len(),
289                max  = %self.inner.net.max_message_size,
290                "message too large to send"
291            );
292            return Err(NetworkError::MessageTooLarge);
293        }
294
295        let now = Instant::now();
296
297        let rem = match to {
298            Target::Single(to) => {
299                self.inner
300                    .sender
301                    .send(Command::Unicast(to.clone(), Some(id), msg.clone()))
302                    .await
303                    .map_err(|_| NetworkError::ChannelClosed)?;
304                vec![to]
305            },
306            Target::Multi(peers) => {
307                self.inner
308                    .sender
309                    .send(Command::Multicast(peers.clone(), Some(id), msg.clone()))
310                    .await
311                    .map_err(|_| NetworkError::ChannelClosed)?;
312                peers
313            },
314            Target::All => {
315                self.inner
316                    .sender
317                    .send(Command::Broadcast(Some(id), msg.clone()))
318                    .await
319                    .map_err(|_| NetworkError::ChannelClosed)?;
320                self.inner.net.parties(Some(Role::Active))
321            },
322        };
323
324        self.inner.buffer.0.lock().entry(b).or_default().insert(
325            id,
326            Message {
327                data: msg,
328                time: now,
329                retries: 0,
330                remaining: rem,
331            },
332        );
333
334        Ok(id)
335    }
336
337    fn next_id(&self) -> Id {
338        Id::from(self.inner.id.fetch_add(1, Ordering::Relaxed))
339    }
340}
341
342async fn retry<K>(buf: Buffer<K>, net: Sender<Command<K>>, delays: [u8; NUM_DELAYS]) -> Infallible
343where
344    K: Clone,
345{
346    let mut i = time::interval(Duration::from_secs(1));
347    i.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
348
349    let mut buckets = Vec::new();
350    let mut ids = Vec::new();
351
352    loop {
353        let now = i.tick().await;
354
355        debug_assert!(buckets.is_empty());
356        buckets.extend(buf.0.lock().keys().copied());
357
358        for b in buckets.drain(..) {
359            debug_assert!(ids.is_empty());
360            ids.extend(
361                buf.0
362                    .lock()
363                    .get(&b)
364                    .into_iter()
365                    .flat_map(|m| m.keys().copied()),
366            );
367
368            for id in ids.drain(..) {
369                let message;
370                let remaining;
371
372                {
373                    let mut buf = buf.0.lock();
374                    let Some(m) = buf.get_mut(&b).and_then(|m| m.get_mut(&id)) else {
375                        continue;
376                    };
377
378                    let delay = delays
379                        .get(m.retries)
380                        .copied()
381                        .or_else(|| delays.last().copied())
382                        .unwrap_or(30);
383
384                    if now.saturating_duration_since(m.time) < Duration::from_secs(delay.into()) {
385                        continue;
386                    }
387
388                    m.time = now;
389                    m.retries = m.retries.saturating_add(1);
390
391                    message = m.data.clone();
392                    remaining = m.remaining.clone();
393                }
394
395                let _ = net
396                    .send(Command::Multicast(remaining, Some(id), message.clone()))
397                    .await;
398            }
399        }
400    }
401}
402
403impl From<u64> for Bucket {
404    fn from(val: u64) -> Self {
405        Self(val)
406    }
407}
408
409impl From<Bucket> for u64 {
410    fn from(val: Bucket) -> Self {
411        val.0
412    }
413}
414
415impl fmt::Display for Bucket {
416    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
417        self.0.fmt(f)
418    }
419}
420
421// Besides the actual meta information, the last byte of the trailer encodes its
422// total length (including the length byte itself).
423impl Trailer {
424    const SIZE: usize = 17;
425
426    fn from_bytes(bytes: &mut Bytes) -> Option<(Self, Bytes)> {
427        let len: usize = bytes.last().copied()?.into();
428        if len < Self::SIZE || bytes.len() < len {
429            return None;
430        }
431        let slice = bytes.split_off(bytes.len() - len);
432        let id = u64::from_be_bytes(slice[len - 9..len - 1].try_into().ok()?);
433        let bucket = u64::from_be_bytes(slice[len - 17..len - 9].try_into().ok()?);
434        let this = Self {
435            bucket: bucket.into(),
436            id: id.into(),
437        };
438        Some((this, slice))
439    }
440
441    fn to_bytes(self) -> [u8; Self::SIZE] {
442        let mut buf = [0; Self::SIZE];
443        buf[..8].copy_from_slice(&self.bucket.0.to_be_bytes()[..]);
444        buf[8..16].copy_from_slice(&self.id.0.to_be_bytes()[..]);
445        buf[16] = Self::SIZE as u8;
446        buf
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use bytes::Bytes;
453    use quickcheck::quickcheck;
454
455    use super::Trailer;
456
457    quickcheck! {
458        fn to_from_bytes(b: u64, i: u64) -> bool {
459            let a = Trailer {
460                bucket: b.into(),
461                id: i.into()
462            };
463            let mut bytes = Bytes::copy_from_slice(&a.to_bytes());
464            let (b, _) = Trailer::from_bytes(&mut bytes).unwrap();
465            a == b
466        }
467    }
468}