1use std::{
2 collections::BTreeMap,
3 convert::Infallible,
4 fmt::{self, Display},
5 hash::Hash,
6 result::Result as StdResult,
7 sync::{
8 Arc,
9 atomic::{AtomicU64, Ordering},
10 },
11};
12
13use bytes::{Bytes, BytesMut};
14use hotshot_types::addr::NetAddr;
15use nohash_hasher::IntMap;
16use parking_lot::Mutex;
17use scopeguard::{ScopeGuard, guard};
18use tokio::{
19 spawn,
20 sync::mpsc::{Sender, error::TrySendError},
21 task::JoinHandle,
22 time::{self, Duration, Instant},
23};
24use tracing::warn;
25
26use crate::{
27 Id, NUM_DELAYS, NetConf, Network, NetworkDown, NetworkError, PublicKey, Role, net::Command,
28};
29
30type Result<T> = std::result::Result<T, NetworkError>;
31
32pub const MAX_BUCKET: Bucket = Bucket(u64::MAX);
34
35#[derive(Debug, Clone)]
49pub struct Retry<K> {
50 inner: Arc<Inner<K>>,
51}
52
53#[derive(Debug)]
54struct Inner<K> {
55 net: Network<K>,
56 sender: Sender<Command<K>>,
57 id: AtomicU64,
58 buffer: Buffer<K>,
59 retry: JoinHandle<Infallible>,
60 pending: Mutex<BTreeMap<Trailer, Pending<K>>>,
61}
62
63impl<K> Drop for Retry<K> {
64 fn drop(&mut self) {
65 self.inner.retry.abort()
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
71pub struct Bucket(u64);
72
73#[derive(Debug, Clone)]
79#[allow(clippy::type_complexity)]
80struct Buffer<K>(Arc<Mutex<BTreeMap<Bucket, IntMap<Id, Message<K>>>>>);
81
82impl<K> Default for Buffer<K> {
83 fn default() -> Self {
84 Self(Default::default())
85 }
86}
87
88#[derive(Debug)]
89struct Message<K> {
90 data: Bytes,
92 time: Instant,
94 retries: usize,
96 remaining: Vec<K>,
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
102struct Trailer {
103 bucket: Bucket,
105 id: Id,
107}
108
109#[derive(Debug, Clone)]
111struct Pending<K> {
112 src: K,
113 data: Bytes,
114 trailer: Bytes,
115}
116
117enum Target<K> {
118 Single(K),
119 Multi(Vec<K>),
120 All,
121}
122
123impl<K> Retry<K>
124where
125 K: Eq + Ord + Clone + Display + Hash + Send + Sync + 'static,
126{
127 pub async fn create(mut cfg: NetConf<K>) -> Result<Self> {
128 cfg.max_message_size += Trailer::SIZE;
129 let delays = cfg.retry_delays;
130 let net = Network::create(cfg).await?;
131 let buffer = Buffer::default();
132 let retry = spawn(retry(buffer.clone(), net.sender(), delays));
133 Ok(Self {
134 inner: Arc::new(Inner {
135 sender: net.sender(),
136 net,
137 buffer,
138 id: AtomicU64::new(0),
139 retry,
140 pending: Mutex::new(BTreeMap::new()),
141 }),
142 })
143 }
144
145 pub async fn close(&self) {
146 self.inner.retry.abort();
147 self.inner.net.close().await;
148 }
149
150 pub fn parties(&self, r: Option<Role>) -> Vec<K> {
151 self.inner.net.parties(r)
152 }
153
154 pub async fn broadcast<B>(&self, b: B, data: Vec<u8>) -> Result<Id>
155 where
156 B: Into<Bucket>,
157 {
158 self.send(b.into(), Target::All, data).await
159 }
160
161 pub async fn multicast<B>(&self, to: Vec<K>, b: B, data: Vec<u8>) -> Result<Id>
162 where
163 B: Into<Bucket>,
164 {
165 self.send(b.into(), Target::Multi(to), data).await
166 }
167
168 pub async fn unicast<B>(&self, to: K, b: B, data: Vec<u8>) -> Result<Id>
169 where
170 B: Into<Bucket>,
171 {
172 self.send(b.into(), Target::Single(to), data).await
173 }
174
175 pub async fn add(
176 &self,
177 r: Role,
178 peers: Vec<(K, PublicKey, NetAddr)>,
179 ) -> StdResult<(), NetworkDown> {
180 self.inner.net.add(r, peers).await
181 }
182
183 pub async fn remove(&self, peers: Vec<K>) -> StdResult<(), NetworkDown> {
184 self.inner.net.remove(peers).await
185 }
186
187 pub async fn assign(&self, r: Role, peers: Vec<K>) -> StdResult<(), NetworkDown> {
188 self.inner.net.assign(r, peers).await
189 }
190
191 pub async fn receive(&self) -> Result<(K, Bytes)> {
192 let first_entry = self.inner.pending.lock().pop_first();
193 if let Some((key, val)) = first_entry {
194 let pending = &self.inner.pending;
195 let guard = guard((key, val.clone()), |(k, v)| {
197 pending.lock().insert(k, v);
198 });
199 self.inner
200 .sender
201 .send(Command::Unicast(val.src.clone(), None, val.trailer.clone()))
202 .await
203 .map_err(|_| NetworkError::ChannelClosed)?;
204 let _ = ScopeGuard::into_inner(guard);
205 return Ok((val.src, val.data));
206 }
207 loop {
208 let (src, mut bytes) = self.inner.net.receive().await?;
209
210 let Some((trailer, trailer_bytes)) = Trailer::from_bytes(&mut bytes) else {
211 warn!(node = %self.inner.net.label, "invalid trailer");
212 continue;
213 };
214
215 if !bytes.is_empty() {
216 match self
218 .inner
219 .sender
220 .try_send(Command::Unicast(src.clone(), None, trailer_bytes))
221 {
222 Ok(()) => return Ok((src, bytes)),
223 Err(TrySendError::Closed(_)) => return Err(NetworkError::ChannelClosed),
224 Err(TrySendError::Full(Command::Unicast(src, _, trailer_bytes))) => {
225 self.inner.pending.lock().insert(
227 trailer,
228 Pending {
229 src: src.clone(),
230 data: bytes.clone(),
231 trailer: trailer_bytes.clone(),
232 },
233 );
234 self.inner
235 .sender
236 .send(Command::Unicast(src.clone(), None, trailer_bytes))
237 .await
238 .map_err(|_| NetworkError::ChannelClosed)?;
239 self.inner.pending.lock().remove(&trailer);
240 return Ok((src, bytes));
241 },
242 Err(TrySendError::Full(_)) => {
243 unreachable!(
244 "We tried sending a Command::Unicast so this is what we get back."
245 )
246 },
247 }
248 }
249
250 let mut messages = self.inner.buffer.0.lock();
251
252 if let Some(buckets) = messages.get_mut(&trailer.bucket)
253 && let Some(m) = buckets.get_mut(&trailer.id)
254 {
255 m.remaining.retain(|k| *k != src);
256 if m.remaining.is_empty() {
257 buckets.remove(&trailer.id);
258 }
259 }
260 }
261 }
262
263 pub fn gc<B: Into<Bucket>>(&self, bucket: B) {
264 let bucket = bucket.into();
265 self.inner.buffer.0.lock().retain(|b, _| *b >= bucket);
266 }
267
268 pub fn rm<B: Into<Bucket>>(&self, bucket: B, id: Id) {
269 let bucket = bucket.into();
270 if let Some(messages) = self.inner.buffer.0.lock().get_mut(&bucket) {
271 messages.remove(&id);
272 }
273 }
274
275 async fn send(&self, b: Bucket, to: Target<K>, data: Vec<u8>) -> Result<Id> {
276 let id = self.next_id();
277
278 let trailer = Trailer { bucket: b, id };
279
280 let mut msg = BytesMut::from(Bytes::from(data));
281 msg.extend_from_slice(&trailer.to_bytes());
282 let msg = msg.freeze();
283
284 if msg.len() > self.inner.net.max_message_size {
285 warn!(
286 name = %self.inner.net.name,
287 node = %self.inner.net.label,
288 len = %msg.len(),
289 max = %self.inner.net.max_message_size,
290 "message too large to send"
291 );
292 return Err(NetworkError::MessageTooLarge);
293 }
294
295 let now = Instant::now();
296
297 let rem = match to {
298 Target::Single(to) => {
299 self.inner
300 .sender
301 .send(Command::Unicast(to.clone(), Some(id), msg.clone()))
302 .await
303 .map_err(|_| NetworkError::ChannelClosed)?;
304 vec![to]
305 },
306 Target::Multi(peers) => {
307 self.inner
308 .sender
309 .send(Command::Multicast(peers.clone(), Some(id), msg.clone()))
310 .await
311 .map_err(|_| NetworkError::ChannelClosed)?;
312 peers
313 },
314 Target::All => {
315 self.inner
316 .sender
317 .send(Command::Broadcast(Some(id), msg.clone()))
318 .await
319 .map_err(|_| NetworkError::ChannelClosed)?;
320 self.inner.net.parties(Some(Role::Active))
321 },
322 };
323
324 self.inner.buffer.0.lock().entry(b).or_default().insert(
325 id,
326 Message {
327 data: msg,
328 time: now,
329 retries: 0,
330 remaining: rem,
331 },
332 );
333
334 Ok(id)
335 }
336
337 fn next_id(&self) -> Id {
338 Id::from(self.inner.id.fetch_add(1, Ordering::Relaxed))
339 }
340}
341
342async fn retry<K>(buf: Buffer<K>, net: Sender<Command<K>>, delays: [u8; NUM_DELAYS]) -> Infallible
343where
344 K: Clone,
345{
346 let mut i = time::interval(Duration::from_secs(1));
347 i.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
348
349 let mut buckets = Vec::new();
350 let mut ids = Vec::new();
351
352 loop {
353 let now = i.tick().await;
354
355 debug_assert!(buckets.is_empty());
356 buckets.extend(buf.0.lock().keys().copied());
357
358 for b in buckets.drain(..) {
359 debug_assert!(ids.is_empty());
360 ids.extend(
361 buf.0
362 .lock()
363 .get(&b)
364 .into_iter()
365 .flat_map(|m| m.keys().copied()),
366 );
367
368 for id in ids.drain(..) {
369 let message;
370 let remaining;
371
372 {
373 let mut buf = buf.0.lock();
374 let Some(m) = buf.get_mut(&b).and_then(|m| m.get_mut(&id)) else {
375 continue;
376 };
377
378 let delay = delays
379 .get(m.retries)
380 .copied()
381 .or_else(|| delays.last().copied())
382 .unwrap_or(30);
383
384 if now.saturating_duration_since(m.time) < Duration::from_secs(delay.into()) {
385 continue;
386 }
387
388 m.time = now;
389 m.retries = m.retries.saturating_add(1);
390
391 message = m.data.clone();
392 remaining = m.remaining.clone();
393 }
394
395 let _ = net
396 .send(Command::Multicast(remaining, Some(id), message.clone()))
397 .await;
398 }
399 }
400 }
401}
402
403impl From<u64> for Bucket {
404 fn from(val: u64) -> Self {
405 Self(val)
406 }
407}
408
409impl From<Bucket> for u64 {
410 fn from(val: Bucket) -> Self {
411 val.0
412 }
413}
414
415impl fmt::Display for Bucket {
416 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
417 self.0.fmt(f)
418 }
419}
420
421impl Trailer {
424 const SIZE: usize = 17;
425
426 fn from_bytes(bytes: &mut Bytes) -> Option<(Self, Bytes)> {
427 let len: usize = bytes.last().copied()?.into();
428 if len < Self::SIZE || bytes.len() < len {
429 return None;
430 }
431 let slice = bytes.split_off(bytes.len() - len);
432 let id = u64::from_be_bytes(slice[len - 9..len - 1].try_into().ok()?);
433 let bucket = u64::from_be_bytes(slice[len - 17..len - 9].try_into().ok()?);
434 let this = Self {
435 bucket: bucket.into(),
436 id: id.into(),
437 };
438 Some((this, slice))
439 }
440
441 fn to_bytes(self) -> [u8; Self::SIZE] {
442 let mut buf = [0; Self::SIZE];
443 buf[..8].copy_from_slice(&self.bucket.0.to_be_bytes()[..]);
444 buf[8..16].copy_from_slice(&self.id.0.to_be_bytes()[..]);
445 buf[16] = Self::SIZE as u8;
446 buf
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use bytes::Bytes;
453 use quickcheck::quickcheck;
454
455 use super::Trailer;
456
457 quickcheck! {
458 fn to_from_bytes(b: u64, i: u64) -> bool {
459 let a = Trailer {
460 bucket: b.into(),
461 id: i.into()
462 };
463 let mut bytes = Bytes::copy_from_slice(&a.to_bytes());
464 let (b, _) = Trailer::from_bytes(&mut bytes).unwrap();
465 a == b
466 }
467 }
468}