1#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21 data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25 pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
26 sql::MigrateTypes,
27 Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
28 UpdateAvailabilityStorage,
29};
30use crate::{
31 availability::{
32 BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33 QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34 },
35 data_source::{
36 storage::{PayloadMetadata, VidCommonMetadata},
37 update, VersionedDataSource,
38 },
39 metrics::PrometheusMetrics,
40 node::{SyncStatus, TimeWindowQueryData, WindowStart},
41 status::HasMetrics,
42 Header, Payload, QueryError, QueryResult,
43};
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48 GetHeader,
51 GetLeaf,
52 GetBlock,
53 GetPayload,
54 GetPayloadMetadata,
55 GetVidCommon,
56 GetVidCommonMetadata,
57 GetHeaderRange,
58 GetLeafRange,
59 GetBlockRange,
60 GetPayloadRange,
61 GetPayloadMetadataRange,
62 GetVidCommonRange,
63 GetVidCommonMetadataRange,
64 GetTransaction,
65 FirstAvailableLeaf,
66 GetStateCert,
67
68 Any,
70}
71
72impl FailableAction {
73 fn matches(self, action: Self) -> bool {
75 self == action || self == Self::Any
78 }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83 #[default]
84 Never,
85 Once(FailableAction),
86 Always(FailableAction),
87}
88
89impl FailureMode {
90 fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91 match self {
92 Self::Once(fail_action) if fail_action.matches(action) => {
93 *self = Self::Never;
94 },
95 Self::Always(fail_action) if fail_action.matches(action) => {},
96 _ => return Ok(()),
97 }
98
99 Err(QueryError::Error {
100 message: "injected error".into(),
101 })
102 }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107 on_read: FailureMode,
108 on_write: FailureMode,
109 on_commit: FailureMode,
110 on_begin_writable: FailureMode,
111 on_begin_read_only: FailureMode,
112}
113
114#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117 inner: S,
118 failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122 fn from(inner: S) -> Self {
123 Self {
124 inner,
125 failure: Default::default(),
126 }
127 }
128}
129
130impl<S> FailStorage<S> {
131 pub async fn fail_reads(&self, action: FailableAction) {
132 self.failure.lock().await.on_read = FailureMode::Always(action);
133 }
134
135 pub async fn fail_writes(&self, action: FailableAction) {
136 self.failure.lock().await.on_write = FailureMode::Always(action);
137 }
138
139 pub async fn fail_commits(&self, action: FailableAction) {
140 self.failure.lock().await.on_commit = FailureMode::Always(action);
141 }
142
143 pub async fn fail_begins_writable(&self, action: FailableAction) {
144 self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145 }
146
147 pub async fn fail_begins_read_only(&self, action: FailableAction) {
148 self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149 }
150
151 pub async fn fail(&self, action: FailableAction) {
152 let mut failure = self.failure.lock().await;
153 failure.on_read = FailureMode::Always(action);
154 failure.on_write = FailureMode::Always(action);
155 failure.on_commit = FailureMode::Always(action);
156 failure.on_begin_writable = FailureMode::Always(action);
157 failure.on_begin_read_only = FailureMode::Always(action);
158 }
159
160 pub async fn pass_reads(&self) {
161 self.failure.lock().await.on_read = FailureMode::Never;
162 }
163
164 pub async fn pass_writes(&self) {
165 self.failure.lock().await.on_write = FailureMode::Never;
166 }
167
168 pub async fn pass_commits(&self) {
169 self.failure.lock().await.on_commit = FailureMode::Never;
170 }
171
172 pub async fn pass_begins_writable(&self) {
173 self.failure.lock().await.on_begin_writable = FailureMode::Never;
174 }
175
176 pub async fn pass_begins_read_only(&self) {
177 self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178 }
179
180 pub async fn pass(&self) {
181 let mut failure = self.failure.lock().await;
182 failure.on_read = FailureMode::Never;
183 failure.on_write = FailureMode::Never;
184 failure.on_commit = FailureMode::Never;
185 failure.on_begin_writable = FailureMode::Never;
186 failure.on_begin_read_only = FailureMode::Never;
187 }
188
189 pub async fn fail_one_read(&self, action: FailableAction) {
190 self.failure.lock().await.on_read = FailureMode::Once(action);
191 }
192
193 pub async fn fail_one_write(&self, action: FailableAction) {
194 self.failure.lock().await.on_write = FailureMode::Once(action);
195 }
196
197 pub async fn fail_one_commit(&self, action: FailableAction) {
198 self.failure.lock().await.on_commit = FailureMode::Once(action);
199 }
200
201 pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202 self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203 }
204
205 pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206 self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207 }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212 S: VersionedDataSource,
213{
214 type Transaction<'a>
215 = Transaction<S::Transaction<'a>>
216 where
217 Self: 'a;
218 type ReadOnly<'a>
219 = Transaction<S::ReadOnly<'a>>
220 where
221 Self: 'a;
222
223 async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224 self.failure
225 .lock()
226 .await
227 .on_begin_writable
228 .maybe_fail(FailableAction::Any)?;
229 Ok(Transaction {
230 inner: self.inner.write().await?,
231 failure: self.failure.clone(),
232 })
233 }
234
235 async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236 self.failure
237 .lock()
238 .await
239 .on_begin_read_only
240 .maybe_fail(FailableAction::Any)?;
241 Ok(Transaction {
242 inner: self.inner.read().await?,
243 failure: self.failure.clone(),
244 })
245 }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250 S: PrunerConfig,
251{
252 fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253 self.inner.set_pruning_config(cfg);
254 }
255
256 fn get_pruning_config(&self) -> Option<PrunerCfg> {
257 self.inner.get_pruning_config()
258 }
259}
260
261#[async_trait]
262impl<S, Types: NodeType> MigrateTypes<Types> for FailStorage<S>
263where
264 S: MigrateTypes<Types> + Sync,
265{
266 async fn migrate_types(&self, _batch_size: u64) -> anyhow::Result<()> {
267 Ok(())
268 }
269}
270
271#[async_trait]
272impl<S> PruneStorage for FailStorage<S>
273where
274 S: PruneStorage + Sync,
275{
276 type Pruner = S::Pruner;
277
278 async fn get_disk_usage(&self) -> anyhow::Result<u64> {
279 self.inner.get_disk_usage().await
280 }
281
282 async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
283 self.inner.prune(pruner).await
284 }
285}
286
287impl<S> HasMetrics for FailStorage<S>
288where
289 S: HasMetrics,
290{
291 fn metrics(&self) -> &PrometheusMetrics {
292 self.inner.metrics()
293 }
294}
295
296#[derive(Debug)]
297pub struct Transaction<T> {
298 inner: T,
299 failure: Arc<Mutex<Failure>>,
300}
301
302impl<T> Transaction<T> {
303 async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
304 self.failure.lock().await.on_read.maybe_fail(action)
305 }
306
307 async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
308 self.failure.lock().await.on_write.maybe_fail(action)
309 }
310
311 async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
312 self.failure.lock().await.on_commit.maybe_fail(action)
313 }
314}
315
316impl<T> update::Transaction for Transaction<T>
317where
318 T: update::Transaction,
319{
320 async fn commit(self) -> anyhow::Result<()> {
321 self.maybe_fail_commit(FailableAction::Any).await?;
322 self.inner.commit().await
323 }
324
325 fn revert(self) -> impl Future + Send {
326 self.inner.revert()
327 }
328}
329
330#[async_trait]
331impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
332where
333 Types: NodeType,
334 Header<Types>: QueryableHeader<Types>,
335 Payload<Types>: QueryablePayload<Types>,
336 T: AvailabilityStorage<Types>,
337{
338 async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
339 self.maybe_fail_read(FailableAction::GetLeaf).await?;
340 self.inner.get_leaf(id).await
341 }
342
343 async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
344 self.maybe_fail_read(FailableAction::GetBlock).await?;
345 self.inner.get_block(id).await
346 }
347
348 async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
349 self.maybe_fail_read(FailableAction::GetHeader).await?;
350 self.inner.get_header(id).await
351 }
352
353 async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
354 self.maybe_fail_read(FailableAction::GetPayload).await?;
355 self.inner.get_payload(id).await
356 }
357
358 async fn get_payload_metadata(
359 &mut self,
360 id: BlockId<Types>,
361 ) -> QueryResult<PayloadMetadata<Types>> {
362 self.maybe_fail_read(FailableAction::GetPayloadMetadata)
363 .await?;
364 self.inner.get_payload_metadata(id).await
365 }
366
367 async fn get_vid_common(
368 &mut self,
369 id: BlockId<Types>,
370 ) -> QueryResult<VidCommonQueryData<Types>> {
371 self.maybe_fail_read(FailableAction::GetVidCommon).await?;
372 self.inner.get_vid_common(id).await
373 }
374
375 async fn get_vid_common_metadata(
376 &mut self,
377 id: BlockId<Types>,
378 ) -> QueryResult<VidCommonMetadata<Types>> {
379 self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
380 .await?;
381 self.inner.get_vid_common_metadata(id).await
382 }
383
384 async fn get_leaf_range<R>(
385 &mut self,
386 range: R,
387 ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
388 where
389 R: RangeBounds<usize> + Send + 'static,
390 {
391 self.maybe_fail_read(FailableAction::GetLeafRange).await?;
392 self.inner.get_leaf_range(range).await
393 }
394
395 async fn get_block_range<R>(
396 &mut self,
397 range: R,
398 ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
399 where
400 R: RangeBounds<usize> + Send + 'static,
401 {
402 self.maybe_fail_read(FailableAction::GetBlockRange).await?;
403 self.inner.get_block_range(range).await
404 }
405
406 async fn get_payload_range<R>(
407 &mut self,
408 range: R,
409 ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
410 where
411 R: RangeBounds<usize> + Send + 'static,
412 {
413 self.maybe_fail_read(FailableAction::GetPayloadRange)
414 .await?;
415 self.inner.get_payload_range(range).await
416 }
417
418 async fn get_payload_metadata_range<R>(
419 &mut self,
420 range: R,
421 ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
422 where
423 R: RangeBounds<usize> + Send + 'static,
424 {
425 self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
426 .await?;
427 self.inner.get_payload_metadata_range(range).await
428 }
429
430 async fn get_vid_common_range<R>(
431 &mut self,
432 range: R,
433 ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
434 where
435 R: RangeBounds<usize> + Send + 'static,
436 {
437 self.maybe_fail_read(FailableAction::GetVidCommonRange)
438 .await?;
439 self.inner.get_vid_common_range(range).await
440 }
441
442 async fn get_vid_common_metadata_range<R>(
443 &mut self,
444 range: R,
445 ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
446 where
447 R: RangeBounds<usize> + Send + 'static,
448 {
449 self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
450 .await?;
451 self.inner.get_vid_common_metadata_range(range).await
452 }
453
454 async fn get_block_with_transaction(
455 &mut self,
456 hash: TransactionHash<Types>,
457 ) -> QueryResult<BlockQueryData<Types>> {
458 self.maybe_fail_read(FailableAction::GetTransaction).await?;
459 self.inner.get_block_with_transaction(hash).await
460 }
461
462 async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
463 self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
464 .await?;
465 self.inner.first_available_leaf(from).await
466 }
467}
468
469impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
470where
471 Types: NodeType,
472 Header<Types>: QueryableHeader<Types>,
473 Payload<Types>: QueryablePayload<Types>,
474 T: UpdateAvailabilityStorage<Types> + Send + Sync,
475{
476 async fn insert_leaf_with_qc_chain(
477 &mut self,
478 leaf: LeafQueryData<Types>,
479 qc_chain: Option<[CertificatePair<Types>; 2]>,
480 ) -> anyhow::Result<()> {
481 self.maybe_fail_write(FailableAction::Any).await?;
482 self.inner.insert_leaf_with_qc_chain(leaf, qc_chain).await
483 }
484
485 async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
486 self.maybe_fail_write(FailableAction::Any).await?;
487 self.inner.insert_block(block).await
488 }
489
490 async fn insert_vid(
491 &mut self,
492 common: VidCommonQueryData<Types>,
493 share: Option<VidShare>,
494 ) -> anyhow::Result<()> {
495 self.maybe_fail_write(FailableAction::Any).await?;
496 self.inner.insert_vid(common, share).await
497 }
498}
499
500#[async_trait]
501impl<T> PrunedHeightStorage for Transaction<T>
502where
503 T: PrunedHeightStorage + Send + Sync,
504{
505 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
506 self.maybe_fail_read(FailableAction::Any).await?;
507 self.inner.load_pruned_height().await
508 }
509}
510
511#[async_trait]
512impl<Types, T> NodeStorage<Types> for Transaction<T>
513where
514 Types: NodeType,
515 Header<Types>: QueryableHeader<Types>,
516 T: NodeStorage<Types> + Send + Sync,
517{
518 async fn block_height(&mut self) -> QueryResult<usize> {
519 self.maybe_fail_read(FailableAction::Any).await?;
520 self.inner.block_height().await
521 }
522
523 async fn count_transactions_in_range(
524 &mut self,
525 range: impl RangeBounds<usize> + Send,
526 namespace: Option<NamespaceId<Types>>,
527 ) -> QueryResult<usize> {
528 self.maybe_fail_read(FailableAction::Any).await?;
529 self.inner
530 .count_transactions_in_range(range, namespace)
531 .await
532 }
533
534 async fn payload_size_in_range(
535 &mut self,
536 range: impl RangeBounds<usize> + Send,
537 namespace: Option<NamespaceId<Types>>,
538 ) -> QueryResult<usize> {
539 self.maybe_fail_read(FailableAction::Any).await?;
540 self.inner.payload_size_in_range(range, namespace).await
541 }
542
543 async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
544 where
545 ID: Into<BlockId<Types>> + Send + Sync,
546 {
547 self.maybe_fail_read(FailableAction::Any).await?;
548 self.inner.vid_share(id).await
549 }
550
551 async fn sync_status(&mut self) -> QueryResult<SyncStatus> {
552 self.maybe_fail_read(FailableAction::Any).await?;
553 self.inner.sync_status().await
554 }
555
556 async fn get_header_window(
557 &mut self,
558 start: impl Into<WindowStart<Types>> + Send + Sync,
559 end: u64,
560 limit: usize,
561 ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
562 self.maybe_fail_read(FailableAction::Any).await?;
563 self.inner.get_header_window(start, end, limit).await
564 }
565
566 async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
567 self.maybe_fail_read(FailableAction::Any).await?;
568 self.inner.latest_qc_chain().await
569 }
570}
571
572impl<Types, T> AggregatesStorage<Types> for Transaction<T>
573where
574 Types: NodeType,
575 Header<Types>: QueryableHeader<Types>,
576 T: AggregatesStorage<Types> + Send + Sync,
577{
578 async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
579 self.maybe_fail_read(FailableAction::Any).await?;
580 self.inner.aggregates_height().await
581 }
582
583 async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
584 self.maybe_fail_read(FailableAction::Any).await?;
585 self.inner.load_prev_aggregate().await
586 }
587}
588
589impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
590where
591 Types: NodeType,
592 Header<Types>: QueryableHeader<Types>,
593 T: UpdateAggregatesStorage<Types> + Send + Sync,
594{
595 async fn update_aggregates(
596 &mut self,
597 prev: Aggregate<Types>,
598 blocks: &[PayloadMetadata<Types>],
599 ) -> anyhow::Result<Aggregate<Types>> {
600 self.maybe_fail_write(FailableAction::Any).await?;
601 self.inner.update_aggregates(prev, blocks).await
602 }
603}