hotshot_query_service/fetching/provider/
testing.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(any(test, feature = "testing"))]
14
15use std::{
16    fmt::Debug,
17    sync::{
18        atomic::{AtomicBool, Ordering},
19        Arc,
20    },
21};
22
23use async_lock::RwLock;
24use async_trait::async_trait;
25use derivative::Derivative;
26use hotshot_types::traits::node_implementation::NodeType;
27use tokio::sync::broadcast;
28
29use super::Provider;
30use crate::fetching::Request;
31
32/// Adaptor to add test-only functionality to an existing [`Provider`].
33///
34/// [`TestProvider`] wraps an existing provider `P` and adds some additional functionality which can
35/// be useful in tests, such as the ability to inject delays into the handling of fetch requests.
36#[derive(Derivative)]
37#[derivative(Clone(bound = ""), Debug(bound = "P: Debug"))]
38pub struct TestProvider<P> {
39    inner: Arc<P>,
40    unblock: Arc<RwLock<Option<broadcast::Sender<()>>>>,
41    fail: Arc<AtomicBool>,
42}
43
44impl<P> TestProvider<P> {
45    pub fn new(inner: P) -> Self {
46        Self {
47            inner: Arc::new(inner),
48            unblock: Default::default(),
49            fail: Arc::new(AtomicBool::new(false)),
50        }
51    }
52
53    /// Delay fetch requests until [`unblock`](Self::unblock).
54    ///
55    /// Fetch requests started after this method returns will block without completing until
56    /// [`unblock`](Self::unblock) is called. This can be useful for tests to examine the state of a
57    /// data source _before_ a fetch request completes, to check that the subsequent fetch actually
58    /// has an effect.
59    pub async fn block(&self) {
60        let mut unblock = self.unblock.write().await;
61        if unblock.is_none() {
62            *unblock = Some(broadcast::channel(1000).0);
63        }
64    }
65
66    /// Allow blocked fetch requests to proceed.
67    ///
68    /// Fetch requests which are blocked as a result of a preceding call to [`block`](Self::block)
69    /// will become unblocked.
70    pub async fn unblock(&self) {
71        let mut unblock = self.unblock.write().await;
72        if let Some(unblock) = unblock.take() {
73            unblock.send(()).ok();
74        }
75    }
76
77    /// Cause subsequent requests to fail.
78    ///
79    /// All requests to the provider after this function is called will fail, until
80    /// [`unfail`](Self::unfail) is called.
81    pub fn fail(&self) {
82        self.fail.store(true, Ordering::SeqCst);
83    }
84
85    /// Stop requests from failing as a result of a previous call to [`fail`](Self::fail).
86    pub fn unfail(&self) {
87        self.fail.store(false, Ordering::SeqCst);
88    }
89}
90
91#[async_trait]
92impl<Types, P, T> Provider<Types, T> for TestProvider<P>
93where
94    Types: NodeType,
95    T: Request<Types> + 'static,
96    P: Provider<Types, T> + Sync,
97{
98    async fn fetch(&self, req: T) -> Option<T::Response> {
99        // Fail the request if the user has called `fail`.
100        if self.fail.load(Ordering::SeqCst) {
101            return None;
102        }
103
104        // Block the request if the user has called `block`.
105        let handle = self
106            .unblock
107            .read()
108            .await
109            .as_ref()
110            .map(|unblock| unblock.subscribe());
111        if let Some(mut handle) = handle {
112            tracing::info!("request for {req:?} will block until manually unblocked");
113            handle.recv().await.ok();
114            tracing::info!("request for {req:?} unblocked");
115        }
116
117        // Do the request.
118        self.inner.fetch(req).await
119    }
120}