request_response/
util.rs

1use std::{collections::VecDeque, hash::Hash, sync::Arc};
2
3use dashmap::DashMap;
4use parking_lot::Mutex;
5
6/// A [`VecDeque`] with a maximum size
7pub struct BoundedVecDeque<T> {
8    /// The inner [`VecDeque`]
9    inner: VecDeque<T>,
10    /// The maximum size of the [`VecDeque`]
11    max_size: usize,
12}
13
14impl<T> BoundedVecDeque<T> {
15    /// Create a new bounded [`VecDeque`] with the given maximum size
16    pub fn new(max_size: usize) -> Self {
17        Self {
18            inner: VecDeque::new(),
19            max_size,
20        }
21    }
22
23    /// Push an item into the bounded [`VecDeque`], removing the oldest item if the
24    /// maximum size is reached
25    pub fn push(&mut self, item: T) {
26        if self.inner.len() >= self.max_size {
27            self.inner.pop_front();
28        }
29        self.inner.push_back(item);
30    }
31}
32
33#[derive(Clone)]
34pub struct NamedSemaphore<T: Clone + Eq + Hash> {
35    /// The underlying map of keys to their semaphore
36    inner: Arc<DashMap<T, Arc<()>>>,
37
38    /// The maximum number of permits for each key
39    max_permits_per_key: usize,
40
41    /// The maximum number of permits that can be held across all keys
42    max_total_permits: Option<usize>,
43
44    /// The total number of permits that are currently being held
45    total_num_permits_held: Arc<Mutex<usize>>,
46}
47
48#[derive(Debug, thiserror::Error)]
49pub enum NamedSemaphoreError {
50    /// The global permit limit has been reached
51    #[error("global permit limit reached")]
52    GlobalLimitReached,
53
54    /// The per-key permit limit has been reached
55    #[error("per-key permit limit reached")]
56    PerKeyLimitReached,
57}
58
59impl<T: Clone + Eq + Hash> NamedSemaphore<T> {
60    /// Create a new named semaphore, specifying the maximum number of permits for each key.
61    pub fn new(max_permits_per_key: usize, max_total_permits: Option<usize>) -> Self {
62        Self {
63            inner: Arc::new(DashMap::new()),
64            max_permits_per_key,
65            max_total_permits,
66            total_num_permits_held: Arc::new(Mutex::new(0)),
67        }
68    }
69
70    /// Try to acquire a permit for the given key.
71    pub fn try_acquire(&self, key: T) -> Result<NamedSemaphorePermit<T>, NamedSemaphoreError> {
72        // Get the permit tracker for the key
73        let permit_tracker = self
74            .inner
75            .entry(key.clone())
76            .or_insert_with(|| Arc::new(()));
77
78        // Lock the number of permits held
79        let mut total_num_permits_guard = self.total_num_permits_held.lock();
80
81        // If the total number of permits is greater than the maximum number of permits, return None
82        if let Some(max_total_permits) = self.max_total_permits {
83            if *total_num_permits_guard >= max_total_permits {
84                return Err(NamedSemaphoreError::GlobalLimitReached);
85            }
86        }
87
88        // If the number of permits is greater than or equal to the maximum number of permits, return an error
89        if Arc::strong_count(&permit_tracker).saturating_sub(1) >= self.max_permits_per_key {
90            return Err(NamedSemaphoreError::PerKeyLimitReached);
91        }
92
93        // Increment the total number of permits
94        *total_num_permits_guard += 1;
95
96        // Return the new permit
97        Ok(NamedSemaphorePermit {
98            key,
99            parent: self.clone(),
100            permit: permit_tracker.clone(),
101        })
102    }
103
104    /// Get the total number of permits that are currently being held across all keys
105    pub fn total_num_permits_held(&self) -> usize {
106        *self.total_num_permits_held.lock()
107    }
108}
109
110pub struct NamedSemaphorePermit<T: Clone + Eq + Hash> {
111    /// The key that we are holding a permit for
112    key: T,
113
114    /// The parent semaphore that we are borrowing a permit from
115    parent: NamedSemaphore<T>,
116
117    /// The permit that we are holding
118    permit: Arc<()>,
119}
120
121impl<T: Clone + Eq + Hash> Drop for NamedSemaphorePermit<T> {
122    fn drop(&mut self) {
123        // Decrement the total number of permits
124        *self.parent.total_num_permits_held.lock() -= 1;
125
126        // Remove the semaphore but only if there are no more strong references to the parent
127        if Arc::strong_count(&self.permit) == 2 {
128            self.parent.inner.remove(&self.key);
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_bounded_vec_deque() {
139        let mut deque = BoundedVecDeque::new(3);
140        deque.push(1);
141        deque.push(2);
142        deque.push(3);
143        deque.push(4);
144        deque.push(5);
145        assert_eq!(deque.inner.len(), 3);
146        assert_eq!(deque.inner, vec![3, 4, 5]);
147    }
148
149    #[test]
150    fn test_named_semaphore() {
151        // Create a new semaphore with a maximum of 1 permit
152        let semaphore = NamedSemaphore::new(1, None);
153
154        // Try to acquire a permit for the key "test"
155        let permit = semaphore.try_acquire("test");
156
157        // Assert that the permit is Some
158        assert!(permit.is_ok());
159
160        // Try to acquire a permit for the key "test" again
161        let permit2 = semaphore.try_acquire("test");
162
163        // Assert that the permit is None
164        assert!(permit2.is_err());
165
166        // Drop the first permit
167        drop(permit);
168
169        // Try to acquire a permit for the key "test" again
170        let permit3 = semaphore.try_acquire("test");
171
172        // Assert that the permit is Some
173        assert!(permit3.is_ok());
174
175        // Drop permit3
176        drop(permit3);
177
178        // Make sure the semaphore is empty
179        assert!(semaphore.inner.is_empty());
180    }
181
182    #[test]
183    fn test_named_semaphore_with_max_total_permits() {
184        // Create a new semaphore with a maximum of 1 permit
185        let semaphore = NamedSemaphore::new(1, Some(2));
186
187        // Try to acquire a permit for the key "test"
188        let permit = semaphore.try_acquire("test");
189
190        // Assert that the permit is Some
191        assert!(permit.is_ok());
192
193        // Try to acquire a permit for the key "test2"
194        let permit2 = semaphore.try_acquire("test2");
195
196        // Assert that the permit is Some
197        assert!(permit2.is_ok());
198
199        // Try to acquire a permit for the key "test3"
200        let permit3 = semaphore.try_acquire("test3");
201
202        // Assert that the permit is None
203        assert!(permit3.is_err());
204
205        // Drop the first permit
206        drop(permit);
207
208        // Try to acquire a permit for the key "test3" again
209        let permit4 = semaphore.try_acquire("test3");
210
211        // Assert that the permit is Some
212        assert!(permit4.is_ok());
213
214        // Make sure the total number of permits held is 2
215        assert_eq!(semaphore.total_num_permits_held(), 2);
216
217        // Drop all permits
218        drop(permit2);
219        drop(permit3);
220        drop(permit4);
221
222        // Make sure the semaphore is empty
223        assert!(semaphore.inner.is_empty());
224
225        // Make sure the total number of permits is 0
226        assert_eq!(semaphore.total_num_permits_held(), 0);
227    }
228}