1use std::{
22 collections::{HashMap, HashSet},
23 marker::PhantomData,
24 time::{Duration, Instant},
25};
26
27use anyhow::{bail, Context};
28use ark_serialize::CanonicalSerialize;
29use async_trait::async_trait;
30use committable::Committable;
31use derive_more::{Deref, DerefMut};
32use futures::{future::Future, stream::TryStreamExt};
33use hotshot_types::{
34 data::VidShare,
35 traits::{
36 block_contents::BlockHeader,
37 metrics::{Counter, Gauge, Histogram, Metrics},
38 node_implementation::{ConsensusTime, NodeType},
39 EncodeBytes,
40 },
41};
42use itertools::Itertools;
43use jf_merkle_tree_compat::prelude::{MerkleNode, MerkleProof};
44pub use sqlx::Executor;
45use sqlx::{
46 pool::Pool, query_builder::Separated, types::BitVec, Encode, Execute, FromRow, QueryBuilder,
47 Type,
48};
49use tokio::time::sleep;
50
51use super::{
52 queries::{
53 self,
54 state::{build_hash_batch_insert, Node},
55 DecodeError,
56 },
57 Database, Db,
58};
59use crate::{
60 availability::{
61 BlockQueryData, LeafQueryData, QueryableHeader, QueryablePayload, StateCertQueryDataV2,
62 VidCommonQueryData,
63 },
64 data_source::{
65 storage::{pruning::PrunedHeightStorage, UpdateAvailabilityStorage},
66 update,
67 },
68 merklized_state::{MerklizedState, UpdateStateData},
69 types::HeightIndexed,
70 Header, Payload, QueryError, QueryResult,
71};
72
73pub type Query<'q> = sqlx::query::Query<'q, Db, <Db as Database>::Arguments<'q>>;
74pub type QueryAs<'q, T> = sqlx::query::QueryAs<'q, Db, T, <Db as Database>::Arguments<'q>>;
75
76pub fn query(sql: &str) -> Query<'_> {
77 sqlx::query(sql)
78}
79
80pub fn query_as<'q, T>(sql: &'q str) -> QueryAs<'q, T>
81where
82 T: for<'r> FromRow<'r, <Db as Database>::Row>,
83{
84 sqlx::query_as(sql)
85}
86
87#[derive(Clone, Copy, Debug, Default)]
89pub struct Write;
90
91#[derive(Clone, Copy, Debug, Default)]
93pub struct Read;
94
95pub trait TransactionMode: Send + Sync {
97 fn begin(
98 conn: &mut <Db as Database>::Connection,
99 ) -> impl Future<Output = anyhow::Result<()>> + Send;
100
101 fn display() -> &'static str;
102}
103
104impl TransactionMode for Write {
105 #[allow(unused_variables)]
106 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
107 #[cfg(feature = "embedded-db")]
136 conn.execute("UPDATE pruned_height SET id = id WHERE false")
137 .await?;
138
139 #[cfg(not(feature = "embedded-db"))]
142 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
143 .await?;
144
145 Ok(())
146 }
147
148 fn display() -> &'static str {
149 "write"
150 }
151}
152
153impl TransactionMode for Read {
154 #[allow(unused_variables)]
155 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
156 #[cfg(not(feature = "embedded-db"))]
165 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE")
166 .await?;
167
168 Ok(())
169 }
170
171 fn display() -> &'static str {
172 "read-only"
173 }
174}
175
176#[derive(Clone, Copy, Debug)]
177enum CloseType {
178 Commit,
179 Revert,
180 Drop,
181}
182
183#[derive(Debug)]
184struct TransactionMetricsGuard<Mode> {
185 started_at: Instant,
186 metrics: PoolMetrics,
187 close_type: CloseType,
188 _mode: PhantomData<Mode>,
189}
190
191impl<Mode: TransactionMode> TransactionMetricsGuard<Mode> {
192 fn begin(metrics: PoolMetrics) -> Self {
193 let started_at = Instant::now();
194 tracing::trace!(mode = Mode::display(), ?started_at, "begin");
195 metrics.open_transactions.update(1);
196
197 Self {
198 started_at,
199 metrics,
200 close_type: CloseType::Drop,
201 _mode: Default::default(),
202 }
203 }
204
205 fn set_closed(&mut self, t: CloseType) {
206 self.close_type = t;
207 }
208}
209
210impl<Mode> Drop for TransactionMetricsGuard<Mode> {
211 fn drop(&mut self) {
212 self.metrics
213 .transaction_durations
214 .add_point((self.started_at.elapsed().as_millis() as f64) / 1000.);
215 self.metrics.open_transactions.update(-1);
216 match self.close_type {
217 CloseType::Commit => self.metrics.commits.add(1),
218 CloseType::Revert => self.metrics.reverts.add(1),
219 CloseType::Drop => self.metrics.drops.add(1),
220 }
221 tracing::trace!(started_at = ?self.started_at, reason = ?self.close_type, "close");
222 }
223}
224
225#[derive(Debug, Deref, DerefMut)]
227pub struct Transaction<Mode> {
228 #[deref]
229 #[deref_mut]
230 inner: sqlx::Transaction<'static, Db>,
231 metrics: TransactionMetricsGuard<Mode>,
232}
233
234impl<Mode: TransactionMode> Transaction<Mode> {
235 pub(super) async fn new(pool: &Pool<Db>, metrics: PoolMetrics) -> anyhow::Result<Self> {
236 let mut inner = pool.begin().await?;
237 let metrics = TransactionMetricsGuard::begin(metrics);
238 Mode::begin(inner.as_mut()).await?;
239 Ok(Self { inner, metrics })
240 }
241}
242
243impl<Mode: TransactionMode> update::Transaction for Transaction<Mode> {
244 async fn commit(mut self) -> anyhow::Result<()> {
245 self.inner.commit().await?;
246 self.metrics.set_closed(CloseType::Commit);
247 Ok(())
248 }
249 fn revert(mut self) -> impl Future + Send {
250 async move {
251 self.inner.rollback().await.unwrap();
252 self.metrics.set_closed(CloseType::Revert);
253 }
254 }
255}
256
257pub trait Params<'p> {
287 fn bind<'q, 'r>(
288 self,
289 q: &'q mut Separated<'r, 'p, Db, &'static str>,
290 ) -> &'q mut Separated<'r, 'p, Db, &'static str>
291 where
292 'p: 'r;
293}
294
295pub trait FixedLengthParams<'p, const N: usize>: Params<'p> {}
301
302macro_rules! impl_tuple_params {
303 ($n:literal, ($($t:ident,)+)) => {
304 impl<'p, $($t),+> Params<'p> for ($($t,)+)
305 where $(
306 $t: 'p + Encode<'p, Db> + Type<Db>
307 ),+{
308 fn bind<'q, 'r>(self, q: &'q mut Separated<'r, 'p, Db, &'static str>) -> &'q mut Separated<'r, 'p, Db, &'static str>
309 where 'p: 'r,
310 {
311 #[allow(non_snake_case)]
312 let ($($t,)+) = self;
313 q $(
314 .push_bind($t)
315 )+
316
317 }
318 }
319
320 impl<'p, $($t),+> FixedLengthParams<'p, $n> for ($($t,)+)
321 where $(
322 $t: 'p + for<'q> Encode<'q, Db> + Type<Db>
323 ),+ {
324 }
325 };
326}
327
328impl_tuple_params!(1, (T,));
329impl_tuple_params!(2, (T1, T2,));
330impl_tuple_params!(3, (T1, T2, T3,));
331impl_tuple_params!(4, (T1, T2, T3, T4,));
332impl_tuple_params!(5, (T1, T2, T3, T4, T5,));
333impl_tuple_params!(6, (T1, T2, T3, T4, T5, T6,));
334impl_tuple_params!(7, (T1, T2, T3, T4, T5, T6, T7,));
335impl_tuple_params!(8, (T1, T2, T3, T4, T5, T6, T7, T8,));
336
337pub fn build_where_in<'a, I>(
338 query: &'a str,
339 column: &'a str,
340 values: I,
341) -> QueryResult<(queries::QueryBuilder<'a>, String)>
342where
343 I: IntoIterator,
344 I::Item: 'a + Encode<'a, Db> + Type<Db>,
345{
346 let mut builder = queries::QueryBuilder::default();
347 let params = values
348 .into_iter()
349 .map(|v| Ok(format!("{} ", builder.bind(v)?)))
350 .collect::<QueryResult<Vec<String>>>()?;
351
352 if params.is_empty() {
353 return Err(QueryError::Error {
354 message: "failed to build WHERE IN query. No parameter found ".to_string(),
355 });
356 }
357
358 let sql = format!(
359 "{query} where {column} IN ({}) ",
360 params.into_iter().join(",")
361 );
362
363 Ok((builder, sql))
364}
365
366impl Transaction<Write> {
368 pub async fn upsert<'p, const N: usize, R>(
369 &mut self,
370 table: &str,
371 columns: [&str; N],
372 pk: impl IntoIterator<Item = &str>,
373 rows: R,
374 ) -> anyhow::Result<()>
375 where
376 R: IntoIterator,
377 R::Item: 'p + FixedLengthParams<'p, N> + Clone,
378 {
379 let set_columns = columns
380 .iter()
381 .map(|col| format!("{col} = excluded.{col}"))
382 .join(",");
383
384 let columns_str = columns.iter().map(|col| format!("\"{col}\"")).join(",");
385
386 let pk = pk.into_iter().join(",");
387
388 let rows: Vec<_> = rows.into_iter().collect();
389 let num_rows = rows.len();
390
391 if num_rows == 0 {
392 tracing::warn!("trying to upsert 0 rows into {table}, this has no effect");
393 return Ok(());
394 }
395
396 let interval = Duration::from_secs(1);
397 let mut retries = 5;
398
399 let mut query_builder =
400 QueryBuilder::new(format!("INSERT INTO \"{table}\" ({columns_str}) "));
401
402 loop {
403 let query_builder = query_builder.reset();
408
409 query_builder.push_values(rows.clone(), |mut b, row| {
410 row.bind(&mut b);
411 });
412
413 query_builder.push(format!(" ON CONFLICT ({pk}) DO UPDATE SET {set_columns}"));
414
415 let query = query_builder.build();
416 let statement = query.sql();
417
418 match self.execute(query).await {
419 Ok(res) => {
420 let rows_modified = res.rows_affected() as usize;
421 if rows_modified != num_rows {
422 let error = format!(
423 "unexpected number of rows modified: expected {num_rows}, got \
424 {rows_modified}. query: {statement}"
425 );
426 tracing::error!(error);
427 bail!(error);
428 }
429 return Ok(());
430 },
431 Err(err) => {
432 tracing::error!(
433 statement,
434 "error in statement execution ({} tries remaining): {err}",
435 retries
436 );
437 if retries == 0 {
438 bail!(err);
439 }
440 retries -= 1;
441 sleep(interval).await;
442 },
443 }
444 }
445 }
446}
447
448impl Transaction<Write> {
450 pub(super) async fn delete_batch(
452 &mut self,
453 state_tables: Vec<String>,
454 height: u64,
455 ) -> anyhow::Result<()> {
456 self.execute(query("DELETE FROM header WHERE height <= $1").bind(height as i64))
457 .await?;
458
459 for state_table in state_tables {
463 self.execute(
464 query(&format!(
465 "
466 DELETE FROM {state_table} WHERE (path, created) IN
467 (SELECT path, created FROM
468 (SELECT path, created,
469 ROW_NUMBER() OVER (PARTITION BY path ORDER BY created DESC) as rank
470 FROM {state_table} WHERE created <= $1) ranked_nodes WHERE rank != 1)"
471 ))
472 .bind(height as i64),
473 )
474 .await?;
475 }
476
477 self.save_pruned_height(height).await?;
478 Ok(())
479 }
480
481 pub(super) async fn save_pruned_height(&mut self, height: u64) -> anyhow::Result<()> {
483 self.upsert(
486 "pruned_height",
487 ["id", "last_height"],
488 ["id"],
489 [(1i32, height as i64)],
490 )
491 .await
492 }
493}
494
495impl<Types> UpdateAvailabilityStorage<Types> for Transaction<Write>
496where
497 Types: NodeType,
498 Payload<Types>: QueryablePayload<Types>,
499 Header<Types>: QueryableHeader<Types>,
500{
501 async fn insert_leaf(&mut self, leaf: LeafQueryData<Types>) -> anyhow::Result<()> {
502 let height = leaf.height();
503
504 if let Some(pruned_height) = self.load_pruned_height().await? {
507 if height <= pruned_height {
508 tracing::info!(
509 height,
510 pruned_height,
511 "ignoring leaf which is already pruned"
512 );
513 return Ok(());
514 }
515 }
516
517 let header_json = serde_json::to_value(leaf.leaf().block_header())
520 .context("failed to serialize header")?;
521 self.upsert(
522 "header",
523 ["height", "hash", "payload_hash", "data", "timestamp"],
524 ["height"],
525 [(
526 height as i64,
527 leaf.block_hash().to_string(),
528 leaf.leaf().block_header().payload_commitment().to_string(),
529 header_json,
530 leaf.leaf().block_header().timestamp() as i64,
531 )],
532 )
533 .await?;
534
535 let query = query("INSERT INTO payload (height) VALUES ($1) ON CONFLICT DO NOTHING")
543 .bind(height as i64);
544 query.execute(self.as_mut()).await?;
545
546 let leaf_json = serde_json::to_value(leaf.leaf()).context("failed to serialize leaf")?;
549 let qc_json = serde_json::to_value(leaf.qc()).context("failed to serialize QC")?;
550 self.upsert(
551 "leaf2",
552 ["height", "hash", "block_hash", "leaf", "qc"],
553 ["height"],
554 [(
555 height as i64,
556 leaf.hash().to_string(),
557 leaf.block_hash().to_string(),
558 leaf_json,
559 qc_json,
560 )],
561 )
562 .await?;
563
564 Ok(())
565 }
566
567 async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
568 let height = block.height();
569
570 if let Some(pruned_height) = self.load_pruned_height().await? {
573 if height <= pruned_height {
574 tracing::info!(
575 height,
576 pruned_height,
577 "ignoring block which is already pruned"
578 );
579 return Ok(());
580 }
581 }
582
583 let payload = block.payload.encode();
586
587 self.upsert(
588 "payload",
589 ["height", "data", "size", "num_transactions"],
590 ["height"],
591 [(
592 height as i64,
593 payload.as_ref().to_vec(),
594 block.size() as i32,
595 block.num_transactions() as i32,
596 )],
597 )
598 .await?;
599
600 let mut rows = vec![];
602 for (txn_ix, txn) in block.enumerate() {
603 let ns_id = block.header().namespace_id(&txn_ix.ns_index).unwrap();
604 rows.push((
605 txn.commit().to_string(),
606 height as i64,
607 txn_ix.ns_index.into(),
608 ns_id.into(),
609 txn_ix.position as i64,
610 ));
611 }
612 if !rows.is_empty() {
613 self.upsert(
614 "transactions",
615 ["hash", "block_height", "ns_index", "ns_id", "position"],
616 ["block_height", "ns_id", "position"],
617 rows,
618 )
619 .await?;
620 }
621
622 Ok(())
623 }
624
625 async fn insert_vid(
626 &mut self,
627 common: VidCommonQueryData<Types>,
628 share: Option<VidShare>,
629 ) -> anyhow::Result<()> {
630 let height = common.height();
631
632 if let Some(pruned_height) = self.load_pruned_height().await? {
635 if height <= pruned_height {
636 tracing::info!(
637 height,
638 pruned_height,
639 "ignoring VID common which is already pruned"
640 );
641 return Ok(());
642 }
643 }
644
645 let common_data =
646 bincode::serialize(common.common()).context("failed to serialize VID common data")?;
647 if let Some(share) = share {
648 let share_data = bincode::serialize(&share).context("failed to serialize VID share")?;
649 self.upsert(
650 "vid2",
651 ["height", "common", "share"],
652 ["height"],
653 [(height as i64, common_data, share_data)],
654 )
655 .await
656 } else {
657 self.upsert(
661 "vid2",
662 ["height", "common"],
663 ["height"],
664 [(height as i64, common_data)],
665 )
666 .await
667 }
668 }
669
670 async fn insert_state_cert(
671 &mut self,
672 state_cert: StateCertQueryDataV2<Types>,
673 ) -> anyhow::Result<()> {
674 let height = state_cert.height();
675
676 if let Some(pruned_height) = self.load_pruned_height().await? {
679 if height <= pruned_height {
680 tracing::info!(
681 height,
682 pruned_height,
683 "ignoring state cert which is already pruned"
684 );
685 return Ok(());
686 }
687 }
688 let epoch = state_cert.0.epoch.u64();
689 let bytes = bincode::serialize(&state_cert.0).context("failed to serialize state cert")?;
690 self.upsert(
693 "finalized_state_cert",
694 ["epoch", "state_cert"],
695 ["epoch"],
696 [(epoch as i64, bytes)],
697 )
698 .await?;
699 Ok(())
700 }
701}
702
703#[async_trait]
704impl<Types: NodeType, State: MerklizedState<Types, ARITY>, const ARITY: usize>
705 UpdateStateData<Types, State, ARITY> for Transaction<Write>
706{
707 async fn set_last_state_height(&mut self, height: usize) -> anyhow::Result<()> {
708 self.upsert(
709 "last_merklized_state_height",
710 ["id", "height"],
711 ["id"],
712 [(1i32, height as i64)],
713 )
714 .await?;
715
716 Ok(())
717 }
718
719 async fn insert_merkle_nodes(
720 &mut self,
721 proof: MerkleProof<State::Entry, State::Key, State::T, ARITY>,
722 traversal_path: Vec<usize>,
723 block_number: u64,
724 ) -> anyhow::Result<()> {
725 let pos = proof.pos;
726 let path = proof.proof;
727
728 let name = State::state_type();
729 let block_number = block_number as i64;
730
731 let mut traversal_path = traversal_path.iter().map(|n| *n as i32);
732
733 let mut nodes = Vec::new();
736 let mut hashset = HashSet::new();
737
738 for node in path.iter() {
739 match node {
740 MerkleNode::Empty => {
741 let index = serde_json::to_value(pos.clone())
742 .decode_error("malformed merkle position")?;
743 let node_path = traversal_path.clone().rev().collect();
747 nodes.push((
748 Node {
749 path: node_path,
750 idx: Some(index),
751 ..Default::default()
752 },
753 None,
754 [0_u8; 32].to_vec(),
755 ));
756 hashset.insert([0_u8; 32].to_vec());
757 },
758 MerkleNode::ForgettenSubtree { .. } => {
759 bail!("Node in the Merkle path contains a forgetten subtree");
760 },
761 MerkleNode::Leaf { value, pos, elem } => {
762 let mut leaf_commit = Vec::new();
763 value
765 .serialize_compressed(&mut leaf_commit)
766 .decode_error("malformed merkle leaf commitment")?;
767
768 let path = traversal_path.clone().rev().collect();
769
770 let index = serde_json::to_value(pos.clone())
771 .decode_error("malformed merkle position")?;
772 let entry =
773 serde_json::to_value(elem).decode_error("malformed merkle element")?;
774
775 nodes.push((
776 Node {
777 path,
778 idx: Some(index),
779 entry: Some(entry),
780 ..Default::default()
781 },
782 None,
783 leaf_commit.clone(),
784 ));
785
786 hashset.insert(leaf_commit);
787 },
788 MerkleNode::Branch { value, children } => {
789 let mut branch_hash = Vec::new();
791 value
792 .serialize_compressed(&mut branch_hash)
793 .decode_error("malformed merkle branch hash")?;
794
795 let mut children_bitvec = BitVec::new();
798 let mut children_values = Vec::new();
799 for child in children {
800 let child = child.as_ref();
801 match child {
802 MerkleNode::Empty => {
803 children_bitvec.push(false);
804 },
805 MerkleNode::Branch { value, .. }
806 | MerkleNode::Leaf { value, .. }
807 | MerkleNode::ForgettenSubtree { value } => {
808 let mut hash = Vec::new();
809 value
810 .serialize_compressed(&mut hash)
811 .decode_error("malformed merkle node hash")?;
812
813 children_values.push(hash);
814 children_bitvec.push(true);
816 },
817 }
818 }
819
820 let path = traversal_path.clone().rev().collect();
822 nodes.push((
823 Node {
824 path,
825 children: None,
826 children_bitvec: Some(children_bitvec),
827 ..Default::default()
828 },
829 Some(children_values.clone()),
830 branch_hash.clone(),
831 ));
832 hashset.insert(branch_hash);
833 hashset.extend(children_values);
834 },
835 }
836
837 traversal_path.next();
840 }
841 let hashes = hashset.into_iter().collect::<Vec<Vec<u8>>>();
843
844 let (query, sql) = build_hash_batch_insert(&hashes)?;
848 let nodes_hash_ids: HashMap<Vec<u8>, i32> = query
849 .query_as(&sql)
850 .fetch(self.as_mut())
851 .try_collect()
852 .await?;
853
854 for (node, children, hash) in &mut nodes {
856 node.created = block_number;
857 node.hash_id = *nodes_hash_ids.get(&*hash).ok_or(QueryError::Error {
858 message: "Missing node hash".to_string(),
859 })?;
860
861 if let Some(children) = children {
862 let children_hashes = children
863 .iter()
864 .map(|c| nodes_hash_ids.get(c).copied())
865 .collect::<Option<Vec<i32>>>()
866 .ok_or(QueryError::Error {
867 message: "Missing child hash".to_string(),
868 })?;
869
870 node.children = Some(children_hashes.into());
871 }
872 }
873
874 Node::upsert(name, nodes.into_iter().map(|(n, ..)| n), self).await?;
875
876 Ok(())
877 }
878}
879
880#[async_trait]
881impl<Mode: TransactionMode> PrunedHeightStorage for Transaction<Mode> {
882 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
883 let Some((height,)) =
884 query_as::<(i64,)>("SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1")
885 .fetch_optional(self.as_mut())
886 .await?
887 else {
888 return Ok(None);
889 };
890 Ok(Some(height as u64))
891 }
892}
893
894#[derive(Clone, Debug)]
895pub(super) struct PoolMetrics {
896 open_transactions: Box<dyn Gauge>,
897 transaction_durations: Box<dyn Histogram>,
898 commits: Box<dyn Counter>,
899 reverts: Box<dyn Counter>,
900 drops: Box<dyn Counter>,
901}
902
903impl PoolMetrics {
904 pub(super) fn new(metrics: &(impl Metrics + ?Sized)) -> Self {
905 Self {
906 open_transactions: metrics.create_gauge("open_transactions".into(), None),
907 transaction_durations: metrics
908 .create_histogram("transaction_duration".into(), Some("s".into())),
909 commits: metrics.create_counter("committed_transactions".into(), None),
910 reverts: metrics.create_counter("reverted_transactions".into(), None),
911 drops: metrics.create_counter("dropped_transactions".into(), None),
912 }
913 }
914}