sequencer/util.rs
1use std::{future::Future, sync::Arc};
2
3use anyhow::Result;
4use tokio::{
5 sync::Semaphore,
6 task::{AbortHandle, JoinError, JoinSet},
7};
8
9/// A join set that limits the number of concurrent tasks
10pub struct BoundedJoinSet<T> {
11 // The inner join set
12 inner: JoinSet<T>,
13 // The semaphore we use to limit the number of concurrent tasks
14 semaphore: Arc<Semaphore>,
15}
16
17impl<T> BoundedJoinSet<T> {
18 /// Create a new [`BoundedJoinSet`] with a maximum number of concurrent tasks
19 pub fn new(max_concurrency: usize) -> Self {
20 Self {
21 inner: JoinSet::new(),
22 semaphore: Arc::new(Semaphore::const_new(max_concurrency)),
23 }
24 }
25}
26
27impl<T: 'static> BoundedJoinSet<T> {
28 /// Spawn the provided task on the JoinSet, returning an [AbortHandle] that can be used
29 /// to remotely cancel the task.
30 pub fn spawn<F>(&mut self, task: F) -> AbortHandle
31 where
32 F: Future<Output = T> + Send + 'static,
33 T: Send,
34 {
35 // Clone the semaphore for the inner task
36 let semaphore = self.semaphore.clone();
37
38 // Wrap the task, making it wait for a semaphore permit first
39 let task = async move {
40 // Acquire the permit
41 let permit = semaphore.acquire().await;
42
43 // Perform the actual task
44 let result = task.await;
45
46 // Drop the permit
47 drop(permit);
48
49 // Return the result
50 result
51 };
52
53 // Spawn the task in the inner join set
54 self.inner.spawn(task)
55 }
56
57 /// Waits until one of the tasks in the set completes and returns its output.
58 ///
59 /// Returns None if the set is empty.
60 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
61 self.inner.join_next().await
62 }
63
64 /// Waits until one of the tasks in the set completes and returns its output, along with the task ID of the completed task.
65 pub async fn join_next_with_id(&mut self) -> Option<Result<(tokio::task::Id, T), JoinError>> {
66 self.inner.join_next_with_id().await
67 }
68}