cliquenet/
retry.rs

1use std::{
2    collections::BTreeMap,
3    convert::Infallible,
4    fmt::{self, Display},
5    hash::Hash,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10};
11
12use bytes::{Bytes, BytesMut};
13use nohash_hasher::IntMap;
14use parking_lot::Mutex;
15use tokio::{
16    spawn,
17    sync::mpsc::{Sender, error::TrySendError},
18    task::JoinHandle,
19    time::{self, Duration, Instant},
20};
21use tracing::warn;
22
23use crate::{
24    Address, Id, NUM_DELAYS, NetConf, Network, NetworkError, PublicKey, Role, net::Command,
25};
26
27type Result<T> = std::result::Result<T, NetworkError>;
28
29/// Max. bucket number.
30pub const MAX_BUCKET: Bucket = Bucket(u64::MAX);
31
32/// `Retry` wraps a [`Network`] and returns acknowledgements to senders.
33///
34/// It also retries messages until either an acknowledgement has been received
35/// or client code has indicated that the messages are no longer of interest
36/// by invoking `Retry::gc`.
37///
38/// Each message that is sent has a trailer appended that contains the bucket
39/// number and ID of the message. Receivers will send this trailer back. The
40/// sender then stops retrying the corresponding message.
41///
42/// Note that if malicious parties modify the trailer and have it point to a
43/// different message, they can only remove themselves from the set of parties
44/// the sender is expecting an acknowledgement from.
45#[derive(Debug, Clone)]
46pub struct Retry<K> {
47    inner: Arc<Inner<K>>,
48}
49
50#[derive(Debug)]
51struct Inner<K> {
52    this: K,
53    net: Network<K>,
54    sender: Sender<Command<K>>,
55    id: AtomicU64,
56    buffer: Buffer<K>,
57    retry: JoinHandle<Infallible>,
58    pending: Mutex<BTreeMap<Trailer, Pending<K>>>,
59}
60
61impl<K> Drop for Retry<K> {
62    fn drop(&mut self) {
63        self.inner.retry.abort()
64    }
65}
66
67/// Buckets conceptionally contain messages.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
69pub struct Bucket(u64);
70
71/// Messages are associated with IDs and put into buckets.
72///
73/// Bucket numbers are given to us by clients which also garbage collect
74/// explicitly by specifying the bucket up to which to remove messages.
75/// Buckets often correspond to rounds elsewhere.
76#[derive(Debug, Clone)]
77#[allow(clippy::type_complexity)]
78struct Buffer<K>(Arc<Mutex<BTreeMap<Bucket, IntMap<Id, Message<K>>>>>);
79
80impl<K> Default for Buffer<K> {
81    fn default() -> Self {
82        Self(Default::default())
83    }
84}
85
86#[derive(Debug)]
87struct Message<K> {
88    /// The message bytes to (re-)send.
89    data: Bytes,
90    /// The time we started sending this message.
91    time: Instant,
92    /// The number of times we have sent this message.
93    retries: usize,
94    /// The remaining number of parties that have to acknowledge the message.
95    remaining: Vec<K>,
96}
97
98/// Meta information appended at the end of a message.
99#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
100struct Trailer {
101    /// The bucket number the message corresponds to.
102    bucket: Bucket,
103    /// The message ID.
104    id: Id,
105}
106
107/// Data we have received but could not acknowledge yet.
108#[derive(Debug)]
109struct Pending<K> {
110    src: K,
111    data: Bytes,
112    trailer: Bytes,
113}
114
115enum Target<K> {
116    Single(K),
117    Multi(Vec<K>),
118    All,
119}
120
121impl<K> Retry<K>
122where
123    K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
124{
125    pub async fn create(mut cfg: NetConf<K>) -> Result<Self> {
126        cfg.max_message_size += Trailer::SIZE;
127        let delays = cfg.retry_delays;
128        let net = Network::create(cfg).await?;
129        let buffer = Buffer::default();
130        let retry = spawn(retry(buffer.clone(), net.sender(), delays));
131        Ok(Self {
132            inner: Arc::new(Inner {
133                this: net.public_key().clone(),
134                sender: net.sender(),
135                net,
136                buffer,
137                id: AtomicU64::new(0),
138                retry,
139                pending: Mutex::new(BTreeMap::new()),
140            }),
141        })
142    }
143
144    pub async fn broadcast<B>(&self, b: B, data: Vec<u8>) -> Result<Id>
145    where
146        B: Into<Bucket>,
147    {
148        self.send(b.into(), Target::All, data).await
149    }
150
151    pub async fn multicast<B>(&self, to: Vec<K>, b: B, data: Vec<u8>) -> Result<Id>
152    where
153        B: Into<Bucket>,
154    {
155        self.send(b.into(), Target::Multi(to), data).await
156    }
157
158    pub async fn unicast<B>(&self, to: K, b: B, data: Vec<u8>) -> Result<Id>
159    where
160        B: Into<Bucket>,
161    {
162        self.send(b.into(), Target::Single(to), data).await
163    }
164
165    pub async fn add(&self, peers: Vec<(K, PublicKey, Address)>) -> Result<()> {
166        self.inner.net.add(peers).await
167    }
168
169    pub async fn remove(&self, peers: Vec<K>) -> Result<()> {
170        self.inner.net.remove(peers).await
171    }
172
173    pub async fn assign(&self, r: Role, peers: Vec<K>) -> Result<()> {
174        self.inner.net.assign(r, peers).await
175    }
176
177    pub async fn receive(&self) -> Result<(K, Bytes)> {
178        let pending = self.inner.pending.lock().pop_first();
179        if let Some((_, Pending { src, data, trailer })) = pending {
180            self.inner
181                .sender
182                .send(Command::Unicast(src.clone(), None, trailer.clone()))
183                .await
184                .map_err(|_| NetworkError::ChannelClosed)?;
185            return Ok((src, data));
186        }
187        loop {
188            let (src, mut bytes) = self.inner.net.receive().await?;
189
190            let Some((trailer, trailer_bytes)) = Trailer::from_bytes(&mut bytes) else {
191                warn!(node = %self.inner.this, "invalid trailer");
192                continue;
193            };
194
195            if !bytes.is_empty() {
196                // Send the trailer back as acknowledgement:
197                match self
198                    .inner
199                    .sender
200                    .try_send(Command::Unicast(src.clone(), None, trailer_bytes))
201                {
202                    Ok(()) => return Ok((src, bytes)),
203                    Err(TrySendError::Closed(_)) => return Err(NetworkError::ChannelClosed),
204                    Err(TrySendError::Full(Command::Unicast(src, _, trailer_bytes))) => {
205                        // Save received data for cancellation safety:
206                        self.inner.pending.lock().insert(
207                            trailer,
208                            Pending {
209                                src: src.clone(),
210                                data: bytes.clone(),
211                                trailer: trailer_bytes.clone(),
212                            },
213                        );
214                        self.inner
215                            .sender
216                            .send(Command::Unicast(src.clone(), None, trailer_bytes))
217                            .await
218                            .map_err(|_| NetworkError::ChannelClosed)?;
219                        self.inner.pending.lock().remove(&trailer);
220                        return Ok((src, bytes));
221                    },
222                    Err(TrySendError::Full(_)) => {
223                        unreachable!(
224                            "We tried sending a Command::Unicast so this is what we get back."
225                        )
226                    },
227                }
228            }
229
230            let mut messages = self.inner.buffer.0.lock();
231
232            if let Some(buckets) = messages.get_mut(&trailer.bucket)
233                && let Some(m) = buckets.get_mut(&trailer.id)
234            {
235                m.remaining.retain(|k| *k != src);
236                if m.remaining.is_empty() {
237                    buckets.remove(&trailer.id);
238                }
239            }
240        }
241    }
242
243    pub fn gc<B: Into<Bucket>>(&self, bucket: B) {
244        let bucket = bucket.into();
245        self.inner.buffer.0.lock().retain(|b, _| *b >= bucket);
246    }
247
248    pub fn rm<B: Into<Bucket>>(&self, bucket: B, id: Id) {
249        let bucket = bucket.into();
250        if let Some(messages) = self.inner.buffer.0.lock().get_mut(&bucket) {
251            messages.remove(&id);
252        }
253    }
254
255    async fn send(&self, b: Bucket, to: Target<K>, data: Vec<u8>) -> Result<Id> {
256        let id = self.next_id();
257
258        let trailer = Trailer { bucket: b, id };
259
260        let mut msg = BytesMut::from(Bytes::from(data));
261        msg.extend_from_slice(&trailer.to_bytes());
262        let msg = msg.freeze();
263
264        let now = Instant::now();
265
266        let rem = match to {
267            Target::Single(to) => {
268                self.inner
269                    .sender
270                    .send(Command::Unicast(to.clone(), Some(id), msg.clone()))
271                    .await
272                    .map_err(|_| NetworkError::ChannelClosed)?;
273                vec![to]
274            },
275            Target::Multi(peers) => {
276                self.inner
277                    .sender
278                    .send(Command::Multicast(peers.clone(), Some(id), msg.clone()))
279                    .await
280                    .map_err(|_| NetworkError::ChannelClosed)?;
281                peers
282            },
283            Target::All => {
284                self.inner
285                    .sender
286                    .send(Command::Broadcast(Some(id), msg.clone()))
287                    .await
288                    .map_err(|_| NetworkError::ChannelClosed)?;
289                self.inner.net.parties(Role::Active)
290            },
291        };
292
293        self.inner.buffer.0.lock().entry(b).or_default().insert(
294            id,
295            Message {
296                data: msg,
297                time: now,
298                retries: 0,
299                remaining: rem,
300            },
301        );
302
303        Ok(id)
304    }
305
306    fn next_id(&self) -> Id {
307        Id::from(self.inner.id.fetch_add(1, Ordering::Relaxed))
308    }
309}
310
311async fn retry<K>(buf: Buffer<K>, net: Sender<Command<K>>, delays: [u8; NUM_DELAYS]) -> Infallible
312where
313    K: Clone,
314{
315    let mut i = time::interval(Duration::from_secs(1));
316    i.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
317
318    let mut buckets = Vec::new();
319    let mut ids = Vec::new();
320
321    loop {
322        let now = i.tick().await;
323
324        debug_assert!(buckets.is_empty());
325        buckets.extend(buf.0.lock().keys().copied());
326
327        for b in buckets.drain(..) {
328            debug_assert!(ids.is_empty());
329            ids.extend(
330                buf.0
331                    .lock()
332                    .get(&b)
333                    .into_iter()
334                    .flat_map(|m| m.keys().copied()),
335            );
336
337            for id in ids.drain(..) {
338                let message;
339                let remaining;
340
341                {
342                    let mut buf = buf.0.lock();
343                    let Some(m) = buf.get_mut(&b).and_then(|m| m.get_mut(&id)) else {
344                        continue;
345                    };
346
347                    let delay = delays
348                        .get(m.retries)
349                        .copied()
350                        .or_else(|| delays.last().copied())
351                        .unwrap_or(30);
352
353                    if now.saturating_duration_since(m.time) < Duration::from_secs(delay.into()) {
354                        continue;
355                    }
356
357                    m.time = now;
358                    m.retries = m.retries.saturating_add(1);
359
360                    message = m.data.clone();
361                    remaining = m.remaining.clone();
362                }
363
364                let _ = net
365                    .send(Command::Multicast(remaining, Some(id), message.clone()))
366                    .await;
367            }
368        }
369    }
370}
371
372impl From<u64> for Bucket {
373    fn from(val: u64) -> Self {
374        Self(val)
375    }
376}
377
378impl From<Bucket> for u64 {
379    fn from(val: Bucket) -> Self {
380        val.0
381    }
382}
383
384impl fmt::Display for Bucket {
385    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
386        self.0.fmt(f)
387    }
388}
389
390impl Trailer {
391    const SIZE: usize = 16;
392
393    fn from_bytes(bytes: &mut Bytes) -> Option<(Self, Bytes)> {
394        if bytes.len() < Self::SIZE {
395            return None;
396        }
397        let slice = bytes.split_off(bytes.len() - Self::SIZE);
398        let both = u128::from_be_bytes(slice[..].try_into().ok()?);
399        let this = Self {
400            bucket: ((both >> 64) as u64).into(),
401            id: (both as u64).into(),
402        };
403        Some((this, slice))
404    }
405
406    fn to_bytes(self) -> [u8; Self::SIZE] {
407        (u128::from(self.bucket.0) << 64 | u128::from(self.id.0)).to_be_bytes()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use bytes::Bytes;
414    use quickcheck::quickcheck;
415
416    use super::Trailer;
417
418    quickcheck! {
419        fn to_from_bytes(b: u64, i: u64) -> bool {
420            let a = Trailer {
421                bucket: b.into(),
422                id: i.into()
423            };
424            let mut bytes = Bytes::copy_from_slice(&a.to_bytes());
425            let (b, _) = Trailer::from_bytes(&mut bytes).unwrap();
426            a == b
427        }
428    }
429}