hotshot_task/
dependency_task.rs1use futures::Future;
8use tokio::task::{spawn, JoinHandle};
9
10use crate::dependency::Dependency;
11
12pub trait HandleDepOutput: Send + Sized + Sync + 'static {
14 type Output: Send + Sync + 'static;
16
17 fn handle_dep_result(self, res: Self::Output) -> impl Future<Output = ()> + Send;
19}
20
21pub struct DependencyTask<D: Dependency<H::Output> + Send, H: HandleDepOutput + Send> {
23 pub(crate) dep: D,
25 pub(crate) handle: H,
27}
28
29impl<D: Dependency<H::Output> + Send, H: HandleDepOutput + Send> DependencyTask<D, H> {
30 #[must_use]
32 pub fn new(dep: D, handle: H) -> Self {
33 Self { dep, handle }
34 }
35}
36
37impl<D: Dependency<H::Output> + Send + 'static, H: HandleDepOutput> DependencyTask<D, H> {
38 pub fn run(self) -> JoinHandle<()>
40 where
41 Self: Sized,
42 {
43 spawn(async move {
44 if let Some(completed) = self.dep.completed().await {
45 self.handle.handle_dep_result(completed).await;
46 }
47 })
48 }
49}
50
51#[cfg(test)]
52mod test {
53
54 use std::time::Duration;
55
56 use async_broadcast::{broadcast, Receiver, Sender};
57 use futures::{stream::FuturesOrdered, StreamExt};
58 use tokio::time::sleep;
59
60 use super::*;
61 use crate::dependency::*;
62
63 #[derive(Clone, PartialEq, Eq, Debug)]
64 enum TaskResult {
65 Success(usize),
66 }
68
69 struct DummyHandle {
70 sender: Sender<TaskResult>,
71 }
72 impl HandleDepOutput for DummyHandle {
73 type Output = usize;
74 async fn handle_dep_result(self, res: usize) {
75 self.sender
76 .broadcast(TaskResult::Success(res))
77 .await
78 .unwrap();
79 }
80 }
81
82 fn eq_dep(
83 rx: Receiver<usize>,
84 cancel_rx: Receiver<()>,
85 dep_name: String,
86 val: usize,
87 ) -> EventDependency<usize> {
88 EventDependency::new(rx, cancel_rx, dep_name, Box::new(move |v| *v == val))
89 }
90
91 #[tokio::test(flavor = "multi_thread")]
92 #[allow(unused_must_use)]
94 async fn it_works() {
95 let (tx, rx) = broadcast(10);
96 let (_cancel_tx, cancel_rx) = broadcast(1);
97 let (res_tx, mut res_rx) = broadcast(10);
98 let dep = eq_dep(rx, cancel_rx, format!("it_works {}", 2), 2);
99 let handle = DummyHandle { sender: res_tx };
100 let join_handle = DependencyTask { dep, handle }.run();
101 tx.broadcast(2).await.unwrap();
102 assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(2));
103
104 join_handle.await;
105 }
106
107 #[tokio::test(flavor = "multi_thread")]
108 async fn many_works() {
109 let (tx, rx) = broadcast(20);
110 let (res_tx, mut res_rx) = broadcast(20);
111 let (_cancel_tx, cancel_rx) = broadcast(1);
112
113 let mut handles = vec![];
114 for i in 0..10 {
115 let dep = eq_dep(rx.clone(), cancel_rx.clone(), format!("many_works {i}"), i);
116 let handle = DummyHandle {
117 sender: res_tx.clone(),
118 };
119 handles.push(DependencyTask { dep, handle }.run());
120 }
121 let tx2 = tx.clone();
122 spawn(async move {
123 for i in 0..10 {
124 tx.broadcast(i).await.unwrap();
125 sleep(Duration::from_millis(10)).await;
126 }
127 });
128 for i in 0..10 {
129 assert_eq!(res_rx.recv().await.unwrap(), TaskResult::Success(i));
130 }
131 tx2.broadcast(100).await.unwrap();
132 FuturesOrdered::from_iter(handles).collect::<Vec<_>>().await;
133 }
134}