1#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21 data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25 Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
26 UpdateAvailabilityStorage,
27 pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
28};
29use crate::{
30 Header, Payload, QueryError, QueryResult,
31 availability::{
32 BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33 QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34 },
35 data_source::{
36 VersionedDataSource,
37 storage::{PayloadMetadata, VidCommonMetadata},
38 update,
39 },
40 metrics::PrometheusMetrics,
41 node::{SyncStatusQueryData, TimeWindowQueryData, WindowStart},
42 status::HasMetrics,
43};
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48 GetHeader,
51 GetLeaf,
52 GetBlock,
53 GetPayload,
54 GetPayloadMetadata,
55 GetVidCommon,
56 GetVidCommonMetadata,
57 GetHeaderRange,
58 GetLeafRange,
59 GetBlockRange,
60 GetPayloadRange,
61 GetPayloadMetadataRange,
62 GetVidCommonRange,
63 GetVidCommonMetadataRange,
64 GetTransaction,
65 FirstAvailableLeaf,
66 GetStateCert,
67
68 Any,
70}
71
72impl FailableAction {
73 fn matches(self, action: Self) -> bool {
75 self == action || self == Self::Any
78 }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83 #[default]
84 Never,
85 Once(FailableAction),
86 Always(FailableAction),
87}
88
89impl FailureMode {
90 fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91 match self {
92 Self::Once(fail_action) if fail_action.matches(action) => {
93 *self = Self::Never;
94 },
95 Self::Always(fail_action) if fail_action.matches(action) => {},
96 _ => return Ok(()),
97 }
98
99 Err(QueryError::Error {
100 message: "injected error".into(),
101 })
102 }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107 on_read: FailureMode,
108 on_write: FailureMode,
109 on_commit: FailureMode,
110 on_begin_writable: FailureMode,
111 on_begin_read_only: FailureMode,
112}
113
114#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117 inner: S,
118 failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122 fn from(inner: S) -> Self {
123 Self {
124 inner,
125 failure: Default::default(),
126 }
127 }
128}
129
130impl<S> FailStorage<S> {
131 pub async fn fail_reads(&self, action: FailableAction) {
132 self.failure.lock().await.on_read = FailureMode::Always(action);
133 }
134
135 pub async fn fail_writes(&self, action: FailableAction) {
136 self.failure.lock().await.on_write = FailureMode::Always(action);
137 }
138
139 pub async fn fail_commits(&self, action: FailableAction) {
140 self.failure.lock().await.on_commit = FailureMode::Always(action);
141 }
142
143 pub async fn fail_begins_writable(&self, action: FailableAction) {
144 self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145 }
146
147 pub async fn fail_begins_read_only(&self, action: FailableAction) {
148 self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149 }
150
151 pub async fn fail(&self, action: FailableAction) {
152 let mut failure = self.failure.lock().await;
153 failure.on_read = FailureMode::Always(action);
154 failure.on_write = FailureMode::Always(action);
155 failure.on_commit = FailureMode::Always(action);
156 failure.on_begin_writable = FailureMode::Always(action);
157 failure.on_begin_read_only = FailureMode::Always(action);
158 }
159
160 pub async fn pass_reads(&self) {
161 self.failure.lock().await.on_read = FailureMode::Never;
162 }
163
164 pub async fn pass_writes(&self) {
165 self.failure.lock().await.on_write = FailureMode::Never;
166 }
167
168 pub async fn pass_commits(&self) {
169 self.failure.lock().await.on_commit = FailureMode::Never;
170 }
171
172 pub async fn pass_begins_writable(&self) {
173 self.failure.lock().await.on_begin_writable = FailureMode::Never;
174 }
175
176 pub async fn pass_begins_read_only(&self) {
177 self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178 }
179
180 pub async fn pass(&self) {
181 let mut failure = self.failure.lock().await;
182 failure.on_read = FailureMode::Never;
183 failure.on_write = FailureMode::Never;
184 failure.on_commit = FailureMode::Never;
185 failure.on_begin_writable = FailureMode::Never;
186 failure.on_begin_read_only = FailureMode::Never;
187 }
188
189 pub async fn fail_one_read(&self, action: FailableAction) {
190 self.failure.lock().await.on_read = FailureMode::Once(action);
191 }
192
193 pub async fn fail_one_write(&self, action: FailableAction) {
194 self.failure.lock().await.on_write = FailureMode::Once(action);
195 }
196
197 pub async fn fail_one_commit(&self, action: FailableAction) {
198 self.failure.lock().await.on_commit = FailureMode::Once(action);
199 }
200
201 pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202 self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203 }
204
205 pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206 self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207 }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212 S: VersionedDataSource,
213{
214 type Transaction<'a>
215 = Transaction<S::Transaction<'a>>
216 where
217 Self: 'a;
218 type ReadOnly<'a>
219 = Transaction<S::ReadOnly<'a>>
220 where
221 Self: 'a;
222
223 async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224 self.failure
225 .lock()
226 .await
227 .on_begin_writable
228 .maybe_fail(FailableAction::Any)?;
229 Ok(Transaction {
230 inner: self.inner.write().await?,
231 failure: self.failure.clone(),
232 })
233 }
234
235 async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236 self.failure
237 .lock()
238 .await
239 .on_begin_read_only
240 .maybe_fail(FailableAction::Any)?;
241 Ok(Transaction {
242 inner: self.inner.read().await?,
243 failure: self.failure.clone(),
244 })
245 }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250 S: PrunerConfig,
251{
252 fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253 self.inner.set_pruning_config(cfg);
254 }
255
256 fn get_pruning_config(&self) -> Option<PrunerCfg> {
257 self.inner.get_pruning_config()
258 }
259}
260
261#[async_trait]
262impl<S> PruneStorage for FailStorage<S>
263where
264 S: PruneStorage + Sync,
265{
266 type Pruner = S::Pruner;
267
268 async fn get_disk_usage(&self) -> anyhow::Result<u64> {
269 self.inner.get_disk_usage().await
270 }
271
272 async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
273 self.inner.prune(pruner).await
274 }
275}
276
277impl<S> HasMetrics for FailStorage<S>
278where
279 S: HasMetrics,
280{
281 fn metrics(&self) -> &PrometheusMetrics {
282 self.inner.metrics()
283 }
284}
285
286#[derive(Debug)]
287pub struct Transaction<T> {
288 inner: T,
289 failure: Arc<Mutex<Failure>>,
290}
291
292impl<T> Transaction<T> {
293 async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
294 self.failure.lock().await.on_read.maybe_fail(action)
295 }
296
297 async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
298 self.failure.lock().await.on_write.maybe_fail(action)
299 }
300
301 async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
302 self.failure.lock().await.on_commit.maybe_fail(action)
303 }
304}
305
306impl<T> update::Transaction for Transaction<T>
307where
308 T: update::Transaction,
309{
310 async fn commit(self) -> anyhow::Result<()> {
311 self.maybe_fail_commit(FailableAction::Any).await?;
312 self.inner.commit().await
313 }
314
315 fn revert(self) -> impl Future + Send {
316 self.inner.revert()
317 }
318}
319
320#[async_trait]
321impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
322where
323 Types: NodeType,
324 Header<Types>: QueryableHeader<Types>,
325 Payload<Types>: QueryablePayload<Types>,
326 T: AvailabilityStorage<Types>,
327{
328 async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
329 self.maybe_fail_read(FailableAction::GetLeaf).await?;
330 self.inner.get_leaf(id).await
331 }
332
333 async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
334 self.maybe_fail_read(FailableAction::GetBlock).await?;
335 self.inner.get_block(id).await
336 }
337
338 async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
339 self.maybe_fail_read(FailableAction::GetHeader).await?;
340 self.inner.get_header(id).await
341 }
342
343 async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
344 self.maybe_fail_read(FailableAction::GetPayload).await?;
345 self.inner.get_payload(id).await
346 }
347
348 async fn get_payload_metadata(
349 &mut self,
350 id: BlockId<Types>,
351 ) -> QueryResult<PayloadMetadata<Types>> {
352 self.maybe_fail_read(FailableAction::GetPayloadMetadata)
353 .await?;
354 self.inner.get_payload_metadata(id).await
355 }
356
357 async fn get_vid_common(
358 &mut self,
359 id: BlockId<Types>,
360 ) -> QueryResult<VidCommonQueryData<Types>> {
361 self.maybe_fail_read(FailableAction::GetVidCommon).await?;
362 self.inner.get_vid_common(id).await
363 }
364
365 async fn get_vid_common_metadata(
366 &mut self,
367 id: BlockId<Types>,
368 ) -> QueryResult<VidCommonMetadata<Types>> {
369 self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
370 .await?;
371 self.inner.get_vid_common_metadata(id).await
372 }
373
374 async fn get_leaf_range<R>(
375 &mut self,
376 range: R,
377 ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
378 where
379 R: RangeBounds<usize> + Send + 'static,
380 {
381 self.maybe_fail_read(FailableAction::GetLeafRange).await?;
382 self.inner.get_leaf_range(range).await
383 }
384
385 async fn get_block_range<R>(
386 &mut self,
387 range: R,
388 ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
389 where
390 R: RangeBounds<usize> + Send + 'static,
391 {
392 self.maybe_fail_read(FailableAction::GetBlockRange).await?;
393 self.inner.get_block_range(range).await
394 }
395
396 async fn get_payload_range<R>(
397 &mut self,
398 range: R,
399 ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
400 where
401 R: RangeBounds<usize> + Send + 'static,
402 {
403 self.maybe_fail_read(FailableAction::GetPayloadRange)
404 .await?;
405 self.inner.get_payload_range(range).await
406 }
407
408 async fn get_payload_metadata_range<R>(
409 &mut self,
410 range: R,
411 ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
412 where
413 R: RangeBounds<usize> + Send + 'static,
414 {
415 self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
416 .await?;
417 self.inner.get_payload_metadata_range(range).await
418 }
419
420 async fn get_vid_common_range<R>(
421 &mut self,
422 range: R,
423 ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
424 where
425 R: RangeBounds<usize> + Send + 'static,
426 {
427 self.maybe_fail_read(FailableAction::GetVidCommonRange)
428 .await?;
429 self.inner.get_vid_common_range(range).await
430 }
431
432 async fn get_vid_common_metadata_range<R>(
433 &mut self,
434 range: R,
435 ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
436 where
437 R: RangeBounds<usize> + Send + 'static,
438 {
439 self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
440 .await?;
441 self.inner.get_vid_common_metadata_range(range).await
442 }
443
444 async fn get_block_with_transaction(
445 &mut self,
446 hash: TransactionHash<Types>,
447 ) -> QueryResult<BlockQueryData<Types>> {
448 self.maybe_fail_read(FailableAction::GetTransaction).await?;
449 self.inner.get_block_with_transaction(hash).await
450 }
451
452 async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
453 self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
454 .await?;
455 self.inner.first_available_leaf(from).await
456 }
457}
458
459impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
460where
461 Types: NodeType,
462 Header<Types>: QueryableHeader<Types>,
463 Payload<Types>: QueryablePayload<Types>,
464 T: UpdateAvailabilityStorage<Types> + Send + Sync,
465{
466 async fn insert_leaf_with_qc_chain(
467 &mut self,
468 leaf: LeafQueryData<Types>,
469 qc_chain: Option<[CertificatePair<Types>; 2]>,
470 ) -> anyhow::Result<()> {
471 self.maybe_fail_write(FailableAction::Any).await?;
472 self.inner.insert_leaf_with_qc_chain(leaf, qc_chain).await
473 }
474
475 async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
476 self.maybe_fail_write(FailableAction::Any).await?;
477 self.inner.insert_block(block).await
478 }
479
480 async fn insert_vid(
481 &mut self,
482 common: VidCommonQueryData<Types>,
483 share: Option<VidShare>,
484 ) -> anyhow::Result<()> {
485 self.maybe_fail_write(FailableAction::Any).await?;
486 self.inner.insert_vid(common, share).await
487 }
488}
489
490#[async_trait]
491impl<T> PrunedHeightStorage for Transaction<T>
492where
493 T: PrunedHeightStorage + Send + Sync,
494{
495 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
496 self.maybe_fail_read(FailableAction::Any).await?;
497 self.inner.load_pruned_height().await
498 }
499}
500
501#[async_trait]
502impl<Types, T> NodeStorage<Types> for Transaction<T>
503where
504 Types: NodeType,
505 Header<Types>: QueryableHeader<Types>,
506 T: NodeStorage<Types> + Send + Sync,
507{
508 async fn block_height(&mut self) -> QueryResult<usize> {
509 self.maybe_fail_read(FailableAction::Any).await?;
510 self.inner.block_height().await
511 }
512
513 async fn count_transactions_in_range(
514 &mut self,
515 range: impl RangeBounds<usize> + Send,
516 namespace: Option<NamespaceId<Types>>,
517 ) -> QueryResult<usize> {
518 self.maybe_fail_read(FailableAction::Any).await?;
519 self.inner
520 .count_transactions_in_range(range, namespace)
521 .await
522 }
523
524 async fn payload_size_in_range(
525 &mut self,
526 range: impl RangeBounds<usize> + Send,
527 namespace: Option<NamespaceId<Types>>,
528 ) -> QueryResult<usize> {
529 self.maybe_fail_read(FailableAction::Any).await?;
530 self.inner.payload_size_in_range(range, namespace).await
531 }
532
533 async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
534 where
535 ID: Into<BlockId<Types>> + Send + Sync,
536 {
537 self.maybe_fail_read(FailableAction::Any).await?;
538 self.inner.vid_share(id).await
539 }
540
541 async fn sync_status_for_range(
542 &mut self,
543 start: usize,
544 end: usize,
545 ) -> QueryResult<SyncStatusQueryData> {
546 self.maybe_fail_read(FailableAction::Any).await?;
547 self.inner.sync_status_for_range(start, end).await
548 }
549
550 async fn get_header_window(
551 &mut self,
552 start: impl Into<WindowStart<Types>> + Send + Sync,
553 end: u64,
554 limit: usize,
555 ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
556 self.maybe_fail_read(FailableAction::Any).await?;
557 self.inner.get_header_window(start, end, limit).await
558 }
559
560 async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
561 self.maybe_fail_read(FailableAction::Any).await?;
562 self.inner.latest_qc_chain().await
563 }
564}
565
566impl<Types, T> AggregatesStorage<Types> for Transaction<T>
567where
568 Types: NodeType,
569 Header<Types>: QueryableHeader<Types>,
570 T: AggregatesStorage<Types> + Send + Sync,
571{
572 async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
573 self.maybe_fail_read(FailableAction::Any).await?;
574 self.inner.aggregates_height().await
575 }
576
577 async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
578 self.maybe_fail_read(FailableAction::Any).await?;
579 self.inner.load_prev_aggregate().await
580 }
581}
582
583impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
584where
585 Types: NodeType,
586 Header<Types>: QueryableHeader<Types>,
587 T: UpdateAggregatesStorage<Types> + Send + Sync,
588{
589 async fn update_aggregates(
590 &mut self,
591 prev: Aggregate<Types>,
592 blocks: &[PayloadMetadata<Types>],
593 ) -> anyhow::Result<Aggregate<Types>> {
594 self.maybe_fail_write(FailableAction::Any).await?;
595 self.inner.update_aggregates(prev, blocks).await
596 }
597}