1use std::{
22 collections::HashMap,
23 marker::PhantomData,
24 time::{Duration, Instant},
25};
26
27use anyhow::{bail, Context};
28use async_trait::async_trait;
29use committable::Committable;
30use derive_more::{Deref, DerefMut};
31use futures::future::Future;
32#[cfg(feature = "embedded-db")]
33use futures::stream::TryStreamExt;
34use hotshot_types::{
35 data::VidShare,
36 simple_certificate::CertificatePair,
37 traits::{
38 block_contents::BlockHeader,
39 metrics::{Counter, Gauge, Histogram, Metrics},
40 node_implementation::NodeType,
41 EncodeBytes,
42 },
43};
44use itertools::Itertools;
45use jf_merkle_tree_compat::prelude::MerkleProof;
46pub use sqlx::Executor;
47use sqlx::{pool::Pool, query_builder::Separated, Encode, Execute, FromRow, QueryBuilder, Type};
48use tokio::time::sleep;
49
50#[cfg(not(feature = "embedded-db"))]
51use super::queries::state::batch_insert_hashes;
52#[cfg(feature = "embedded-db")]
53use super::queries::state::build_hash_batch_insert;
54use super::{
55 queries::{
56 self,
57 state::{collect_nodes_from_proofs, Node},
58 },
59 Database, Db,
60};
61use crate::{
62 availability::{
63 BlockQueryData, LeafQueryData, QueryableHeader, QueryablePayload, VidCommonQueryData,
64 },
65 data_source::{
66 storage::{pruning::PrunedHeightStorage, NodeStorage, UpdateAvailabilityStorage},
67 update,
68 },
69 merklized_state::{MerklizedState, UpdateStateData},
70 types::HeightIndexed,
71 Header, Payload, QueryError, QueryResult,
72};
73
74pub type Query<'q> = sqlx::query::Query<'q, Db, <Db as Database>::Arguments<'q>>;
75pub type QueryAs<'q, T> = sqlx::query::QueryAs<'q, Db, T, <Db as Database>::Arguments<'q>>;
76
77pub fn query(sql: &str) -> Query<'_> {
78 sqlx::query(sql)
79}
80
81pub fn query_as<'q, T>(sql: &'q str) -> QueryAs<'q, T>
82where
83 T: for<'r> FromRow<'r, <Db as Database>::Row>,
84{
85 sqlx::query_as(sql)
86}
87
88#[derive(Clone, Copy, Debug, Default)]
90pub struct Write;
91
92#[derive(Clone, Copy, Debug, Default)]
94pub struct Read;
95
96pub trait TransactionMode: Send + Sync {
98 fn begin(
99 conn: &mut <Db as Database>::Connection,
100 ) -> impl Future<Output = anyhow::Result<()>> + Send;
101
102 fn display() -> &'static str;
103}
104
105impl TransactionMode for Write {
106 #[allow(unused_variables)]
107 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
108 #[cfg(feature = "embedded-db")]
137 conn.execute("UPDATE pruned_height SET id = id WHERE false")
138 .await?;
139
140 #[cfg(not(feature = "embedded-db"))]
143 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
144 .await?;
145
146 Ok(())
147 }
148
149 fn display() -> &'static str {
150 "write"
151 }
152}
153
154impl TransactionMode for Read {
155 #[allow(unused_variables)]
156 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
157 #[cfg(not(feature = "embedded-db"))]
166 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE")
167 .await?;
168
169 Ok(())
170 }
171
172 fn display() -> &'static str {
173 "read-only"
174 }
175}
176
177#[derive(Clone, Copy, Debug)]
178enum CloseType {
179 Commit,
180 Revert,
181 Drop,
182}
183
184#[derive(Debug)]
185struct TransactionMetricsGuard<Mode> {
186 started_at: Instant,
187 metrics: PoolMetrics,
188 close_type: CloseType,
189 _mode: PhantomData<Mode>,
190}
191
192impl<Mode: TransactionMode> TransactionMetricsGuard<Mode> {
193 fn begin(metrics: PoolMetrics) -> Self {
194 let started_at = Instant::now();
195 tracing::trace!(mode = Mode::display(), ?started_at, "begin");
196 metrics.open_transactions.update(1);
197
198 Self {
199 started_at,
200 metrics,
201 close_type: CloseType::Drop,
202 _mode: Default::default(),
203 }
204 }
205
206 fn set_closed(&mut self, t: CloseType) {
207 self.close_type = t;
208 }
209}
210
211impl<Mode> Drop for TransactionMetricsGuard<Mode> {
212 fn drop(&mut self) {
213 self.metrics
214 .transaction_durations
215 .add_point((self.started_at.elapsed().as_millis() as f64) / 1000.);
216 self.metrics.open_transactions.update(-1);
217 match self.close_type {
218 CloseType::Commit => self.metrics.commits.add(1),
219 CloseType::Revert => self.metrics.reverts.add(1),
220 CloseType::Drop => self.metrics.drops.add(1),
221 }
222 tracing::trace!(started_at = ?self.started_at, reason = ?self.close_type, "close");
223 }
224}
225
226#[derive(Debug, Deref, DerefMut)]
228pub struct Transaction<Mode> {
229 #[deref]
230 #[deref_mut]
231 inner: sqlx::Transaction<'static, Db>,
232 metrics: TransactionMetricsGuard<Mode>,
233}
234
235impl<Mode: TransactionMode> Transaction<Mode> {
236 pub(super) async fn new(pool: &Pool<Db>, metrics: PoolMetrics) -> anyhow::Result<Self> {
237 let mut inner = pool.begin().await?;
238 let metrics = TransactionMetricsGuard::begin(metrics);
239 Mode::begin(inner.as_mut()).await?;
240 Ok(Self { inner, metrics })
241 }
242}
243
244impl<Mode: TransactionMode> update::Transaction for Transaction<Mode> {
245 async fn commit(mut self) -> anyhow::Result<()> {
246 self.inner.commit().await?;
247 self.metrics.set_closed(CloseType::Commit);
248 Ok(())
249 }
250 fn revert(mut self) -> impl Future + Send {
251 async move {
252 self.inner.rollback().await.unwrap();
253 self.metrics.set_closed(CloseType::Revert);
254 }
255 }
256}
257
258pub trait Params<'p> {
288 fn bind<'q, 'r>(
289 self,
290 q: &'q mut Separated<'r, 'p, Db, &'static str>,
291 ) -> &'q mut Separated<'r, 'p, Db, &'static str>
292 where
293 'p: 'r;
294}
295
296pub trait FixedLengthParams<'p, const N: usize>: Params<'p> {}
302
303macro_rules! impl_tuple_params {
304 ($n:literal, ($($t:ident,)+)) => {
305 impl<'p, $($t),+> Params<'p> for ($($t,)+)
306 where $(
307 $t: 'p + Encode<'p, Db> + Type<Db>
308 ),+{
309 fn bind<'q, 'r>(self, q: &'q mut Separated<'r, 'p, Db, &'static str>) -> &'q mut Separated<'r, 'p, Db, &'static str>
310 where 'p: 'r,
311 {
312 #[allow(non_snake_case)]
313 let ($($t,)+) = self;
314 q $(
315 .push_bind($t)
316 )+
317
318 }
319 }
320
321 impl<'p, $($t),+> FixedLengthParams<'p, $n> for ($($t,)+)
322 where $(
323 $t: 'p + for<'q> Encode<'q, Db> + Type<Db>
324 ),+ {
325 }
326 };
327}
328
329impl_tuple_params!(1, (T,));
330impl_tuple_params!(2, (T1, T2,));
331impl_tuple_params!(3, (T1, T2, T3,));
332impl_tuple_params!(4, (T1, T2, T3, T4,));
333impl_tuple_params!(5, (T1, T2, T3, T4, T5,));
334impl_tuple_params!(6, (T1, T2, T3, T4, T5, T6,));
335impl_tuple_params!(7, (T1, T2, T3, T4, T5, T6, T7,));
336impl_tuple_params!(8, (T1, T2, T3, T4, T5, T6, T7, T8,));
337
338pub fn build_where_in<'a, I>(
339 query: &'a str,
340 column: &'a str,
341 values: I,
342) -> QueryResult<(queries::QueryBuilder<'a>, String)>
343where
344 I: IntoIterator,
345 I::Item: 'a + Encode<'a, Db> + Type<Db>,
346{
347 let mut builder = queries::QueryBuilder::default();
348 let params = values
349 .into_iter()
350 .map(|v| Ok(format!("{} ", builder.bind(v)?)))
351 .collect::<QueryResult<Vec<String>>>()?;
352
353 if params.is_empty() {
354 return Err(QueryError::Error {
355 message: "failed to build WHERE IN query. No parameter found ".to_string(),
356 });
357 }
358
359 let sql = format!(
360 "{query} where {column} IN ({}) ",
361 params.into_iter().join(",")
362 );
363
364 Ok((builder, sql))
365}
366
367impl Transaction<Write> {
369 pub async fn upsert<'p, const N: usize, R>(
370 &mut self,
371 table: &str,
372 columns: [&str; N],
373 pk: impl IntoIterator<Item = &str>,
374 rows: R,
375 ) -> anyhow::Result<()>
376 where
377 R: IntoIterator,
378 R::Item: 'p + FixedLengthParams<'p, N> + Clone,
379 {
380 let set_columns = columns
381 .iter()
382 .map(|col| format!("{col} = excluded.{col}"))
383 .join(",");
384
385 let columns_str = columns.iter().map(|col| format!("\"{col}\"")).join(",");
386
387 let pk = pk.into_iter().join(",");
388
389 let rows: Vec<_> = rows.into_iter().collect();
390 let num_rows = rows.len();
391
392 if num_rows == 0 {
393 tracing::warn!("trying to upsert 0 rows into {table}, this has no effect");
394 return Ok(());
395 }
396
397 let interval = Duration::from_secs(1);
398 let mut retries = 5;
399
400 let mut query_builder =
401 QueryBuilder::new(format!("INSERT INTO \"{table}\" ({columns_str}) "));
402
403 loop {
404 let query_builder = query_builder.reset();
409
410 query_builder.push_values(rows.clone(), |mut b, row| {
411 row.bind(&mut b);
412 });
413
414 query_builder.push(format!(" ON CONFLICT ({pk}) DO UPDATE SET {set_columns}"));
415
416 let query = query_builder.build();
417 let statement = query.sql();
418
419 match self.execute(query).await {
420 Ok(res) => {
421 let rows_modified = res.rows_affected() as usize;
422 if rows_modified != num_rows {
423 let error = format!(
424 "unexpected number of rows modified: expected {num_rows}, got \
425 {rows_modified}. query: {statement}"
426 );
427 tracing::error!(error);
428 bail!(error);
429 }
430 return Ok(());
431 },
432 Err(err) => {
433 tracing::error!(
434 statement,
435 "error in statement execution ({} tries remaining): {err}",
436 retries
437 );
438 if retries == 0 {
439 bail!(err);
440 }
441 retries -= 1;
442 sleep(interval).await;
443 },
444 }
445 }
446 }
447}
448
449impl Transaction<Write> {
451 pub(super) async fn delete_batch(
453 &mut self,
454 state_tables: Vec<String>,
455 height: u64,
456 ) -> anyhow::Result<()> {
457 self.execute(query("DELETE FROM header WHERE height <= $1").bind(height as i64))
458 .await?;
459
460 for state_table in state_tables {
464 self.execute(
465 query(&format!(
466 "
467 DELETE FROM {state_table} WHERE (path, created) IN
468 (SELECT path, created FROM
469 (SELECT path, created,
470 ROW_NUMBER() OVER (PARTITION BY path ORDER BY created DESC) as rank
471 FROM {state_table} WHERE created <= $1) ranked_nodes WHERE rank != 1)"
472 ))
473 .bind(height as i64),
474 )
475 .await?;
476 }
477
478 self.save_pruned_height(height).await?;
479 Ok(())
480 }
481
482 pub(super) async fn save_pruned_height(&mut self, height: u64) -> anyhow::Result<()> {
484 self.upsert(
487 "pruned_height",
488 ["id", "last_height"],
489 ["id"],
490 [(1i32, height as i64)],
491 )
492 .await
493 }
494}
495
496impl<Types> UpdateAvailabilityStorage<Types> for Transaction<Write>
497where
498 Types: NodeType,
499 Payload<Types>: QueryablePayload<Types>,
500 Header<Types>: QueryableHeader<Types>,
501{
502 async fn insert_leaf_with_qc_chain(
503 &mut self,
504 leaf: LeafQueryData<Types>,
505 qc_chain: Option<[CertificatePair<Types>; 2]>,
506 ) -> anyhow::Result<()> {
507 let height = leaf.height();
508
509 if let Some(pruned_height) = self.load_pruned_height().await? {
512 if height <= pruned_height {
513 tracing::info!(
514 height,
515 pruned_height,
516 "ignoring leaf which is already pruned"
517 );
518 return Ok(());
519 }
520 }
521
522 let header_json = serde_json::to_value(leaf.leaf().block_header())
525 .context("failed to serialize header")?;
526 self.upsert(
527 "header",
528 ["height", "hash", "payload_hash", "data", "timestamp"],
529 ["height"],
530 [(
531 height as i64,
532 leaf.block_hash().to_string(),
533 leaf.leaf().block_header().payload_commitment().to_string(),
534 header_json,
535 leaf.leaf().block_header().timestamp() as i64,
536 )],
537 )
538 .await?;
539
540 let query = query("INSERT INTO payload (height) VALUES ($1) ON CONFLICT DO NOTHING")
548 .bind(height as i64);
549 query.execute(self.as_mut()).await?;
550
551 let leaf_json = serde_json::to_value(leaf.leaf()).context("failed to serialize leaf")?;
554 let qc_json = serde_json::to_value(leaf.qc()).context("failed to serialize QC")?;
555 self.upsert(
556 "leaf2",
557 ["height", "hash", "block_hash", "leaf", "qc"],
558 ["height"],
559 [(
560 height as i64,
561 leaf.hash().to_string(),
562 leaf.block_hash().to_string(),
563 leaf_json,
564 qc_json,
565 )],
566 )
567 .await?;
568
569 let block_height = NodeStorage::<Types>::block_height(self).await? as u64;
570 if height + 1 >= block_height {
571 let qcs = serde_json::to_value(&qc_chain)?;
576 self.upsert("latest_qc_chain", ["id", "qcs"], ["id"], [(1i32, qcs)])
577 .await?;
578 }
579
580 Ok(())
581 }
582
583 async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
584 let height = block.height();
585
586 if let Some(pruned_height) = self.load_pruned_height().await? {
589 if height <= pruned_height {
590 tracing::info!(
591 height,
592 pruned_height,
593 "ignoring block which is already pruned"
594 );
595 return Ok(());
596 }
597 }
598
599 let payload = block.payload.encode();
602
603 self.upsert(
604 "payload",
605 ["height", "data", "size", "num_transactions"],
606 ["height"],
607 [(
608 height as i64,
609 payload.as_ref().to_vec(),
610 block.size() as i32,
611 block.num_transactions() as i32,
612 )],
613 )
614 .await?;
615
616 let mut rows = vec![];
618 for (txn_ix, txn) in block.enumerate() {
619 let ns_id = block.header().namespace_id(&txn_ix.ns_index).unwrap();
620 rows.push((
621 txn.commit().to_string(),
622 height as i64,
623 txn_ix.ns_index.into(),
624 ns_id.into(),
625 txn_ix.position as i64,
626 ));
627 }
628 if !rows.is_empty() {
629 self.upsert(
630 "transactions",
631 ["hash", "block_height", "ns_index", "ns_id", "position"],
632 ["block_height", "ns_id", "position"],
633 rows,
634 )
635 .await?;
636 }
637
638 Ok(())
639 }
640
641 async fn insert_vid(
642 &mut self,
643 common: VidCommonQueryData<Types>,
644 share: Option<VidShare>,
645 ) -> anyhow::Result<()> {
646 let height = common.height();
647
648 if let Some(pruned_height) = self.load_pruned_height().await? {
651 if height <= pruned_height {
652 tracing::info!(
653 height,
654 pruned_height,
655 "ignoring VID common which is already pruned"
656 );
657 return Ok(());
658 }
659 }
660
661 let common_data =
662 bincode::serialize(common.common()).context("failed to serialize VID common data")?;
663 if let Some(share) = share {
664 let share_data = bincode::serialize(&share).context("failed to serialize VID share")?;
665 self.upsert(
666 "vid2",
667 ["height", "common", "share"],
668 ["height"],
669 [(height as i64, common_data, share_data)],
670 )
671 .await
672 } else {
673 self.upsert(
677 "vid2",
678 ["height", "common"],
679 ["height"],
680 [(height as i64, common_data)],
681 )
682 .await
683 }
684 }
685}
686
687#[async_trait]
688impl<Types: NodeType, State: MerklizedState<Types, ARITY>, const ARITY: usize>
689 UpdateStateData<Types, State, ARITY> for Transaction<Write>
690{
691 async fn set_last_state_height(&mut self, height: usize) -> anyhow::Result<()> {
692 self.upsert(
693 "last_merklized_state_height",
694 ["id", "height"],
695 ["id"],
696 [(1i32, height as i64)],
697 )
698 .await?;
699
700 Ok(())
701 }
702
703 async fn insert_merkle_nodes(
704 &mut self,
705 proof: MerkleProof<State::Entry, State::Key, State::T, ARITY>,
706 traversal_path: Vec<usize>,
707 block_number: u64,
708 ) -> anyhow::Result<()> {
709 let proofs = vec![(proof, traversal_path)];
710 UpdateStateData::<Types, State, ARITY>::insert_merkle_nodes_batch(
711 self,
712 proofs,
713 block_number,
714 )
715 .await
716 }
717
718 async fn insert_merkle_nodes_batch(
719 &mut self,
720 proofs: Vec<(
721 MerkleProof<State::Entry, State::Key, State::T, ARITY>,
722 Vec<usize>,
723 )>,
724 block_number: u64,
725 ) -> anyhow::Result<()> {
726 if proofs.is_empty() {
727 return Ok(());
728 }
729
730 let name = State::state_type();
731 let block_number = block_number as i64;
732
733 let (mut all_nodes, all_hashes) = collect_nodes_from_proofs(&proofs)?;
734 let hashes: Vec<Vec<u8>> = all_hashes.into_iter().collect();
735
736 #[cfg(not(feature = "embedded-db"))]
737 let nodes_hash_ids: HashMap<Vec<u8>, i32> = batch_insert_hashes(hashes, self).await?;
738
739 #[cfg(feature = "embedded-db")]
740 let nodes_hash_ids: HashMap<Vec<u8>, i32> = {
741 let mut hash_ids: HashMap<Vec<u8>, i32> = HashMap::with_capacity(hashes.len());
742 for hash_chunk in hashes.chunks(20) {
743 let (query, sql) = build_hash_batch_insert(hash_chunk)?;
744 let chunk_ids: HashMap<Vec<u8>, i32> = query
745 .query_as(&sql)
746 .fetch(self.as_mut())
747 .try_collect()
748 .await?;
749 hash_ids.extend(chunk_ids);
750 }
751 hash_ids
752 };
753
754 for (node, children, hash) in &mut all_nodes {
755 node.created = block_number;
756 node.hash_id = *nodes_hash_ids.get(&*hash).ok_or(QueryError::Error {
757 message: "Missing node hash".to_string(),
758 })?;
759
760 if let Some(children) = children {
761 let children_hashes = children
762 .iter()
763 .map(|c| nodes_hash_ids.get(c).copied())
764 .collect::<Option<Vec<i32>>>()
765 .ok_or(QueryError::Error {
766 message: "Missing child hash".to_string(),
767 })?;
768
769 node.children = Some(children_hashes.into());
770 }
771 }
772
773 Node::upsert(name, all_nodes.into_iter().map(|(n, ..)| n), self).await?;
774
775 Ok(())
776 }
777}
778
779#[async_trait]
780impl<Mode: TransactionMode> PrunedHeightStorage for Transaction<Mode> {
781 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
782 let Some((height,)) =
783 query_as::<(i64,)>("SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1")
784 .fetch_optional(self.as_mut())
785 .await?
786 else {
787 return Ok(None);
788 };
789 Ok(Some(height as u64))
790 }
791}
792
793#[derive(Clone, Debug)]
794pub(super) struct PoolMetrics {
795 open_transactions: Box<dyn Gauge>,
796 transaction_durations: Box<dyn Histogram>,
797 commits: Box<dyn Counter>,
798 reverts: Box<dyn Counter>,
799 drops: Box<dyn Counter>,
800}
801
802impl PoolMetrics {
803 pub(super) fn new(metrics: &(impl Metrics + ?Sized)) -> Self {
804 Self {
805 open_transactions: metrics.create_gauge("open_transactions".into(), None),
806 transaction_durations: metrics
807 .create_histogram("transaction_duration".into(), Some("s".into())),
808 commits: metrics.create_counter("committed_transactions".into(), None),
809 reverts: metrics.create_counter("reverted_transactions".into(), None),
810 drops: metrics.create_counter("dropped_transactions".into(), None),
811 }
812 }
813}