hotshot_query_service/fetching/provider/testing.rs
1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(any(test, feature = "testing"))]
14
15use std::{
16 fmt::Debug,
17 sync::{
18 atomic::{AtomicBool, Ordering},
19 Arc,
20 },
21};
22
23use async_lock::RwLock;
24use async_trait::async_trait;
25use derivative::Derivative;
26use hotshot_types::traits::node_implementation::NodeType;
27use tokio::sync::broadcast;
28
29use super::Provider;
30use crate::fetching::Request;
31
32/// Adaptor to add test-only functionality to an existing [`Provider`].
33///
34/// [`TestProvider`] wraps an existing provider `P` and adds some additional functionality which can
35/// be useful in tests, such as the ability to inject delays into the handling of fetch requests.
36#[derive(Derivative)]
37#[derivative(Clone(bound = ""), Debug(bound = "P: Debug"))]
38pub struct TestProvider<P> {
39 inner: Arc<P>,
40 unblock: Arc<RwLock<Option<broadcast::Sender<()>>>>,
41 fail: Arc<AtomicBool>,
42}
43
44impl<P> TestProvider<P> {
45 pub fn new(inner: P) -> Self {
46 Self {
47 inner: Arc::new(inner),
48 unblock: Default::default(),
49 fail: Arc::new(AtomicBool::new(false)),
50 }
51 }
52
53 /// Delay fetch requests until [`unblock`](Self::unblock).
54 ///
55 /// Fetch requests started after this method returns will block without completing until
56 /// [`unblock`](Self::unblock) is called. This can be useful for tests to examine the state of a
57 /// data source _before_ a fetch request completes, to check that the subsequent fetch actually
58 /// has an effect.
59 pub async fn block(&self) {
60 let mut unblock = self.unblock.write().await;
61 if unblock.is_none() {
62 *unblock = Some(broadcast::channel(1000).0);
63 }
64 }
65
66 /// Allow blocked fetch requests to proceed.
67 ///
68 /// Fetch requests which are blocked as a result of a preceding call to [`block`](Self::block)
69 /// will become unblocked.
70 pub async fn unblock(&self) {
71 let mut unblock = self.unblock.write().await;
72 if let Some(unblock) = unblock.take() {
73 unblock.send(()).ok();
74 }
75 }
76
77 /// Cause subsequent requests to fail.
78 ///
79 /// All requests to the provider after this function is called will fail, until
80 /// [`unfail`](Self::unfail) is called.
81 pub fn fail(&self) {
82 self.fail.store(true, Ordering::SeqCst);
83 }
84
85 /// Stop requests from failing as a result of a previous call to [`fail`](Self::fail).
86 pub fn unfail(&self) {
87 self.fail.store(false, Ordering::SeqCst);
88 }
89}
90
91#[async_trait]
92impl<Types, P, T> Provider<Types, T> for TestProvider<P>
93where
94 Types: NodeType,
95 T: Request<Types> + 'static,
96 P: Provider<Types, T> + Sync,
97{
98 async fn fetch(&self, req: T) -> Option<T::Response> {
99 // Fail the request if the user has called `fail`.
100 if self.fail.load(Ordering::SeqCst) {
101 return None;
102 }
103
104 // Block the request if the user has called `block`.
105 let handle = self
106 .unblock
107 .read()
108 .await
109 .as_ref()
110 .map(|unblock| unblock.subscribe());
111 if let Some(mut handle) = handle {
112 tracing::info!("request for {req:?} will block until manually unblocked");
113 handle.recv().await.ok();
114 tracing::info!("request for {req:?} unblocked");
115 }
116
117 // Do the request.
118 self.inner.fetch(req).await
119 }
120}