1use std::{
2 collections::BTreeMap,
3 convert::Infallible,
4 fmt::{self, Display},
5 hash::Hash,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10};
11
12use bytes::{Bytes, BytesMut};
13use nohash_hasher::IntMap;
14use parking_lot::Mutex;
15use tokio::{
16 spawn,
17 sync::mpsc::{Sender, error::TrySendError},
18 task::JoinHandle,
19 time::{self, Duration, Instant},
20};
21use tracing::warn;
22
23use crate::{
24 Address, Id, NUM_DELAYS, NetConf, Network, NetworkError, PublicKey, Role, net::Command,
25};
26
27type Result<T> = std::result::Result<T, NetworkError>;
28
29pub const MAX_BUCKET: Bucket = Bucket(u64::MAX);
31
32#[derive(Debug, Clone)]
46pub struct Retry<K> {
47 inner: Arc<Inner<K>>,
48}
49
50#[derive(Debug)]
51struct Inner<K> {
52 this: K,
53 net: Network<K>,
54 sender: Sender<Command<K>>,
55 id: AtomicU64,
56 buffer: Buffer<K>,
57 retry: JoinHandle<Infallible>,
58 pending: Mutex<BTreeMap<Trailer, Pending<K>>>,
59}
60
61impl<K> Drop for Retry<K> {
62 fn drop(&mut self) {
63 self.inner.retry.abort()
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
69pub struct Bucket(u64);
70
71#[derive(Debug, Clone)]
77#[allow(clippy::type_complexity)]
78struct Buffer<K>(Arc<Mutex<BTreeMap<Bucket, IntMap<Id, Message<K>>>>>);
79
80impl<K> Default for Buffer<K> {
81 fn default() -> Self {
82 Self(Default::default())
83 }
84}
85
86#[derive(Debug)]
87struct Message<K> {
88 data: Bytes,
90 time: Instant,
92 retries: usize,
94 remaining: Vec<K>,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
100struct Trailer {
101 bucket: Bucket,
103 id: Id,
105}
106
107#[derive(Debug)]
109struct Pending<K> {
110 src: K,
111 data: Bytes,
112 trailer: Bytes,
113}
114
115enum Target<K> {
116 Single(K),
117 Multi(Vec<K>),
118 All,
119}
120
121impl<K> Retry<K>
122where
123 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
124{
125 pub async fn create(mut cfg: NetConf<K>) -> Result<Self> {
126 cfg.max_message_size += Trailer::SIZE;
127 let delays = cfg.retry_delays;
128 let net = Network::create(cfg).await?;
129 let buffer = Buffer::default();
130 let retry = spawn(retry(buffer.clone(), net.sender(), delays));
131 Ok(Self {
132 inner: Arc::new(Inner {
133 this: net.public_key().clone(),
134 sender: net.sender(),
135 net,
136 buffer,
137 id: AtomicU64::new(0),
138 retry,
139 pending: Mutex::new(BTreeMap::new()),
140 }),
141 })
142 }
143
144 pub async fn broadcast<B>(&self, b: B, data: Vec<u8>) -> Result<Id>
145 where
146 B: Into<Bucket>,
147 {
148 self.send(b.into(), Target::All, data).await
149 }
150
151 pub async fn multicast<B>(&self, to: Vec<K>, b: B, data: Vec<u8>) -> Result<Id>
152 where
153 B: Into<Bucket>,
154 {
155 self.send(b.into(), Target::Multi(to), data).await
156 }
157
158 pub async fn unicast<B>(&self, to: K, b: B, data: Vec<u8>) -> Result<Id>
159 where
160 B: Into<Bucket>,
161 {
162 self.send(b.into(), Target::Single(to), data).await
163 }
164
165 pub async fn add(&self, peers: Vec<(K, PublicKey, Address)>) -> Result<()> {
166 self.inner.net.add(peers).await
167 }
168
169 pub async fn remove(&self, peers: Vec<K>) -> Result<()> {
170 self.inner.net.remove(peers).await
171 }
172
173 pub async fn assign(&self, r: Role, peers: Vec<K>) -> Result<()> {
174 self.inner.net.assign(r, peers).await
175 }
176
177 pub async fn receive(&self) -> Result<(K, Bytes)> {
178 let pending = self.inner.pending.lock().pop_first();
179 if let Some((_, Pending { src, data, trailer })) = pending {
180 self.inner
181 .sender
182 .send(Command::Unicast(src.clone(), None, trailer.clone()))
183 .await
184 .map_err(|_| NetworkError::ChannelClosed)?;
185 return Ok((src, data));
186 }
187 loop {
188 let (src, mut bytes) = self.inner.net.receive().await?;
189
190 let Some((trailer, trailer_bytes)) = Trailer::from_bytes(&mut bytes) else {
191 warn!(node = %self.inner.this, "invalid trailer");
192 continue;
193 };
194
195 if !bytes.is_empty() {
196 match self
198 .inner
199 .sender
200 .try_send(Command::Unicast(src.clone(), None, trailer_bytes))
201 {
202 Ok(()) => return Ok((src, bytes)),
203 Err(TrySendError::Closed(_)) => return Err(NetworkError::ChannelClosed),
204 Err(TrySendError::Full(Command::Unicast(src, _, trailer_bytes))) => {
205 self.inner.pending.lock().insert(
207 trailer,
208 Pending {
209 src: src.clone(),
210 data: bytes.clone(),
211 trailer: trailer_bytes.clone(),
212 },
213 );
214 self.inner
215 .sender
216 .send(Command::Unicast(src.clone(), None, trailer_bytes))
217 .await
218 .map_err(|_| NetworkError::ChannelClosed)?;
219 self.inner.pending.lock().remove(&trailer);
220 return Ok((src, bytes));
221 },
222 Err(TrySendError::Full(_)) => {
223 unreachable!(
224 "We tried sending a Command::Unicast so this is what we get back."
225 )
226 },
227 }
228 }
229
230 let mut messages = self.inner.buffer.0.lock();
231
232 if let Some(buckets) = messages.get_mut(&trailer.bucket)
233 && let Some(m) = buckets.get_mut(&trailer.id)
234 {
235 m.remaining.retain(|k| *k != src);
236 if m.remaining.is_empty() {
237 buckets.remove(&trailer.id);
238 }
239 }
240 }
241 }
242
243 pub fn gc<B: Into<Bucket>>(&self, bucket: B) {
244 let bucket = bucket.into();
245 self.inner.buffer.0.lock().retain(|b, _| *b >= bucket);
246 }
247
248 pub fn rm<B: Into<Bucket>>(&self, bucket: B, id: Id) {
249 let bucket = bucket.into();
250 if let Some(messages) = self.inner.buffer.0.lock().get_mut(&bucket) {
251 messages.remove(&id);
252 }
253 }
254
255 async fn send(&self, b: Bucket, to: Target<K>, data: Vec<u8>) -> Result<Id> {
256 let id = self.next_id();
257
258 let trailer = Trailer { bucket: b, id };
259
260 let mut msg = BytesMut::from(Bytes::from(data));
261 msg.extend_from_slice(&trailer.to_bytes());
262 let msg = msg.freeze();
263
264 let now = Instant::now();
265
266 let rem = match to {
267 Target::Single(to) => {
268 self.inner
269 .sender
270 .send(Command::Unicast(to.clone(), Some(id), msg.clone()))
271 .await
272 .map_err(|_| NetworkError::ChannelClosed)?;
273 vec![to]
274 },
275 Target::Multi(peers) => {
276 self.inner
277 .sender
278 .send(Command::Multicast(peers.clone(), Some(id), msg.clone()))
279 .await
280 .map_err(|_| NetworkError::ChannelClosed)?;
281 peers
282 },
283 Target::All => {
284 self.inner
285 .sender
286 .send(Command::Broadcast(Some(id), msg.clone()))
287 .await
288 .map_err(|_| NetworkError::ChannelClosed)?;
289 self.inner.net.parties(Role::Active)
290 },
291 };
292
293 self.inner.buffer.0.lock().entry(b).or_default().insert(
294 id,
295 Message {
296 data: msg,
297 time: now,
298 retries: 0,
299 remaining: rem,
300 },
301 );
302
303 Ok(id)
304 }
305
306 fn next_id(&self) -> Id {
307 Id::from(self.inner.id.fetch_add(1, Ordering::Relaxed))
308 }
309}
310
311async fn retry<K>(buf: Buffer<K>, net: Sender<Command<K>>, delays: [u8; NUM_DELAYS]) -> Infallible
312where
313 K: Clone,
314{
315 let mut i = time::interval(Duration::from_secs(1));
316 i.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
317
318 let mut buckets = Vec::new();
319 let mut ids = Vec::new();
320
321 loop {
322 let now = i.tick().await;
323
324 debug_assert!(buckets.is_empty());
325 buckets.extend(buf.0.lock().keys().copied());
326
327 for b in buckets.drain(..) {
328 debug_assert!(ids.is_empty());
329 ids.extend(
330 buf.0
331 .lock()
332 .get(&b)
333 .into_iter()
334 .flat_map(|m| m.keys().copied()),
335 );
336
337 for id in ids.drain(..) {
338 let message;
339 let remaining;
340
341 {
342 let mut buf = buf.0.lock();
343 let Some(m) = buf.get_mut(&b).and_then(|m| m.get_mut(&id)) else {
344 continue;
345 };
346
347 let delay = delays
348 .get(m.retries)
349 .copied()
350 .or_else(|| delays.last().copied())
351 .unwrap_or(30);
352
353 if now.saturating_duration_since(m.time) < Duration::from_secs(delay.into()) {
354 continue;
355 }
356
357 m.time = now;
358 m.retries = m.retries.saturating_add(1);
359
360 message = m.data.clone();
361 remaining = m.remaining.clone();
362 }
363
364 let _ = net
365 .send(Command::Multicast(remaining, Some(id), message.clone()))
366 .await;
367 }
368 }
369 }
370}
371
372impl From<u64> for Bucket {
373 fn from(val: u64) -> Self {
374 Self(val)
375 }
376}
377
378impl From<Bucket> for u64 {
379 fn from(val: Bucket) -> Self {
380 val.0
381 }
382}
383
384impl fmt::Display for Bucket {
385 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
386 self.0.fmt(f)
387 }
388}
389
390impl Trailer {
391 const SIZE: usize = 16;
392
393 fn from_bytes(bytes: &mut Bytes) -> Option<(Self, Bytes)> {
394 if bytes.len() < Self::SIZE {
395 return None;
396 }
397 let slice = bytes.split_off(bytes.len() - Self::SIZE);
398 let both = u128::from_be_bytes(slice[..].try_into().ok()?);
399 let this = Self {
400 bucket: ((both >> 64) as u64).into(),
401 id: (both as u64).into(),
402 };
403 Some((this, slice))
404 }
405
406 fn to_bytes(self) -> [u8; Self::SIZE] {
407 (u128::from(self.bucket.0) << 64 | u128::from(self.id.0)).to_be_bytes()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use bytes::Bytes;
414 use quickcheck::quickcheck;
415
416 use super::Trailer;
417
418 quickcheck! {
419 fn to_from_bytes(b: u64, i: u64) -> bool {
420 let a = Trailer {
421 bucket: b.into(),
422 id: i.into()
423 };
424 let mut bytes = Bytes::copy_from_slice(&a.to_bytes());
425 let (b, _) = Trailer::from_bytes(&mut bytes).unwrap();
426 a == b
427 }
428 }
429}