hotshot_query_service/data_source/storage/
fail_storage.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21    data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25    pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
26    sql::MigrateTypes,
27    Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
28    UpdateAvailabilityStorage,
29};
30use crate::{
31    availability::{
32        BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33        QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34    },
35    data_source::{
36        storage::{PayloadMetadata, VidCommonMetadata},
37        update, VersionedDataSource,
38    },
39    metrics::PrometheusMetrics,
40    node::{SyncStatus, TimeWindowQueryData, WindowStart},
41    status::HasMetrics,
42    Header, Payload, QueryError, QueryResult,
43};
44
45/// A specific action that can be targeted to inject an error.
46#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48    // TODO currently we implement failable actions for the availability methods, but if needed we
49    // can always add more variants for other actions.
50    GetHeader,
51    GetLeaf,
52    GetBlock,
53    GetPayload,
54    GetPayloadMetadata,
55    GetVidCommon,
56    GetVidCommonMetadata,
57    GetHeaderRange,
58    GetLeafRange,
59    GetBlockRange,
60    GetPayloadRange,
61    GetPayloadMetadataRange,
62    GetVidCommonRange,
63    GetVidCommonMetadataRange,
64    GetTransaction,
65    FirstAvailableLeaf,
66    GetStateCert,
67
68    /// Target any action for failure.
69    Any,
70}
71
72impl FailableAction {
73    /// Should `self` being targeted for failure cause `action` to fail?
74    fn matches(self, action: Self) -> bool {
75        // Fail if this is the action specifically targeted for failure or if we are failing any
76        // action right now.
77        self == action || self == Self::Any
78    }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83    #[default]
84    Never,
85    Once(FailableAction),
86    Always(FailableAction),
87}
88
89impl FailureMode {
90    fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91        match self {
92            Self::Once(fail_action) if fail_action.matches(action) => {
93                *self = Self::Never;
94            },
95            Self::Always(fail_action) if fail_action.matches(action) => {},
96            _ => return Ok(()),
97        }
98
99        Err(QueryError::Error {
100            message: "injected error".into(),
101        })
102    }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107    on_read: FailureMode,
108    on_write: FailureMode,
109    on_commit: FailureMode,
110    on_begin_writable: FailureMode,
111    on_begin_read_only: FailureMode,
112}
113
114/// Storage wrapper for error injection.
115#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117    inner: S,
118    failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122    fn from(inner: S) -> Self {
123        Self {
124            inner,
125            failure: Default::default(),
126        }
127    }
128}
129
130impl<S> FailStorage<S> {
131    pub async fn fail_reads(&self, action: FailableAction) {
132        self.failure.lock().await.on_read = FailureMode::Always(action);
133    }
134
135    pub async fn fail_writes(&self, action: FailableAction) {
136        self.failure.lock().await.on_write = FailureMode::Always(action);
137    }
138
139    pub async fn fail_commits(&self, action: FailableAction) {
140        self.failure.lock().await.on_commit = FailureMode::Always(action);
141    }
142
143    pub async fn fail_begins_writable(&self, action: FailableAction) {
144        self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145    }
146
147    pub async fn fail_begins_read_only(&self, action: FailableAction) {
148        self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149    }
150
151    pub async fn fail(&self, action: FailableAction) {
152        let mut failure = self.failure.lock().await;
153        failure.on_read = FailureMode::Always(action);
154        failure.on_write = FailureMode::Always(action);
155        failure.on_commit = FailureMode::Always(action);
156        failure.on_begin_writable = FailureMode::Always(action);
157        failure.on_begin_read_only = FailureMode::Always(action);
158    }
159
160    pub async fn pass_reads(&self) {
161        self.failure.lock().await.on_read = FailureMode::Never;
162    }
163
164    pub async fn pass_writes(&self) {
165        self.failure.lock().await.on_write = FailureMode::Never;
166    }
167
168    pub async fn pass_commits(&self) {
169        self.failure.lock().await.on_commit = FailureMode::Never;
170    }
171
172    pub async fn pass_begins_writable(&self) {
173        self.failure.lock().await.on_begin_writable = FailureMode::Never;
174    }
175
176    pub async fn pass_begins_read_only(&self) {
177        self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178    }
179
180    pub async fn pass(&self) {
181        let mut failure = self.failure.lock().await;
182        failure.on_read = FailureMode::Never;
183        failure.on_write = FailureMode::Never;
184        failure.on_commit = FailureMode::Never;
185        failure.on_begin_writable = FailureMode::Never;
186        failure.on_begin_read_only = FailureMode::Never;
187    }
188
189    pub async fn fail_one_read(&self, action: FailableAction) {
190        self.failure.lock().await.on_read = FailureMode::Once(action);
191    }
192
193    pub async fn fail_one_write(&self, action: FailableAction) {
194        self.failure.lock().await.on_write = FailureMode::Once(action);
195    }
196
197    pub async fn fail_one_commit(&self, action: FailableAction) {
198        self.failure.lock().await.on_commit = FailureMode::Once(action);
199    }
200
201    pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202        self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203    }
204
205    pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206        self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207    }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212    S: VersionedDataSource,
213{
214    type Transaction<'a>
215        = Transaction<S::Transaction<'a>>
216    where
217        Self: 'a;
218    type ReadOnly<'a>
219        = Transaction<S::ReadOnly<'a>>
220    where
221        Self: 'a;
222
223    async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224        self.failure
225            .lock()
226            .await
227            .on_begin_writable
228            .maybe_fail(FailableAction::Any)?;
229        Ok(Transaction {
230            inner: self.inner.write().await?,
231            failure: self.failure.clone(),
232        })
233    }
234
235    async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236        self.failure
237            .lock()
238            .await
239            .on_begin_read_only
240            .maybe_fail(FailableAction::Any)?;
241        Ok(Transaction {
242            inner: self.inner.read().await?,
243            failure: self.failure.clone(),
244        })
245    }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250    S: PrunerConfig,
251{
252    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253        self.inner.set_pruning_config(cfg);
254    }
255
256    fn get_pruning_config(&self) -> Option<PrunerCfg> {
257        self.inner.get_pruning_config()
258    }
259}
260
261#[async_trait]
262impl<S, Types: NodeType> MigrateTypes<Types> for FailStorage<S>
263where
264    S: MigrateTypes<Types> + Sync,
265{
266    async fn migrate_types(&self, _batch_size: u64) -> anyhow::Result<()> {
267        Ok(())
268    }
269}
270
271#[async_trait]
272impl<S> PruneStorage for FailStorage<S>
273where
274    S: PruneStorage + Sync,
275{
276    type Pruner = S::Pruner;
277
278    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
279        self.inner.get_disk_usage().await
280    }
281
282    async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
283        self.inner.prune(pruner).await
284    }
285}
286
287impl<S> HasMetrics for FailStorage<S>
288where
289    S: HasMetrics,
290{
291    fn metrics(&self) -> &PrometheusMetrics {
292        self.inner.metrics()
293    }
294}
295
296#[derive(Debug)]
297pub struct Transaction<T> {
298    inner: T,
299    failure: Arc<Mutex<Failure>>,
300}
301
302impl<T> Transaction<T> {
303    async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
304        self.failure.lock().await.on_read.maybe_fail(action)
305    }
306
307    async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
308        self.failure.lock().await.on_write.maybe_fail(action)
309    }
310
311    async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
312        self.failure.lock().await.on_commit.maybe_fail(action)
313    }
314}
315
316impl<T> update::Transaction for Transaction<T>
317where
318    T: update::Transaction,
319{
320    async fn commit(self) -> anyhow::Result<()> {
321        self.maybe_fail_commit(FailableAction::Any).await?;
322        self.inner.commit().await
323    }
324
325    fn revert(self) -> impl Future + Send {
326        self.inner.revert()
327    }
328}
329
330#[async_trait]
331impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
332where
333    Types: NodeType,
334    Header<Types>: QueryableHeader<Types>,
335    Payload<Types>: QueryablePayload<Types>,
336    T: AvailabilityStorage<Types>,
337{
338    async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
339        self.maybe_fail_read(FailableAction::GetLeaf).await?;
340        self.inner.get_leaf(id).await
341    }
342
343    async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
344        self.maybe_fail_read(FailableAction::GetBlock).await?;
345        self.inner.get_block(id).await
346    }
347
348    async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
349        self.maybe_fail_read(FailableAction::GetHeader).await?;
350        self.inner.get_header(id).await
351    }
352
353    async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
354        self.maybe_fail_read(FailableAction::GetPayload).await?;
355        self.inner.get_payload(id).await
356    }
357
358    async fn get_payload_metadata(
359        &mut self,
360        id: BlockId<Types>,
361    ) -> QueryResult<PayloadMetadata<Types>> {
362        self.maybe_fail_read(FailableAction::GetPayloadMetadata)
363            .await?;
364        self.inner.get_payload_metadata(id).await
365    }
366
367    async fn get_vid_common(
368        &mut self,
369        id: BlockId<Types>,
370    ) -> QueryResult<VidCommonQueryData<Types>> {
371        self.maybe_fail_read(FailableAction::GetVidCommon).await?;
372        self.inner.get_vid_common(id).await
373    }
374
375    async fn get_vid_common_metadata(
376        &mut self,
377        id: BlockId<Types>,
378    ) -> QueryResult<VidCommonMetadata<Types>> {
379        self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
380            .await?;
381        self.inner.get_vid_common_metadata(id).await
382    }
383
384    async fn get_leaf_range<R>(
385        &mut self,
386        range: R,
387    ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
388    where
389        R: RangeBounds<usize> + Send + 'static,
390    {
391        self.maybe_fail_read(FailableAction::GetLeafRange).await?;
392        self.inner.get_leaf_range(range).await
393    }
394
395    async fn get_block_range<R>(
396        &mut self,
397        range: R,
398    ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
399    where
400        R: RangeBounds<usize> + Send + 'static,
401    {
402        self.maybe_fail_read(FailableAction::GetBlockRange).await?;
403        self.inner.get_block_range(range).await
404    }
405
406    async fn get_payload_range<R>(
407        &mut self,
408        range: R,
409    ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
410    where
411        R: RangeBounds<usize> + Send + 'static,
412    {
413        self.maybe_fail_read(FailableAction::GetPayloadRange)
414            .await?;
415        self.inner.get_payload_range(range).await
416    }
417
418    async fn get_payload_metadata_range<R>(
419        &mut self,
420        range: R,
421    ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
422    where
423        R: RangeBounds<usize> + Send + 'static,
424    {
425        self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
426            .await?;
427        self.inner.get_payload_metadata_range(range).await
428    }
429
430    async fn get_vid_common_range<R>(
431        &mut self,
432        range: R,
433    ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
434    where
435        R: RangeBounds<usize> + Send + 'static,
436    {
437        self.maybe_fail_read(FailableAction::GetVidCommonRange)
438            .await?;
439        self.inner.get_vid_common_range(range).await
440    }
441
442    async fn get_vid_common_metadata_range<R>(
443        &mut self,
444        range: R,
445    ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
446    where
447        R: RangeBounds<usize> + Send + 'static,
448    {
449        self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
450            .await?;
451        self.inner.get_vid_common_metadata_range(range).await
452    }
453
454    async fn get_block_with_transaction(
455        &mut self,
456        hash: TransactionHash<Types>,
457    ) -> QueryResult<BlockQueryData<Types>> {
458        self.maybe_fail_read(FailableAction::GetTransaction).await?;
459        self.inner.get_block_with_transaction(hash).await
460    }
461
462    async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
463        self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
464            .await?;
465        self.inner.first_available_leaf(from).await
466    }
467}
468
469impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
470where
471    Types: NodeType,
472    Header<Types>: QueryableHeader<Types>,
473    Payload<Types>: QueryablePayload<Types>,
474    T: UpdateAvailabilityStorage<Types> + Send + Sync,
475{
476    async fn insert_leaf_with_qc_chain(
477        &mut self,
478        leaf: LeafQueryData<Types>,
479        qc_chain: Option<[CertificatePair<Types>; 2]>,
480    ) -> anyhow::Result<()> {
481        self.maybe_fail_write(FailableAction::Any).await?;
482        self.inner.insert_leaf_with_qc_chain(leaf, qc_chain).await
483    }
484
485    async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
486        self.maybe_fail_write(FailableAction::Any).await?;
487        self.inner.insert_block(block).await
488    }
489
490    async fn insert_vid(
491        &mut self,
492        common: VidCommonQueryData<Types>,
493        share: Option<VidShare>,
494    ) -> anyhow::Result<()> {
495        self.maybe_fail_write(FailableAction::Any).await?;
496        self.inner.insert_vid(common, share).await
497    }
498}
499
500#[async_trait]
501impl<T> PrunedHeightStorage for Transaction<T>
502where
503    T: PrunedHeightStorage + Send + Sync,
504{
505    async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
506        self.maybe_fail_read(FailableAction::Any).await?;
507        self.inner.load_pruned_height().await
508    }
509}
510
511#[async_trait]
512impl<Types, T> NodeStorage<Types> for Transaction<T>
513where
514    Types: NodeType,
515    Header<Types>: QueryableHeader<Types>,
516    T: NodeStorage<Types> + Send + Sync,
517{
518    async fn block_height(&mut self) -> QueryResult<usize> {
519        self.maybe_fail_read(FailableAction::Any).await?;
520        self.inner.block_height().await
521    }
522
523    async fn count_transactions_in_range(
524        &mut self,
525        range: impl RangeBounds<usize> + Send,
526        namespace: Option<NamespaceId<Types>>,
527    ) -> QueryResult<usize> {
528        self.maybe_fail_read(FailableAction::Any).await?;
529        self.inner
530            .count_transactions_in_range(range, namespace)
531            .await
532    }
533
534    async fn payload_size_in_range(
535        &mut self,
536        range: impl RangeBounds<usize> + Send,
537        namespace: Option<NamespaceId<Types>>,
538    ) -> QueryResult<usize> {
539        self.maybe_fail_read(FailableAction::Any).await?;
540        self.inner.payload_size_in_range(range, namespace).await
541    }
542
543    async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
544    where
545        ID: Into<BlockId<Types>> + Send + Sync,
546    {
547        self.maybe_fail_read(FailableAction::Any).await?;
548        self.inner.vid_share(id).await
549    }
550
551    async fn sync_status(&mut self) -> QueryResult<SyncStatus> {
552        self.maybe_fail_read(FailableAction::Any).await?;
553        self.inner.sync_status().await
554    }
555
556    async fn get_header_window(
557        &mut self,
558        start: impl Into<WindowStart<Types>> + Send + Sync,
559        end: u64,
560        limit: usize,
561    ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
562        self.maybe_fail_read(FailableAction::Any).await?;
563        self.inner.get_header_window(start, end, limit).await
564    }
565
566    async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
567        self.maybe_fail_read(FailableAction::Any).await?;
568        self.inner.latest_qc_chain().await
569    }
570}
571
572impl<Types, T> AggregatesStorage<Types> for Transaction<T>
573where
574    Types: NodeType,
575    Header<Types>: QueryableHeader<Types>,
576    T: AggregatesStorage<Types> + Send + Sync,
577{
578    async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
579        self.maybe_fail_read(FailableAction::Any).await?;
580        self.inner.aggregates_height().await
581    }
582
583    async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
584        self.maybe_fail_read(FailableAction::Any).await?;
585        self.inner.load_prev_aggregate().await
586    }
587}
588
589impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
590where
591    Types: NodeType,
592    Header<Types>: QueryableHeader<Types>,
593    T: UpdateAggregatesStorage<Types> + Send + Sync,
594{
595    async fn update_aggregates(
596        &mut self,
597        prev: Aggregate<Types>,
598        blocks: &[PayloadMetadata<Types>],
599    ) -> anyhow::Result<Aggregate<Types>> {
600        self.maybe_fail_write(FailableAction::Any).await?;
601        self.inner.update_aggregates(prev, blocks).await
602    }
603}