hotshot_task/
dependency_task.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use futures::Future;
8use tokio::task::{spawn, JoinHandle};
9
10use crate::dependency::Dependency;
11
12/// Defines a type that can handle the result of a dependency
13pub trait HandleDepOutput: Send + Sized + Sync + 'static {
14    /// Type we expect from completed dependency
15    type Output: Send + Sync + 'static;
16
17    /// Called once when the Dependency completes handles the results
18    fn handle_dep_result(self, res: Self::Output) -> impl Future<Output = ()> + Send;
19}
20
21/// A task that runs until it's dependency completes and it handles the result
22pub struct DependencyTask<D: Dependency<H::Output> + Send, H: HandleDepOutput + Send> {
23    /// Dependency this tasks waits for
24    pub(crate) dep: D,
25    /// Handles the results returned from `self.dep.completed().await`
26    pub(crate) handle: H,
27}
28
29impl<D: Dependency<H::Output> + Send, H: HandleDepOutput + Send> DependencyTask<D, H> {
30    /// Create a new `DependencyTask`
31    #[must_use]
32    pub fn new(dep: D, handle: H) -> Self {
33        Self { dep, handle }
34    }
35}
36
37impl<D: Dependency<H::Output> + Send + 'static, H: HandleDepOutput> DependencyTask<D, H> {
38    /// Spawn the dependency task
39    pub fn run(self) -> JoinHandle<()>
40    where
41        Self: Sized,
42    {
43        spawn(async move {
44            if let Some(completed) = self.dep.completed().await {
45                self.handle.handle_dep_result(completed).await;
46            }
47        })
48    }
49}
50
51#[cfg(test)]
52mod test {
53
54    use std::time::Duration;
55
56    use async_broadcast::{broadcast, Receiver, Sender};
57    use futures::{stream::FuturesOrdered, StreamExt};
58    use tokio::time::sleep;
59
60    use super::*;
61    use crate::dependency::*;
62
63    #[derive(Clone, PartialEq, Eq, Debug)]
64    enum TaskResult {
65        Success(usize),
66        // Failure,
67    }
68
69    struct DummyHandle {
70        sender: Sender<TaskResult>,
71    }
72    impl HandleDepOutput for DummyHandle {
73        type Output = usize;
74        async fn handle_dep_result(self, res: usize) {
75            self.sender
76                .broadcast(TaskResult::Success(res))
77                .await
78                .unwrap();
79        }
80    }
81
82    fn eq_dep(
83        rx: Receiver<usize>,
84        cancel_rx: Receiver<()>,
85        dep_name: String,
86        val: usize,
87    ) -> EventDependency<usize> {
88        EventDependency::new(rx, cancel_rx, dep_name, Box::new(move |v| *v == val))
89    }
90
91    #[tokio::test(flavor = "multi_thread")]
92    // allow unused for tokio because it's a test
93    #[allow(unused_must_use)]
94    async fn it_works() {
95        let (tx, rx) = broadcast(10);
96        let (_cancel_tx, cancel_rx) = broadcast(1);
97        let (res_tx, mut res_rx) = broadcast(10);
98        let dep = eq_dep(rx, cancel_rx, format!("it_works {}", 2), 2);
99        let handle = DummyHandle { sender: res_tx };
100        let join_handle = DependencyTask { dep, handle }.run();
101        tx.broadcast(2).await.unwrap();
102        assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(2));
103
104        join_handle.await;
105    }
106
107    #[tokio::test(flavor = "multi_thread")]
108    async fn many_works() {
109        let (tx, rx) = broadcast(20);
110        let (res_tx, mut res_rx) = broadcast(20);
111        let (_cancel_tx, cancel_rx) = broadcast(1);
112
113        let mut handles = vec![];
114        for i in 0..10 {
115            let dep = eq_dep(rx.clone(), cancel_rx.clone(), format!("many_works {i}"), i);
116            let handle = DummyHandle {
117                sender: res_tx.clone(),
118            };
119            handles.push(DependencyTask { dep, handle }.run());
120        }
121        let tx2 = tx.clone();
122        spawn(async move {
123            for i in 0..10 {
124                tx.broadcast(i).await.unwrap();
125                sleep(Duration::from_millis(10)).await;
126            }
127        });
128        for i in 0..10 {
129            assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(i));
130        }
131        tx2.broadcast(100).await.unwrap();
132        FuturesOrdered::from_iter(handles).collect::<Vec<_>>().await;
133    }
134}