hotshot_task/
dependency.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::future::Future;
8
9use async_broadcast::{Receiver, RecvError};
10use futures::{
11    future::BoxFuture,
12    stream::{FuturesUnordered, StreamExt},
13    FutureExt,
14};
15
16/// Type which describes the idea of waiting for a dependency to complete
17pub trait Dependency<T> {
18    /// Complete will wait until it gets some value `T` then return the value
19    fn completed(self) -> impl Future<Output = Option<T>> + Send;
20    /// Create an or dependency from this dependency and another
21    fn or<D: Dependency<T> + Send + 'static>(self, dep: D) -> OrDependency<T>
22    where
23        T: Send + Sync + Clone + 'static,
24        Self: Sized + Send + 'static,
25    {
26        let mut or = OrDependency::from_deps(vec![self]);
27        or.add_dep(dep);
28        or
29    }
30    /// Create an and dependency from this dependency and another
31    fn and<D: Dependency<T> + Send + 'static>(self, dep: D) -> AndDependency<T>
32    where
33        T: Send + Sync + Clone + 'static,
34        Self: Sized + Send + 'static,
35    {
36        let mut and = AndDependency::from_deps(vec![self]);
37        and.add_dep(dep);
38        and
39    }
40}
41
42/// Defines a dependency that completes when all of its deps complete
43pub struct AndDependency<T> {
44    /// Dependencies being combined
45    deps: Vec<BoxFuture<'static, Option<T>>>,
46}
47impl<T: Clone + Send + Sync> Dependency<Vec<T>> for AndDependency<T> {
48    /// Returns a vector of all of the results from it's dependencies.
49    /// The results will be in a random order
50    async fn completed(self) -> Option<Vec<T>> {
51        let futures = FuturesUnordered::from_iter(self.deps);
52        futures
53            .collect::<Vec<Option<T>>>()
54            .await
55            .into_iter()
56            .collect()
57    }
58}
59
60impl<T: Clone + Send + Sync + 'static> AndDependency<T> {
61    /// Create from a vec of deps
62    #[must_use]
63    pub fn from_deps(deps: Vec<impl Dependency<T> + Send + 'static>) -> Self {
64        let mut pinned = vec![];
65        for dep in deps {
66            pinned.push(dep.completed().boxed());
67        }
68        Self { deps: pinned }
69    }
70    /// Add another dependency
71    pub fn add_dep(&mut self, dep: impl Dependency<T> + Send + 'static) {
72        self.deps.push(dep.completed().boxed());
73    }
74    /// Add multiple dependencies
75    pub fn add_deps(&mut self, deps: AndDependency<T>) {
76        for dep in deps.deps {
77            self.deps.push(dep);
78        }
79    }
80}
81
82/// Defines a dependency that completes when one of it's dependencies completes
83pub struct OrDependency<T> {
84    /// Dependencies being combined
85    deps: Vec<BoxFuture<'static, Option<T>>>,
86}
87impl<T: Clone + Send + Sync> Dependency<T> for OrDependency<T> {
88    /// Returns the value of the first completed dependency
89    async fn completed(self) -> Option<T> {
90        let mut futures = FuturesUnordered::from_iter(self.deps);
91        loop {
92            if let Some(maybe) = futures.next().await {
93                if maybe.is_some() {
94                    return maybe;
95                }
96            } else {
97                return None;
98            }
99        }
100    }
101}
102
103impl<T: Clone + Send + Sync + 'static> OrDependency<T> {
104    /// Creat an `OrDependency` from a vec of dependencies
105    #[must_use]
106    pub fn from_deps(deps: Vec<impl Dependency<T> + Send + 'static>) -> Self {
107        let mut pinned = vec![];
108        for dep in deps {
109            pinned.push(dep.completed().boxed());
110        }
111        Self { deps: pinned }
112    }
113    /// Add another dependency
114    pub fn add_dep(&mut self, dep: impl Dependency<T> + Send + 'static) {
115        self.deps.push(dep.completed().boxed());
116    }
117}
118
119/// A dependency that listens on a channel for an event
120/// that matches what some value it wants.
121pub struct EventDependency<T: Clone + Send + Sync> {
122    /// Channel of incoming events
123    pub(crate) event_rx: Receiver<T>,
124
125    /// Closure which returns true if the incoming `T` is the
126    /// thing that completes this dependency
127    pub(crate) match_fn: Box<dyn Fn(&T) -> bool + Send>,
128
129    /// The potentially externally completed dependency. If the dependency was seeded from an event
130    /// message, we can mark it as already done in lieu of other events still pending.
131    completed_dependency: Option<T>,
132
133    cancel_receiver: Receiver<()>,
134
135    dependency_name: String,
136}
137
138impl<T: Clone + Send + Sync + 'static> EventDependency<T> {
139    /// Create a new `EventDependency`
140    #[must_use]
141    pub fn new(
142        receiver: Receiver<T>,
143        cancel_receiver: Receiver<()>,
144        dependency_name: String,
145        match_fn: Box<dyn Fn(&T) -> bool + Send>,
146    ) -> Self {
147        Self {
148            event_rx: receiver,
149            match_fn: Box::new(match_fn),
150            completed_dependency: None,
151            cancel_receiver,
152            dependency_name,
153        }
154    }
155
156    /// Mark a dependency as completed.
157    pub fn mark_as_completed(&mut self, dependency: T) {
158        self.completed_dependency = Some(dependency);
159    }
160}
161
162impl<T: Clone + Send + Sync + 'static> Dependency<T> for EventDependency<T> {
163    async fn completed(mut self) -> Option<T> {
164        if let Some(dependency) = self.completed_dependency {
165            return Some(dependency);
166        }
167        loop {
168            if let Some(dependency) = self.completed_dependency {
169                return Some(dependency);
170            }
171
172            tokio::select! {
173                recv_event = self.event_rx.recv() => {
174                    match recv_event {
175                        Ok(event) => {
176                            if (self.match_fn)(&event) {
177                                return Some(event);
178                            }
179                        },
180                        Err(RecvError::Overflowed(n)) => {
181                            tracing::error!("Dependency Task overloaded, skipping {} events", n);
182                        },
183                        Err(RecvError::Closed) => {
184                            return None;
185                        },
186                    }
187                }
188                _ = self.cancel_receiver.recv() => {
189                   tracing::error!("{} dependency cancelled", self.dependency_name);
190                   return None;
191                }
192            }
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use async_broadcast::{broadcast, Receiver};
200
201    use super::{AndDependency, Dependency, EventDependency, OrDependency};
202
203    fn eq_dep(
204        rx: Receiver<usize>,
205        cancel_rx: Receiver<()>,
206        dep_name: String,
207        val: usize,
208    ) -> EventDependency<usize> {
209        EventDependency {
210            event_rx: rx,
211            match_fn: Box::new(move |v| *v == val),
212            completed_dependency: None,
213            dependency_name: dep_name,
214            cancel_receiver: cancel_rx,
215        }
216    }
217
218    #[tokio::test(flavor = "multi_thread")]
219    async fn it_works() {
220        let (tx, rx) = broadcast(10);
221        let (_cancel_tx, cancel_rx) = broadcast(1);
222
223        let mut deps = vec![];
224        for i in 0..5 {
225            tx.broadcast(i).await.unwrap();
226            deps.push(eq_dep(
227                rx.clone(),
228                cancel_rx.clone(),
229                format!("it_works {i}"),
230                5,
231            ));
232        }
233
234        let and = AndDependency::from_deps(deps);
235        tx.broadcast(5).await.unwrap();
236        let result = and.completed().await;
237        assert_eq!(result, Some(vec![5; 5]));
238    }
239
240    #[tokio::test(flavor = "multi_thread")]
241    async fn or_dep() {
242        let (tx, rx) = broadcast(10);
243        let (_cancel_tx, cancel_rx) = broadcast(1);
244
245        tx.broadcast(5).await.unwrap();
246        let mut deps = vec![];
247        for i in 0..5 {
248            deps.push(eq_dep(
249                rx.clone(),
250                cancel_rx.clone(),
251                format!("or_dep {i}"),
252                5,
253            ));
254        }
255        let or = OrDependency::from_deps(deps);
256        let result = or.completed().await;
257        assert_eq!(result, Some(5));
258    }
259
260    #[tokio::test(flavor = "multi_thread")]
261    async fn and_or_dep() {
262        let (tx, rx) = broadcast(10);
263        let (_cancel_tx, cancel_rx) = broadcast(1);
264
265        tx.broadcast(1).await.unwrap();
266        tx.broadcast(2).await.unwrap();
267        tx.broadcast(3).await.unwrap();
268        tx.broadcast(5).await.unwrap();
269        tx.broadcast(6).await.unwrap();
270
271        let or1 = OrDependency::from_deps(
272            [
273                eq_dep(
274                    rx.clone(),
275                    cancel_rx.clone(),
276                    format!("and_or_dep or1 {}", 4),
277                    4,
278                ),
279                eq_dep(
280                    rx.clone(),
281                    cancel_rx.clone(),
282                    format!("and_or_dep or1 {}", 6),
283                    6,
284                ),
285            ]
286            .into(),
287        );
288        let or2 = OrDependency::from_deps(
289            [
290                eq_dep(
291                    rx.clone(),
292                    cancel_rx.clone(),
293                    format!("and_or_dep or2 {}", 4),
294                    4,
295                ),
296                eq_dep(
297                    rx.clone(),
298                    cancel_rx.clone(),
299                    format!("and_or_dep or2 {}", 5),
300                    5,
301                ),
302            ]
303            .into(),
304        );
305        let and = AndDependency::from_deps([or1, or2].into());
306        let result = and.completed().await;
307        assert_eq!(result, Some(vec![6, 5]));
308    }
309
310    #[tokio::test(flavor = "multi_thread")]
311    async fn or_and_dep() {
312        let (tx, rx) = broadcast(10);
313        let (_cancel_tx, cancel_rx) = broadcast(1);
314
315        tx.broadcast(1).await.unwrap();
316        tx.broadcast(2).await.unwrap();
317        tx.broadcast(3).await.unwrap();
318        tx.broadcast(4).await.unwrap();
319        tx.broadcast(5).await.unwrap();
320
321        let and1 = eq_dep(
322            rx.clone(),
323            cancel_rx.clone(),
324            format!("or_and_dep and1 {}", 4),
325            4,
326        )
327        .and(eq_dep(
328            rx.clone(),
329            cancel_rx.clone(),
330            format!("or_and_dep and1 {}", 6),
331            6,
332        ));
333        let and2 = eq_dep(
334            rx.clone(),
335            cancel_rx.clone(),
336            format!("or_and_dep and2 {}", 4),
337            4,
338        )
339        .and(eq_dep(
340            rx.clone(),
341            cancel_rx.clone(),
342            format!("or_and_dep and2 {}", 5),
343            5,
344        ));
345        let or = and1.or(and2);
346        let result = or.completed().await;
347        assert_eq!(result, Some(vec![4, 5]));
348    }
349
350    #[tokio::test(flavor = "multi_thread")]
351    async fn many_and_dep() {
352        let (tx, rx) = broadcast(10);
353        let (_cancel_tx, cancel_rx) = broadcast(1);
354
355        tx.broadcast(1).await.unwrap();
356        tx.broadcast(2).await.unwrap();
357        tx.broadcast(3).await.unwrap();
358        tx.broadcast(4).await.unwrap();
359        tx.broadcast(5).await.unwrap();
360        tx.broadcast(6).await.unwrap();
361
362        let mut and1 = eq_dep(
363            rx.clone(),
364            cancel_rx.clone(),
365            format!("many_and_dep and1 {}", 4),
366            4,
367        )
368        .and(eq_dep(
369            rx.clone(),
370            cancel_rx.clone(),
371            format!("many_and_dep and1 {}", 6),
372            6,
373        ));
374        let and2 = eq_dep(
375            rx.clone(),
376            cancel_rx.clone(),
377            format!("many_and_dep and2 {}", 4),
378            4,
379        )
380        .and(eq_dep(
381            rx.clone(),
382            cancel_rx.clone(),
383            format!("many_and_dep and2 {}", 5),
384            5,
385        ));
386        and1.add_deps(and2);
387        let result = and1.completed().await;
388        assert_eq!(result, Some(vec![4, 6, 4, 5]));
389    }
390
391    #[tokio::test(flavor = "multi_thread")]
392    async fn cancel_event_dep() {
393        let (tx, rx) = broadcast(10);
394        let (cancel_tx, cancel_rx) = broadcast(1);
395
396        for i in 0..=5 {
397            tx.broadcast(i).await.unwrap();
398        }
399        cancel_tx.broadcast(()).await.unwrap();
400        let dep = eq_dep(
401            rx.clone(),
402            cancel_rx.clone(),
403            format!("cancel_event_dep {}", 6),
404            6,
405        );
406        let result = dep.completed().await;
407        assert_eq!(result, None);
408    }
409
410    #[tokio::test(flavor = "multi_thread")]
411    async fn drop_cancel_dep() {
412        let (tx, rx) = broadcast(10);
413        let (cancel_tx, cancel_rx) = broadcast(1);
414
415        for i in 0..=5 {
416            tx.broadcast(i).await.unwrap();
417        }
418        drop(cancel_tx);
419        let dep = eq_dep(
420            rx.clone(),
421            cancel_rx.clone(),
422            format!("drop_cancel_dep {}", 6),
423            6,
424        );
425        let result = dep.completed().await;
426        assert_eq!(result, None);
427    }
428
429    #[tokio::test(flavor = "multi_thread")]
430    async fn cancel_and_dep() {
431        let (tx, rx) = broadcast(10);
432        let (cancel_tx, cancel_rx) = broadcast(1);
433
434        let mut deps = vec![];
435        for i in 0..=5 {
436            tx.broadcast(i).await.unwrap();
437            deps.push(eq_dep(
438                rx.clone(),
439                cancel_rx.clone(),
440                format!("cancel_and_dep {i}"),
441                i,
442            ))
443        }
444        deps.push(eq_dep(
445            rx.clone(),
446            cancel_rx.clone(),
447            format!("cancel_and_dep {}", 6),
448            6,
449        ));
450        cancel_tx.broadcast(()).await.unwrap();
451        let result = AndDependency::from_deps(deps).completed().await;
452        assert_eq!(result, None);
453    }
454
455    #[tokio::test(flavor = "multi_thread")]
456    async fn cancel_or_dep() {
457        let (_, rx) = broadcast(10);
458        let (cancel_tx, cancel_rx) = broadcast(1);
459
460        let mut deps = vec![];
461        for i in 0..=5 {
462            deps.push(eq_dep(
463                rx.clone(),
464                cancel_rx.clone(),
465                format!("cancel_event_dep {i}"),
466                i,
467            ))
468        }
469        cancel_tx.broadcast(()).await.unwrap();
470        let result = OrDependency::from_deps(deps).completed().await;
471        assert_eq!(result, None);
472    }
473}