sequencer/
util.rs

1use std::{future::Future, sync::Arc};
2
3use anyhow::Result;
4use tokio::{
5    sync::Semaphore,
6    task::{AbortHandle, JoinError, JoinSet},
7};
8
9/// A join set that limits the number of concurrent tasks
10pub struct BoundedJoinSet<T> {
11    // The inner join set
12    inner: JoinSet<T>,
13    // The semaphore we use to limit the number of concurrent tasks
14    semaphore: Arc<Semaphore>,
15}
16
17impl<T> BoundedJoinSet<T> {
18    /// Create a new [`BoundedJoinSet`] with a maximum number of concurrent tasks
19    pub fn new(max_concurrency: usize) -> Self {
20        Self {
21            inner: JoinSet::new(),
22            semaphore: Arc::new(Semaphore::const_new(max_concurrency)),
23        }
24    }
25}
26
27impl<T: 'static> BoundedJoinSet<T> {
28    /// Spawn the provided task on the JoinSet, returning an [AbortHandle] that can be used
29    /// to remotely cancel the task.
30    pub fn spawn<F>(&mut self, task: F) -> AbortHandle
31    where
32        F: Future<Output = T> + Send + 'static,
33        T: Send,
34    {
35        // Clone the semaphore for the inner task
36        let semaphore = self.semaphore.clone();
37
38        // Wrap the task, making it wait for a semaphore permit first
39        let task = async move {
40            // Acquire the permit
41            let permit = semaphore.acquire().await;
42
43            // Perform the actual task
44            let result = task.await;
45
46            // Drop the permit
47            drop(permit);
48
49            // Return the result
50            result
51        };
52
53        // Spawn the task in the inner join set
54        self.inner.spawn(task)
55    }
56
57    /// Waits until one of the tasks in the set completes and returns its output.
58    ///
59    /// Returns None if the set is empty.
60    pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
61        self.inner.join_next().await
62    }
63
64    /// Waits until one of the tasks in the set completes and returns its output, along with the task ID of the completed task.
65    pub async fn join_next_with_id(&mut self) -> Option<Result<(tokio::task::Id, T), JoinError>> {
66        self.inner.join_next_with_id().await
67    }
68}