1use std::future::Future;
8
9use async_broadcast::{Receiver, RecvError};
10use futures::{
11 future::BoxFuture,
12 stream::{FuturesUnordered, StreamExt},
13 FutureExt,
14};
15
16pub trait Dependency<T> {
18 fn completed(self) -> impl Future<Output = Option<T>> + Send;
20 fn or<D: Dependency<T> + Send + 'static>(self, dep: D) -> OrDependency<T>
22 where
23 T: Send + Sync + Clone + 'static,
24 Self: Sized + Send + 'static,
25 {
26 let mut or = OrDependency::from_deps(vec![self]);
27 or.add_dep(dep);
28 or
29 }
30 fn and<D: Dependency<T> + Send + 'static>(self, dep: D) -> AndDependency<T>
32 where
33 T: Send + Sync + Clone + 'static,
34 Self: Sized + Send + 'static,
35 {
36 let mut and = AndDependency::from_deps(vec![self]);
37 and.add_dep(dep);
38 and
39 }
40}
41
42pub struct AndDependency<T> {
44 deps: Vec<BoxFuture<'static, Option<T>>>,
46}
47impl<T: Clone + Send + Sync> Dependency<Vec<T>> for AndDependency<T> {
48 async fn completed(self) -> Option<Vec<T>> {
51 let futures = FuturesUnordered::from_iter(self.deps);
52 futures
53 .collect::<Vec<Option<T>>>()
54 .await
55 .into_iter()
56 .collect()
57 }
58}
59
60impl<T: Clone + Send + Sync + 'static> AndDependency<T> {
61 #[must_use]
63 pub fn from_deps(deps: Vec<impl Dependency<T> + Send + 'static>) -> Self {
64 let mut pinned = vec![];
65 for dep in deps {
66 pinned.push(dep.completed().boxed());
67 }
68 Self { deps: pinned }
69 }
70 pub fn add_dep(&mut self, dep: impl Dependency<T> + Send + 'static) {
72 self.deps.push(dep.completed().boxed());
73 }
74 pub fn add_deps(&mut self, deps: AndDependency<T>) {
76 for dep in deps.deps {
77 self.deps.push(dep);
78 }
79 }
80}
81
82pub struct OrDependency<T> {
84 deps: Vec<BoxFuture<'static, Option<T>>>,
86}
87impl<T: Clone + Send + Sync> Dependency<T> for OrDependency<T> {
88 async fn completed(self) -> Option<T> {
90 let mut futures = FuturesUnordered::from_iter(self.deps);
91 loop {
92 if let Some(maybe) = futures.next().await {
93 if maybe.is_some() {
94 return maybe;
95 }
96 } else {
97 return None;
98 }
99 }
100 }
101}
102
103impl<T: Clone + Send + Sync + 'static> OrDependency<T> {
104 #[must_use]
106 pub fn from_deps(deps: Vec<impl Dependency<T> + Send + 'static>) -> Self {
107 let mut pinned = vec![];
108 for dep in deps {
109 pinned.push(dep.completed().boxed());
110 }
111 Self { deps: pinned }
112 }
113 pub fn add_dep(&mut self, dep: impl Dependency<T> + Send + 'static) {
115 self.deps.push(dep.completed().boxed());
116 }
117}
118
119pub struct EventDependency<T: Clone + Send + Sync> {
122 pub(crate) event_rx: Receiver<T>,
124
125 pub(crate) match_fn: Box<dyn Fn(&T) -> bool + Send>,
128
129 completed_dependency: Option<T>,
132
133 cancel_receiver: Receiver<()>,
134
135 dependency_name: String,
136}
137
138impl<T: Clone + Send + Sync + 'static> EventDependency<T> {
139 #[must_use]
141 pub fn new(
142 receiver: Receiver<T>,
143 cancel_receiver: Receiver<()>,
144 dependency_name: String,
145 match_fn: Box<dyn Fn(&T) -> bool + Send>,
146 ) -> Self {
147 Self {
148 event_rx: receiver,
149 match_fn: Box::new(match_fn),
150 completed_dependency: None,
151 cancel_receiver,
152 dependency_name,
153 }
154 }
155
156 pub fn mark_as_completed(&mut self, dependency: T) {
158 self.completed_dependency = Some(dependency);
159 }
160}
161
162impl<T: Clone + Send + Sync + 'static> Dependency<T> for EventDependency<T> {
163 async fn completed(mut self) -> Option<T> {
164 if let Some(dependency) = self.completed_dependency {
165 return Some(dependency);
166 }
167 loop {
168 if let Some(dependency) = self.completed_dependency {
169 return Some(dependency);
170 }
171
172 tokio::select! {
173 recv_event = self.event_rx.recv() => {
174 match recv_event {
175 Ok(event) => {
176 if (self.match_fn)(&event) {
177 return Some(event);
178 }
179 },
180 Err(RecvError::Overflowed(n)) => {
181 tracing::error!("Dependency Task overloaded, skipping {} events", n);
182 },
183 Err(RecvError::Closed) => {
184 return None;
185 },
186 }
187 }
188 _ = self.cancel_receiver.recv() => {
189 tracing::error!("{} dependency cancelled", self.dependency_name);
190 return None;
191 }
192 }
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use async_broadcast::{broadcast, Receiver};
200
201 use super::{AndDependency, Dependency, EventDependency, OrDependency};
202
203 fn eq_dep(
204 rx: Receiver<usize>,
205 cancel_rx: Receiver<()>,
206 dep_name: String,
207 val: usize,
208 ) -> EventDependency<usize> {
209 EventDependency {
210 event_rx: rx,
211 match_fn: Box::new(move |v| *v == val),
212 completed_dependency: None,
213 dependency_name: dep_name,
214 cancel_receiver: cancel_rx,
215 }
216 }
217
218 #[tokio::test(flavor = "multi_thread")]
219 async fn it_works() {
220 let (tx, rx) = broadcast(10);
221 let (_cancel_tx, cancel_rx) = broadcast(1);
222
223 let mut deps = vec![];
224 for i in 0..5 {
225 tx.broadcast(i).await.unwrap();
226 deps.push(eq_dep(
227 rx.clone(),
228 cancel_rx.clone(),
229 format!("it_works {i}"),
230 5,
231 ));
232 }
233
234 let and = AndDependency::from_deps(deps);
235 tx.broadcast(5).await.unwrap();
236 let result = and.completed().await;
237 assert_eq!(result, Some(vec![5; 5]));
238 }
239
240 #[tokio::test(flavor = "multi_thread")]
241 async fn or_dep() {
242 let (tx, rx) = broadcast(10);
243 let (_cancel_tx, cancel_rx) = broadcast(1);
244
245 tx.broadcast(5).await.unwrap();
246 let mut deps = vec![];
247 for i in 0..5 {
248 deps.push(eq_dep(
249 rx.clone(),
250 cancel_rx.clone(),
251 format!("or_dep {i}"),
252 5,
253 ));
254 }
255 let or = OrDependency::from_deps(deps);
256 let result = or.completed().await;
257 assert_eq!(result, Some(5));
258 }
259
260 #[tokio::test(flavor = "multi_thread")]
261 async fn and_or_dep() {
262 let (tx, rx) = broadcast(10);
263 let (_cancel_tx, cancel_rx) = broadcast(1);
264
265 tx.broadcast(1).await.unwrap();
266 tx.broadcast(2).await.unwrap();
267 tx.broadcast(3).await.unwrap();
268 tx.broadcast(5).await.unwrap();
269 tx.broadcast(6).await.unwrap();
270
271 let or1 = OrDependency::from_deps(
272 [
273 eq_dep(
274 rx.clone(),
275 cancel_rx.clone(),
276 format!("and_or_dep or1 {}", 4),
277 4,
278 ),
279 eq_dep(
280 rx.clone(),
281 cancel_rx.clone(),
282 format!("and_or_dep or1 {}", 6),
283 6,
284 ),
285 ]
286 .into(),
287 );
288 let or2 = OrDependency::from_deps(
289 [
290 eq_dep(
291 rx.clone(),
292 cancel_rx.clone(),
293 format!("and_or_dep or2 {}", 4),
294 4,
295 ),
296 eq_dep(
297 rx.clone(),
298 cancel_rx.clone(),
299 format!("and_or_dep or2 {}", 5),
300 5,
301 ),
302 ]
303 .into(),
304 );
305 let and = AndDependency::from_deps([or1, or2].into());
306 let result = and.completed().await;
307 assert_eq!(result, Some(vec![6, 5]));
308 }
309
310 #[tokio::test(flavor = "multi_thread")]
311 async fn or_and_dep() {
312 let (tx, rx) = broadcast(10);
313 let (_cancel_tx, cancel_rx) = broadcast(1);
314
315 tx.broadcast(1).await.unwrap();
316 tx.broadcast(2).await.unwrap();
317 tx.broadcast(3).await.unwrap();
318 tx.broadcast(4).await.unwrap();
319 tx.broadcast(5).await.unwrap();
320
321 let and1 = eq_dep(
322 rx.clone(),
323 cancel_rx.clone(),
324 format!("or_and_dep and1 {}", 4),
325 4,
326 )
327 .and(eq_dep(
328 rx.clone(),
329 cancel_rx.clone(),
330 format!("or_and_dep and1 {}", 6),
331 6,
332 ));
333 let and2 = eq_dep(
334 rx.clone(),
335 cancel_rx.clone(),
336 format!("or_and_dep and2 {}", 4),
337 4,
338 )
339 .and(eq_dep(
340 rx.clone(),
341 cancel_rx.clone(),
342 format!("or_and_dep and2 {}", 5),
343 5,
344 ));
345 let or = and1.or(and2);
346 let result = or.completed().await;
347 assert_eq!(result, Some(vec![4, 5]));
348 }
349
350 #[tokio::test(flavor = "multi_thread")]
351 async fn many_and_dep() {
352 let (tx, rx) = broadcast(10);
353 let (_cancel_tx, cancel_rx) = broadcast(1);
354
355 tx.broadcast(1).await.unwrap();
356 tx.broadcast(2).await.unwrap();
357 tx.broadcast(3).await.unwrap();
358 tx.broadcast(4).await.unwrap();
359 tx.broadcast(5).await.unwrap();
360 tx.broadcast(6).await.unwrap();
361
362 let mut and1 = eq_dep(
363 rx.clone(),
364 cancel_rx.clone(),
365 format!("many_and_dep and1 {}", 4),
366 4,
367 )
368 .and(eq_dep(
369 rx.clone(),
370 cancel_rx.clone(),
371 format!("many_and_dep and1 {}", 6),
372 6,
373 ));
374 let and2 = eq_dep(
375 rx.clone(),
376 cancel_rx.clone(),
377 format!("many_and_dep and2 {}", 4),
378 4,
379 )
380 .and(eq_dep(
381 rx.clone(),
382 cancel_rx.clone(),
383 format!("many_and_dep and2 {}", 5),
384 5,
385 ));
386 and1.add_deps(and2);
387 let result = and1.completed().await;
388 assert_eq!(result, Some(vec![4, 6, 4, 5]));
389 }
390
391 #[tokio::test(flavor = "multi_thread")]
392 async fn cancel_event_dep() {
393 let (tx, rx) = broadcast(10);
394 let (cancel_tx, cancel_rx) = broadcast(1);
395
396 for i in 0..=5 {
397 tx.broadcast(i).await.unwrap();
398 }
399 cancel_tx.broadcast(()).await.unwrap();
400 let dep = eq_dep(
401 rx.clone(),
402 cancel_rx.clone(),
403 format!("cancel_event_dep {}", 6),
404 6,
405 );
406 let result = dep.completed().await;
407 assert_eq!(result, None);
408 }
409
410 #[tokio::test(flavor = "multi_thread")]
411 async fn drop_cancel_dep() {
412 let (tx, rx) = broadcast(10);
413 let (cancel_tx, cancel_rx) = broadcast(1);
414
415 for i in 0..=5 {
416 tx.broadcast(i).await.unwrap();
417 }
418 drop(cancel_tx);
419 let dep = eq_dep(
420 rx.clone(),
421 cancel_rx.clone(),
422 format!("drop_cancel_dep {}", 6),
423 6,
424 );
425 let result = dep.completed().await;
426 assert_eq!(result, None);
427 }
428
429 #[tokio::test(flavor = "multi_thread")]
430 async fn cancel_and_dep() {
431 let (tx, rx) = broadcast(10);
432 let (cancel_tx, cancel_rx) = broadcast(1);
433
434 let mut deps = vec![];
435 for i in 0..=5 {
436 tx.broadcast(i).await.unwrap();
437 deps.push(eq_dep(
438 rx.clone(),
439 cancel_rx.clone(),
440 format!("cancel_and_dep {i}"),
441 i,
442 ))
443 }
444 deps.push(eq_dep(
445 rx.clone(),
446 cancel_rx.clone(),
447 format!("cancel_and_dep {}", 6),
448 6,
449 ));
450 cancel_tx.broadcast(()).await.unwrap();
451 let result = AndDependency::from_deps(deps).completed().await;
452 assert_eq!(result, None);
453 }
454
455 #[tokio::test(flavor = "multi_thread")]
456 async fn cancel_or_dep() {
457 let (_, rx) = broadcast(10);
458 let (cancel_tx, cancel_rx) = broadcast(1);
459
460 let mut deps = vec![];
461 for i in 0..=5 {
462 deps.push(eq_dep(
463 rx.clone(),
464 cancel_rx.clone(),
465 format!("cancel_event_dep {i}"),
466 i,
467 ))
468 }
469 cancel_tx.broadcast(()).await.unwrap();
470 let result = OrDependency::from_deps(deps).completed().await;
471 assert_eq!(result, None);
472 }
473}