1use std::{collections::VecDeque, hash::Hash, sync::Arc};
2
3use dashmap::DashMap;
4use parking_lot::Mutex;
5
6pub struct BoundedVecDeque<T> {
8 inner: VecDeque<T>,
10 max_size: usize,
12}
13
14impl<T> BoundedVecDeque<T> {
15 pub fn new(max_size: usize) -> Self {
17 Self {
18 inner: VecDeque::new(),
19 max_size,
20 }
21 }
22
23 pub fn push(&mut self, item: T) {
26 if self.inner.len() >= self.max_size {
27 self.inner.pop_front();
28 }
29 self.inner.push_back(item);
30 }
31}
32
33#[derive(Clone)]
34pub struct NamedSemaphore<T: Clone + Eq + Hash> {
35 inner: Arc<DashMap<T, Arc<()>>>,
37
38 max_permits_per_key: usize,
40
41 max_total_permits: Option<usize>,
43
44 total_num_permits_held: Arc<Mutex<usize>>,
46}
47
48#[derive(Debug, thiserror::Error)]
49pub enum NamedSemaphoreError {
50 #[error("global permit limit reached")]
52 GlobalLimitReached,
53
54 #[error("per-key permit limit reached")]
56 PerKeyLimitReached,
57}
58
59impl<T: Clone + Eq + Hash> NamedSemaphore<T> {
60 pub fn new(max_permits_per_key: usize, max_total_permits: Option<usize>) -> Self {
62 Self {
63 inner: Arc::new(DashMap::new()),
64 max_permits_per_key,
65 max_total_permits,
66 total_num_permits_held: Arc::new(Mutex::new(0)),
67 }
68 }
69
70 pub fn try_acquire(&self, key: T) -> Result<NamedSemaphorePermit<T>, NamedSemaphoreError> {
72 let permit_tracker = self
74 .inner
75 .entry(key.clone())
76 .or_insert_with(|| Arc::new(()));
77
78 let mut total_num_permits_guard = self.total_num_permits_held.lock();
80
81 if let Some(max_total_permits) = self.max_total_permits {
83 if *total_num_permits_guard >= max_total_permits {
84 return Err(NamedSemaphoreError::GlobalLimitReached);
85 }
86 }
87
88 if Arc::strong_count(&permit_tracker).saturating_sub(1) >= self.max_permits_per_key {
90 return Err(NamedSemaphoreError::PerKeyLimitReached);
91 }
92
93 *total_num_permits_guard += 1;
95
96 Ok(NamedSemaphorePermit {
98 key,
99 parent: self.clone(),
100 permit: permit_tracker.clone(),
101 })
102 }
103
104 pub fn total_num_permits_held(&self) -> usize {
106 *self.total_num_permits_held.lock()
107 }
108}
109
110pub struct NamedSemaphorePermit<T: Clone + Eq + Hash> {
111 key: T,
113
114 parent: NamedSemaphore<T>,
116
117 permit: Arc<()>,
119}
120
121impl<T: Clone + Eq + Hash> Drop for NamedSemaphorePermit<T> {
122 fn drop(&mut self) {
123 *self.parent.total_num_permits_held.lock() -= 1;
125
126 if Arc::strong_count(&self.permit) == 2 {
128 self.parent.inner.remove(&self.key);
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_bounded_vec_deque() {
139 let mut deque = BoundedVecDeque::new(3);
140 deque.push(1);
141 deque.push(2);
142 deque.push(3);
143 deque.push(4);
144 deque.push(5);
145 assert_eq!(deque.inner.len(), 3);
146 assert_eq!(deque.inner, vec![3, 4, 5]);
147 }
148
149 #[test]
150 fn test_named_semaphore() {
151 let semaphore = NamedSemaphore::new(1, None);
153
154 let permit = semaphore.try_acquire("test");
156
157 assert!(permit.is_ok());
159
160 let permit2 = semaphore.try_acquire("test");
162
163 assert!(permit2.is_err());
165
166 drop(permit);
168
169 let permit3 = semaphore.try_acquire("test");
171
172 assert!(permit3.is_ok());
174
175 drop(permit3);
177
178 assert!(semaphore.inner.is_empty());
180 }
181
182 #[test]
183 fn test_named_semaphore_with_max_total_permits() {
184 let semaphore = NamedSemaphore::new(1, Some(2));
186
187 let permit = semaphore.try_acquire("test");
189
190 assert!(permit.is_ok());
192
193 let permit2 = semaphore.try_acquire("test2");
195
196 assert!(permit2.is_ok());
198
199 let permit3 = semaphore.try_acquire("test3");
201
202 assert!(permit3.is_err());
204
205 drop(permit);
207
208 let permit4 = semaphore.try_acquire("test3");
210
211 assert!(permit4.is_ok());
213
214 assert_eq!(semaphore.total_num_permits_held(), 2);
216
217 drop(permit2);
219 drop(permit3);
220 drop(permit4);
221
222 assert!(semaphore.inner.is_empty());
224
225 assert_eq!(semaphore.total_num_permits_held(), 0);
227 }
228}