hotshot_query_service/data_source/storage/
fail_storage.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21    data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25    Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
26    UpdateAvailabilityStorage,
27    pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
28};
29use crate::{
30    Header, Payload, QueryError, QueryResult,
31    availability::{
32        BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33        QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34    },
35    data_source::{
36        VersionedDataSource,
37        storage::{PayloadMetadata, VidCommonMetadata},
38        update,
39    },
40    metrics::PrometheusMetrics,
41    node::{SyncStatusQueryData, TimeWindowQueryData, WindowStart},
42    status::HasMetrics,
43};
44
45/// A specific action that can be targeted to inject an error.
46#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48    // TODO currently we implement failable actions for the availability methods, but if needed we
49    // can always add more variants for other actions.
50    GetHeader,
51    GetLeaf,
52    GetBlock,
53    GetPayload,
54    GetPayloadMetadata,
55    GetVidCommon,
56    GetVidCommonMetadata,
57    GetHeaderRange,
58    GetLeafRange,
59    GetBlockRange,
60    GetPayloadRange,
61    GetPayloadMetadataRange,
62    GetVidCommonRange,
63    GetVidCommonMetadataRange,
64    GetTransaction,
65    FirstAvailableLeaf,
66    GetStateCert,
67
68    /// Target any action for failure.
69    Any,
70}
71
72impl FailableAction {
73    /// Should `self` being targeted for failure cause `action` to fail?
74    fn matches(self, action: Self) -> bool {
75        // Fail if this is the action specifically targeted for failure or if we are failing any
76        // action right now.
77        self == action || self == Self::Any
78    }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83    #[default]
84    Never,
85    Once(FailableAction),
86    Always(FailableAction),
87}
88
89impl FailureMode {
90    fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91        match self {
92            Self::Once(fail_action) if fail_action.matches(action) => {
93                *self = Self::Never;
94            },
95            Self::Always(fail_action) if fail_action.matches(action) => {},
96            _ => return Ok(()),
97        }
98
99        Err(QueryError::Error {
100            message: "injected error".into(),
101        })
102    }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107    on_read: FailureMode,
108    on_write: FailureMode,
109    on_commit: FailureMode,
110    on_begin_writable: FailureMode,
111    on_begin_read_only: FailureMode,
112}
113
114/// Storage wrapper for error injection.
115#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117    inner: S,
118    failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122    fn from(inner: S) -> Self {
123        Self {
124            inner,
125            failure: Default::default(),
126        }
127    }
128}
129
130impl<S> FailStorage<S> {
131    pub async fn fail_reads(&self, action: FailableAction) {
132        self.failure.lock().await.on_read = FailureMode::Always(action);
133    }
134
135    pub async fn fail_writes(&self, action: FailableAction) {
136        self.failure.lock().await.on_write = FailureMode::Always(action);
137    }
138
139    pub async fn fail_commits(&self, action: FailableAction) {
140        self.failure.lock().await.on_commit = FailureMode::Always(action);
141    }
142
143    pub async fn fail_begins_writable(&self, action: FailableAction) {
144        self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145    }
146
147    pub async fn fail_begins_read_only(&self, action: FailableAction) {
148        self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149    }
150
151    pub async fn fail(&self, action: FailableAction) {
152        let mut failure = self.failure.lock().await;
153        failure.on_read = FailureMode::Always(action);
154        failure.on_write = FailureMode::Always(action);
155        failure.on_commit = FailureMode::Always(action);
156        failure.on_begin_writable = FailureMode::Always(action);
157        failure.on_begin_read_only = FailureMode::Always(action);
158    }
159
160    pub async fn pass_reads(&self) {
161        self.failure.lock().await.on_read = FailureMode::Never;
162    }
163
164    pub async fn pass_writes(&self) {
165        self.failure.lock().await.on_write = FailureMode::Never;
166    }
167
168    pub async fn pass_commits(&self) {
169        self.failure.lock().await.on_commit = FailureMode::Never;
170    }
171
172    pub async fn pass_begins_writable(&self) {
173        self.failure.lock().await.on_begin_writable = FailureMode::Never;
174    }
175
176    pub async fn pass_begins_read_only(&self) {
177        self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178    }
179
180    pub async fn pass(&self) {
181        let mut failure = self.failure.lock().await;
182        failure.on_read = FailureMode::Never;
183        failure.on_write = FailureMode::Never;
184        failure.on_commit = FailureMode::Never;
185        failure.on_begin_writable = FailureMode::Never;
186        failure.on_begin_read_only = FailureMode::Never;
187    }
188
189    pub async fn fail_one_read(&self, action: FailableAction) {
190        self.failure.lock().await.on_read = FailureMode::Once(action);
191    }
192
193    pub async fn fail_one_write(&self, action: FailableAction) {
194        self.failure.lock().await.on_write = FailureMode::Once(action);
195    }
196
197    pub async fn fail_one_commit(&self, action: FailableAction) {
198        self.failure.lock().await.on_commit = FailureMode::Once(action);
199    }
200
201    pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202        self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203    }
204
205    pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206        self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207    }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212    S: VersionedDataSource,
213{
214    type Transaction<'a>
215        = Transaction<S::Transaction<'a>>
216    where
217        Self: 'a;
218    type ReadOnly<'a>
219        = Transaction<S::ReadOnly<'a>>
220    where
221        Self: 'a;
222
223    async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224        self.failure
225            .lock()
226            .await
227            .on_begin_writable
228            .maybe_fail(FailableAction::Any)?;
229        Ok(Transaction {
230            inner: self.inner.write().await?,
231            failure: self.failure.clone(),
232        })
233    }
234
235    async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236        self.failure
237            .lock()
238            .await
239            .on_begin_read_only
240            .maybe_fail(FailableAction::Any)?;
241        Ok(Transaction {
242            inner: self.inner.read().await?,
243            failure: self.failure.clone(),
244        })
245    }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250    S: PrunerConfig,
251{
252    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253        self.inner.set_pruning_config(cfg);
254    }
255
256    fn get_pruning_config(&self) -> Option<PrunerCfg> {
257        self.inner.get_pruning_config()
258    }
259}
260
261#[async_trait]
262impl<S> PruneStorage for FailStorage<S>
263where
264    S: PruneStorage + Sync,
265{
266    type Pruner = S::Pruner;
267
268    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
269        self.inner.get_disk_usage().await
270    }
271
272    async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
273        self.inner.prune(pruner).await
274    }
275}
276
277impl<S> HasMetrics for FailStorage<S>
278where
279    S: HasMetrics,
280{
281    fn metrics(&self) -> &PrometheusMetrics {
282        self.inner.metrics()
283    }
284}
285
286#[derive(Debug)]
287pub struct Transaction<T> {
288    inner: T,
289    failure: Arc<Mutex<Failure>>,
290}
291
292impl<T> Transaction<T> {
293    async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
294        self.failure.lock().await.on_read.maybe_fail(action)
295    }
296
297    async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
298        self.failure.lock().await.on_write.maybe_fail(action)
299    }
300
301    async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
302        self.failure.lock().await.on_commit.maybe_fail(action)
303    }
304}
305
306impl<T> update::Transaction for Transaction<T>
307where
308    T: update::Transaction,
309{
310    async fn commit(self) -> anyhow::Result<()> {
311        self.maybe_fail_commit(FailableAction::Any).await?;
312        self.inner.commit().await
313    }
314
315    fn revert(self) -> impl Future + Send {
316        self.inner.revert()
317    }
318}
319
320#[async_trait]
321impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
322where
323    Types: NodeType,
324    Header<Types>: QueryableHeader<Types>,
325    Payload<Types>: QueryablePayload<Types>,
326    T: AvailabilityStorage<Types>,
327{
328    async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
329        self.maybe_fail_read(FailableAction::GetLeaf).await?;
330        self.inner.get_leaf(id).await
331    }
332
333    async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
334        self.maybe_fail_read(FailableAction::GetBlock).await?;
335        self.inner.get_block(id).await
336    }
337
338    async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
339        self.maybe_fail_read(FailableAction::GetHeader).await?;
340        self.inner.get_header(id).await
341    }
342
343    async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
344        self.maybe_fail_read(FailableAction::GetPayload).await?;
345        self.inner.get_payload(id).await
346    }
347
348    async fn get_payload_metadata(
349        &mut self,
350        id: BlockId<Types>,
351    ) -> QueryResult<PayloadMetadata<Types>> {
352        self.maybe_fail_read(FailableAction::GetPayloadMetadata)
353            .await?;
354        self.inner.get_payload_metadata(id).await
355    }
356
357    async fn get_vid_common(
358        &mut self,
359        id: BlockId<Types>,
360    ) -> QueryResult<VidCommonQueryData<Types>> {
361        self.maybe_fail_read(FailableAction::GetVidCommon).await?;
362        self.inner.get_vid_common(id).await
363    }
364
365    async fn get_vid_common_metadata(
366        &mut self,
367        id: BlockId<Types>,
368    ) -> QueryResult<VidCommonMetadata<Types>> {
369        self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
370            .await?;
371        self.inner.get_vid_common_metadata(id).await
372    }
373
374    async fn get_leaf_range<R>(
375        &mut self,
376        range: R,
377    ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
378    where
379        R: RangeBounds<usize> + Send + 'static,
380    {
381        self.maybe_fail_read(FailableAction::GetLeafRange).await?;
382        self.inner.get_leaf_range(range).await
383    }
384
385    async fn get_block_range<R>(
386        &mut self,
387        range: R,
388    ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
389    where
390        R: RangeBounds<usize> + Send + 'static,
391    {
392        self.maybe_fail_read(FailableAction::GetBlockRange).await?;
393        self.inner.get_block_range(range).await
394    }
395
396    async fn get_payload_range<R>(
397        &mut self,
398        range: R,
399    ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
400    where
401        R: RangeBounds<usize> + Send + 'static,
402    {
403        self.maybe_fail_read(FailableAction::GetPayloadRange)
404            .await?;
405        self.inner.get_payload_range(range).await
406    }
407
408    async fn get_payload_metadata_range<R>(
409        &mut self,
410        range: R,
411    ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
412    where
413        R: RangeBounds<usize> + Send + 'static,
414    {
415        self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
416            .await?;
417        self.inner.get_payload_metadata_range(range).await
418    }
419
420    async fn get_vid_common_range<R>(
421        &mut self,
422        range: R,
423    ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
424    where
425        R: RangeBounds<usize> + Send + 'static,
426    {
427        self.maybe_fail_read(FailableAction::GetVidCommonRange)
428            .await?;
429        self.inner.get_vid_common_range(range).await
430    }
431
432    async fn get_vid_common_metadata_range<R>(
433        &mut self,
434        range: R,
435    ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
436    where
437        R: RangeBounds<usize> + Send + 'static,
438    {
439        self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
440            .await?;
441        self.inner.get_vid_common_metadata_range(range).await
442    }
443
444    async fn get_block_with_transaction(
445        &mut self,
446        hash: TransactionHash<Types>,
447    ) -> QueryResult<BlockQueryData<Types>> {
448        self.maybe_fail_read(FailableAction::GetTransaction).await?;
449        self.inner.get_block_with_transaction(hash).await
450    }
451
452    async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
453        self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
454            .await?;
455        self.inner.first_available_leaf(from).await
456    }
457}
458
459impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
460where
461    Types: NodeType,
462    Header<Types>: QueryableHeader<Types>,
463    Payload<Types>: QueryablePayload<Types>,
464    T: UpdateAvailabilityStorage<Types> + Send + Sync,
465{
466    async fn insert_leaf_with_qc_chain(
467        &mut self,
468        leaf: LeafQueryData<Types>,
469        qc_chain: Option<[CertificatePair<Types>; 2]>,
470    ) -> anyhow::Result<()> {
471        self.maybe_fail_write(FailableAction::Any).await?;
472        self.inner.insert_leaf_with_qc_chain(leaf, qc_chain).await
473    }
474
475    async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
476        self.maybe_fail_write(FailableAction::Any).await?;
477        self.inner.insert_block(block).await
478    }
479
480    async fn insert_vid(
481        &mut self,
482        common: VidCommonQueryData<Types>,
483        share: Option<VidShare>,
484    ) -> anyhow::Result<()> {
485        self.maybe_fail_write(FailableAction::Any).await?;
486        self.inner.insert_vid(common, share).await
487    }
488}
489
490#[async_trait]
491impl<T> PrunedHeightStorage for Transaction<T>
492where
493    T: PrunedHeightStorage + Send + Sync,
494{
495    async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
496        self.maybe_fail_read(FailableAction::Any).await?;
497        self.inner.load_pruned_height().await
498    }
499}
500
501#[async_trait]
502impl<Types, T> NodeStorage<Types> for Transaction<T>
503where
504    Types: NodeType,
505    Header<Types>: QueryableHeader<Types>,
506    T: NodeStorage<Types> + Send + Sync,
507{
508    async fn block_height(&mut self) -> QueryResult<usize> {
509        self.maybe_fail_read(FailableAction::Any).await?;
510        self.inner.block_height().await
511    }
512
513    async fn count_transactions_in_range(
514        &mut self,
515        range: impl RangeBounds<usize> + Send,
516        namespace: Option<NamespaceId<Types>>,
517    ) -> QueryResult<usize> {
518        self.maybe_fail_read(FailableAction::Any).await?;
519        self.inner
520            .count_transactions_in_range(range, namespace)
521            .await
522    }
523
524    async fn payload_size_in_range(
525        &mut self,
526        range: impl RangeBounds<usize> + Send,
527        namespace: Option<NamespaceId<Types>>,
528    ) -> QueryResult<usize> {
529        self.maybe_fail_read(FailableAction::Any).await?;
530        self.inner.payload_size_in_range(range, namespace).await
531    }
532
533    async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
534    where
535        ID: Into<BlockId<Types>> + Send + Sync,
536    {
537        self.maybe_fail_read(FailableAction::Any).await?;
538        self.inner.vid_share(id).await
539    }
540
541    async fn sync_status_for_range(
542        &mut self,
543        start: usize,
544        end: usize,
545    ) -> QueryResult<SyncStatusQueryData> {
546        self.maybe_fail_read(FailableAction::Any).await?;
547        self.inner.sync_status_for_range(start, end).await
548    }
549
550    async fn get_header_window(
551        &mut self,
552        start: impl Into<WindowStart<Types>> + Send + Sync,
553        end: u64,
554        limit: usize,
555    ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
556        self.maybe_fail_read(FailableAction::Any).await?;
557        self.inner.get_header_window(start, end, limit).await
558    }
559
560    async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
561        self.maybe_fail_read(FailableAction::Any).await?;
562        self.inner.latest_qc_chain().await
563    }
564}
565
566impl<Types, T> AggregatesStorage<Types> for Transaction<T>
567where
568    Types: NodeType,
569    Header<Types>: QueryableHeader<Types>,
570    T: AggregatesStorage<Types> + Send + Sync,
571{
572    async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
573        self.maybe_fail_read(FailableAction::Any).await?;
574        self.inner.aggregates_height().await
575    }
576
577    async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
578        self.maybe_fail_read(FailableAction::Any).await?;
579        self.inner.load_prev_aggregate().await
580    }
581}
582
583impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
584where
585    Types: NodeType,
586    Header<Types>: QueryableHeader<Types>,
587    T: UpdateAggregatesStorage<Types> + Send + Sync,
588{
589    async fn update_aggregates(
590        &mut self,
591        prev: Aggregate<Types>,
592        blocks: &[PayloadMetadata<Types>],
593    ) -> anyhow::Result<Aggregate<Types>> {
594        self.maybe_fail_write(FailableAction::Any).await?;
595        self.inner.update_aggregates(prev, blocks).await
596    }
597}