1use std::{collections::HashMap, marker::PhantomData, time::Instant};
22
23use anyhow::{Context, bail};
24use async_trait::async_trait;
25use committable::Committable;
26use derive_more::{Deref, DerefMut};
27use futures::future::Future;
28#[cfg(feature = "embedded-db")]
29use futures::stream::TryStreamExt;
30use hotshot_types::{
31 data::VidShare,
32 simple_certificate::CertificatePair,
33 traits::{
34 EncodeBytes,
35 block_contents::BlockHeader,
36 metrics::{Counter, Gauge, Histogram, Metrics},
37 node_implementation::NodeType,
38 },
39};
40use itertools::Itertools;
41use jf_merkle_tree_compat::prelude::MerkleProof;
42pub use sqlx::Executor;
43use sqlx::{Encode, Execute, FromRow, QueryBuilder, Type, pool::Pool, query_builder::Separated};
44use tracing::instrument;
45
46#[cfg(not(feature = "embedded-db"))]
47use super::queries::state::batch_insert_hashes;
48#[cfg(feature = "embedded-db")]
49use super::queries::state::build_hash_batch_insert;
50use super::{
51 Database, Db,
52 queries::{
53 self,
54 state::{Node, collect_nodes_from_proofs},
55 },
56};
57use crate::{
58 Header, Payload, QueryError, QueryResult,
59 availability::{
60 BlockQueryData, LeafQueryData, QueryableHeader, QueryablePayload, VidCommonQueryData,
61 },
62 data_source::{
63 storage::{NodeStorage, UpdateAvailabilityStorage, pruning::PrunedHeightStorage},
64 update,
65 },
66 merklized_state::{MerklizedState, UpdateStateData},
67 types::HeightIndexed,
68};
69
70pub type Query<'q> = sqlx::query::Query<'q, Db, <Db as Database>::Arguments<'q>>;
71pub type QueryAs<'q, T> = sqlx::query::QueryAs<'q, Db, T, <Db as Database>::Arguments<'q>>;
72
73pub fn query(sql: &str) -> Query<'_> {
74 sqlx::query(sql)
75}
76
77pub fn query_as<'q, T>(sql: &'q str) -> QueryAs<'q, T>
78where
79 T: for<'r> FromRow<'r, <Db as Database>::Row>,
80{
81 sqlx::query_as(sql)
82}
83
84#[derive(Clone, Copy, Debug, Default)]
86pub struct Write;
87
88#[derive(Clone, Copy, Debug, Default)]
90pub struct Read;
91
92pub trait TransactionMode: Send + Sync {
94 fn begin(
95 conn: &mut <Db as Database>::Connection,
96 ) -> impl Future<Output = anyhow::Result<()>> + Send;
97
98 fn display() -> &'static str;
99}
100
101impl TransactionMode for Write {
102 #[allow(unused_variables)]
103 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
104 #[cfg(feature = "embedded-db")]
133 conn.execute("UPDATE pruned_height SET id = id WHERE false")
134 .await?;
135
136 #[cfg(not(feature = "embedded-db"))]
139 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
140 .await?;
141
142 Ok(())
143 }
144
145 fn display() -> &'static str {
146 "write"
147 }
148}
149
150impl TransactionMode for Read {
151 #[allow(unused_variables)]
152 async fn begin(conn: &mut <Db as Database>::Connection) -> anyhow::Result<()> {
153 #[cfg(not(feature = "embedded-db"))]
162 conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE")
163 .await?;
164
165 Ok(())
166 }
167
168 fn display() -> &'static str {
169 "read-only"
170 }
171}
172
173#[derive(Clone, Copy, Debug)]
174enum CloseType {
175 Commit,
176 Revert,
177 Drop,
178}
179
180#[derive(Debug)]
181struct TransactionMetricsGuard<Mode> {
182 started_at: Instant,
183 metrics: PoolMetrics,
184 close_type: CloseType,
185 _mode: PhantomData<Mode>,
186}
187
188impl<Mode: TransactionMode> TransactionMetricsGuard<Mode> {
189 fn begin(metrics: PoolMetrics) -> Self {
190 let started_at = Instant::now();
191 tracing::trace!(mode = Mode::display(), ?started_at, "begin");
192 metrics.open_transactions.update(1);
193
194 Self {
195 started_at,
196 metrics,
197 close_type: CloseType::Drop,
198 _mode: Default::default(),
199 }
200 }
201
202 fn set_closed(&mut self, t: CloseType) {
203 self.close_type = t;
204 }
205}
206
207impl<Mode> Drop for TransactionMetricsGuard<Mode> {
208 fn drop(&mut self) {
209 self.metrics
210 .transaction_durations
211 .add_point((self.started_at.elapsed().as_millis() as f64) / 1000.);
212 self.metrics.open_transactions.update(-1);
213 match self.close_type {
214 CloseType::Commit => self.metrics.commits.add(1),
215 CloseType::Revert => self.metrics.reverts.add(1),
216 CloseType::Drop => self.metrics.drops.add(1),
217 }
218 tracing::trace!(started_at = ?self.started_at, reason = ?self.close_type, "close");
219 }
220}
221
222#[derive(Debug, Deref, DerefMut)]
224pub struct Transaction<Mode> {
225 #[deref]
226 #[deref_mut]
227 inner: sqlx::Transaction<'static, Db>,
228 metrics: TransactionMetricsGuard<Mode>,
229}
230
231impl<Mode: TransactionMode> Transaction<Mode> {
232 pub(super) async fn new(pool: &Pool<Db>, metrics: PoolMetrics) -> anyhow::Result<Self> {
233 let mut inner = pool.begin().await?;
234 let metrics = TransactionMetricsGuard::begin(metrics);
235 Mode::begin(inner.as_mut()).await?;
236 Ok(Self { inner, metrics })
237 }
238}
239
240impl<Mode: TransactionMode> update::Transaction for Transaction<Mode> {
241 async fn commit(mut self) -> anyhow::Result<()> {
242 self.inner.commit().await?;
243 self.metrics.set_closed(CloseType::Commit);
244 Ok(())
245 }
246 fn revert(mut self) -> impl Future + Send {
247 async move {
248 self.inner.rollback().await.unwrap();
249 self.metrics.set_closed(CloseType::Revert);
250 }
251 }
252}
253
254pub trait Params<'p> {
284 fn bind<'q, 'r>(
285 self,
286 q: &'q mut Separated<'r, 'p, Db, &'static str>,
287 ) -> &'q mut Separated<'r, 'p, Db, &'static str>
288 where
289 'p: 'r;
290}
291
292pub trait FixedLengthParams<'p, const N: usize>: Params<'p> {}
298
299macro_rules! impl_tuple_params {
300 ($n:literal, ($($t:ident,)+)) => {
301 impl<'p, $($t),+> Params<'p> for ($($t,)+)
302 where $(
303 $t: 'p + Encode<'p, Db> + Type<Db>
304 ),+{
305 fn bind<'q, 'r>(self, q: &'q mut Separated<'r, 'p, Db, &'static str>) -> &'q mut Separated<'r, 'p, Db, &'static str>
306 where 'p: 'r,
307 {
308 #[allow(non_snake_case)]
309 let ($($t,)+) = self;
310 q $(
311 .push_bind($t)
312 )+
313
314 }
315 }
316
317 impl<'p, $($t),+> FixedLengthParams<'p, $n> for ($($t,)+)
318 where $(
319 $t: 'p + for<'q> Encode<'q, Db> + Type<Db>
320 ),+ {
321 }
322 };
323}
324
325impl_tuple_params!(1, (T,));
326impl_tuple_params!(2, (T1, T2,));
327impl_tuple_params!(3, (T1, T2, T3,));
328impl_tuple_params!(4, (T1, T2, T3, T4,));
329impl_tuple_params!(5, (T1, T2, T3, T4, T5,));
330impl_tuple_params!(6, (T1, T2, T3, T4, T5, T6,));
331impl_tuple_params!(7, (T1, T2, T3, T4, T5, T6, T7,));
332impl_tuple_params!(8, (T1, T2, T3, T4, T5, T6, T7, T8,));
333
334pub fn build_where_in<'a, I>(
335 query: &'a str,
336 column: &'a str,
337 values: I,
338) -> QueryResult<(queries::QueryBuilder<'a>, String)>
339where
340 I: IntoIterator,
341 I::Item: 'a + Encode<'a, Db> + Type<Db>,
342{
343 let mut builder = queries::QueryBuilder::default();
344 let params = values
345 .into_iter()
346 .map(|v| Ok(format!("{} ", builder.bind(v)?)))
347 .collect::<QueryResult<Vec<String>>>()?;
348
349 if params.is_empty() {
350 return Err(QueryError::Error {
351 message: "failed to build WHERE IN query. No parameter found ".to_string(),
352 });
353 }
354
355 let sql = format!(
356 "{query} where {column} IN ({}) ",
357 params.into_iter().join(",")
358 );
359
360 Ok((builder, sql))
361}
362
363impl Transaction<Write> {
365 pub async fn upsert<'p, const N: usize, R>(
366 &mut self,
367 table: &str,
368 columns: [&str; N],
369 pk: impl IntoIterator<Item = &str>,
370 rows: R,
371 ) -> anyhow::Result<()>
372 where
373 R: IntoIterator,
374 R::Item: 'p + FixedLengthParams<'p, N>,
375 {
376 let set_columns = columns
377 .iter()
378 .map(|col| format!("{col} = excluded.{col}"))
379 .join(",");
380
381 let columns_str = columns.iter().map(|col| format!("\"{col}\"")).join(",");
382
383 let pk = pk.into_iter().join(",");
384
385 let rows: Vec<_> = rows.into_iter().collect();
386 let num_rows = rows.len();
387
388 if num_rows == 0 {
389 tracing::warn!("trying to upsert 0 rows into {table}, this has no effect");
390 return Ok(());
391 }
392
393 let mut query_builder =
394 QueryBuilder::new(format!("INSERT INTO \"{table}\" ({columns_str}) "));
395 query_builder.push_values(rows, |mut b, row| {
396 row.bind(&mut b);
397 });
398 query_builder.push(format!(" ON CONFLICT ({pk}) DO UPDATE SET {set_columns}"));
399
400 let query = query_builder.build();
401 let statement = query.sql();
402
403 let res = self.execute(query).await.inspect_err(|err| {
404 tracing::error!(statement, "error in statement execution: {err:#}");
405 })?;
406 let rows_modified = res.rows_affected() as usize;
407 if rows_modified != num_rows {
408 let error = format!(
409 "unexpected number of rows modified: expected {num_rows}, got {rows_modified}. \
410 query: {statement}"
411 );
412 tracing::error!(error);
413 bail!(error);
414 }
415 Ok(())
416 }
417}
418
419impl Transaction<Write> {
421 #[instrument(skip(self))]
423 pub(super) async fn delete_batch(
424 &mut self,
425 state_tables: Vec<String>,
426 height: u64,
427 ) -> anyhow::Result<()> {
428 let res = query(
430 "WITH to_delete AS (
431 SELECT h.payload_hash, h.ns_table FROM header AS h
432 WHERE (h.payload_hash, h.ns_table) IN (
433 SELECT range.payload_hash, range.ns_table
434 FROM header AS range
435 WHERE range.height <= $1
436 )
437 GROUP BY h.payload_hash, h.ns_table
438 HAVING count(*) <= 1
439 )
440 DELETE FROM payload AS p
441 WHERE (p.hash, p.ns_table) IN (SELECT * FROM to_delete)",
442 )
443 .bind(height as i64)
444 .execute(self.as_mut())
445 .await
446 .context("deleting payloads")?;
447 tracing::debug!(
448 rows_affected = res.rows_affected(),
449 "garbage collected payloads"
450 );
451
452 let res = query(
454 "WITH to_delete AS (
455 SELECT h.payload_hash FROM header AS h
456 WHERE h.payload_hash IN (
457 SELECT range.payload_hash
458 FROM header AS range
459 WHERE range.height <= $1
460 )
461 GROUP BY h.payload_hash
462 HAVING count(*) <= 1
463 )
464 DELETE FROM vid_common AS v
465 WHERE v.hash IN (SELECT * FROM to_delete)",
466 )
467 .bind(height as i64)
468 .execute(self.as_mut())
469 .await
470 .context("deleting VID common")?;
471 tracing::debug!(
472 rows_affected = res.rows_affected(),
473 "garbage collected VID common"
474 );
475
476 let res = query("DELETE FROM header WHERE height <= $1")
478 .bind(height as i64)
479 .execute(self.as_mut())
480 .await?;
481 tracing::debug!(rows_affected = res.rows_affected(), "pruned headers");
482
483 for state_table in state_tables {
487 self.execute(
488 query(&format!(
489 "
490 DELETE FROM {state_table} WHERE (path, created) IN
491 (SELECT path, created FROM
492 (SELECT path, created,
493 ROW_NUMBER() OVER (PARTITION BY path ORDER BY created DESC) as rank
494 FROM {state_table} WHERE created <= $1) ranked_nodes WHERE rank != 1)"
495 ))
496 .bind(height as i64),
497 )
498 .await?;
499 }
500
501 self.save_pruned_height(height).await?;
502 Ok(())
503 }
504
505 pub(crate) async fn save_pruned_height(&mut self, height: u64) -> anyhow::Result<()> {
507 self.upsert(
510 "pruned_height",
511 ["id", "last_height"],
512 ["id"],
513 [(1i32, height as i64)],
514 )
515 .await
516 }
517}
518
519impl<Types> UpdateAvailabilityStorage<Types> for Transaction<Write>
520where
521 Types: NodeType,
522 Payload<Types>: QueryablePayload<Types>,
523 Header<Types>: QueryableHeader<Types>,
524{
525 async fn insert_leaf_with_qc_chain(
526 &mut self,
527 leaf: LeafQueryData<Types>,
528 qc_chain: Option<[CertificatePair<Types>; 2]>,
529 ) -> anyhow::Result<()> {
530 let height = leaf.height();
531
532 if let Some(pruned_height) = self.load_pruned_height().await?
535 && height <= pruned_height
536 {
537 tracing::info!(
538 height,
539 pruned_height,
540 "ignoring leaf which is already pruned"
541 );
542 return Ok(());
543 }
544
545 let header_json = serde_json::to_value(leaf.leaf().block_header())
548 .context("failed to serialize header")?;
549 self.upsert(
550 "header",
551 [
552 "height",
553 "hash",
554 "payload_hash",
555 "ns_table",
556 "data",
557 "timestamp",
558 ],
559 ["height"],
560 [(
561 height as i64,
562 leaf.block_hash().to_string(),
563 leaf.leaf().block_header().payload_commitment().to_string(),
564 leaf.leaf().block_header().ns_table(),
565 header_json,
566 leaf.leaf().block_header().timestamp() as i64,
567 )],
568 )
569 .await
570 .context("inserting header")?;
571
572 let leaf_json = serde_json::to_value(leaf.leaf()).context("failed to serialize leaf")?;
575 let qc_json = serde_json::to_value(leaf.qc()).context("failed to serialize QC")?;
576 self.upsert(
577 "leaf2",
578 ["height", "hash", "block_hash", "leaf", "qc"],
579 ["height"],
580 [(
581 height as i64,
582 leaf.hash().to_string(),
583 leaf.block_hash().to_string(),
584 leaf_json,
585 qc_json,
586 )],
587 )
588 .await
589 .context("inserting leaf")?;
590
591 let block_height = NodeStorage::<Types>::block_height(self).await? as u64;
592 if height + 1 >= block_height {
593 let qcs = serde_json::to_value(&qc_chain)?;
598 self.upsert("latest_qc_chain", ["id", "qcs"], ["id"], [(1i32, qcs)])
599 .await
600 .context("inserting QC chain")?;
601 }
602
603 Ok(())
604 }
605
606 async fn insert_block(&mut self, block: BlockQueryData<Types>) -> anyhow::Result<()> {
607 let height = block.height();
608
609 if let Some(pruned_height) = self.load_pruned_height().await?
612 && height <= pruned_height
613 {
614 tracing::info!(
615 height,
616 pruned_height,
617 "ignoring block which is already pruned"
618 );
619 return Ok(());
620 }
621
622 self.upsert(
623 "payload",
624 ["hash", "ns_table", "size", "num_transactions", "data"],
625 ["hash", "ns_table"],
626 [(
627 block.payload_hash().to_string(),
628 block.header().ns_table(),
629 block.size() as i32,
630 block.num_transactions() as i32,
631 block.payload.encode().as_ref().to_vec(),
632 )],
633 )
634 .await
635 .context("inserting payload")?;
636
637 let mut rows = vec![];
639 for (txn_ix, txn) in block.enumerate() {
640 let ns_id = block.header().namespace_id(&txn_ix.ns_index).unwrap();
641 rows.push((
642 txn.commit().to_string(),
643 height as i64,
644 txn_ix.ns_index.into(),
645 ns_id.into(),
646 txn_ix.position as i64,
647 ));
648 }
649 if !rows.is_empty() {
650 self.upsert(
651 "transactions",
652 ["hash", "block_height", "ns_index", "ns_id", "position"],
653 ["block_height", "ns_id", "position"],
654 rows,
655 )
656 .await
657 .context("inserting transactions")?;
658 }
659
660 Ok(())
661 }
662
663 async fn insert_vid(
664 &mut self,
665 common: VidCommonQueryData<Types>,
666 share: Option<VidShare>,
667 ) -> anyhow::Result<()> {
668 let height = common.height();
669
670 if let Some(pruned_height) = self.load_pruned_height().await?
673 && height <= pruned_height
674 {
675 tracing::info!(
676 height,
677 pruned_height,
678 "ignoring VID common which is already pruned"
679 );
680 return Ok(());
681 }
682
683 let common_data =
684 bincode::serialize(common.common()).context("failed to serialize VID common data")?;
685 self.upsert(
686 "vid_common",
687 ["hash", "data"],
688 ["hash"],
689 [(common.payload_hash().to_string(), common_data)],
690 )
691 .await
692 .context("isnerting VID common")?;
693
694 if let Some(share) = share {
695 let share_data = bincode::serialize(&share).context("failed to serialize VID share")?;
696 query("UPDATE header SET vid_share = $1 WHERE height = $2")
697 .bind(share_data)
698 .bind(height as i64)
699 .execute(self.as_mut())
700 .await
701 .context("inserting VID share")?;
702 }
703
704 Ok(())
705 }
706}
707
708#[async_trait]
709impl<Types: NodeType, State: MerklizedState<Types, ARITY>, const ARITY: usize>
710 UpdateStateData<Types, State, ARITY> for Transaction<Write>
711{
712 async fn set_last_state_height(&mut self, height: usize) -> anyhow::Result<()> {
713 self.upsert(
714 "last_merklized_state_height",
715 ["id", "height"],
716 ["id"],
717 [(1i32, height as i64)],
718 )
719 .await?;
720
721 Ok(())
722 }
723
724 async fn insert_merkle_nodes(
725 &mut self,
726 proof: MerkleProof<State::Entry, State::Key, State::T, ARITY>,
727 traversal_path: Vec<usize>,
728 block_number: u64,
729 ) -> anyhow::Result<()> {
730 let proofs = vec![(proof, traversal_path)];
731 UpdateStateData::<Types, State, ARITY>::insert_merkle_nodes_batch(
732 self,
733 proofs,
734 block_number,
735 )
736 .await
737 }
738
739 async fn insert_merkle_nodes_batch(
740 &mut self,
741 proofs: Vec<(
742 MerkleProof<State::Entry, State::Key, State::T, ARITY>,
743 Vec<usize>,
744 )>,
745 block_number: u64,
746 ) -> anyhow::Result<()> {
747 if proofs.is_empty() {
748 return Ok(());
749 }
750
751 let name = State::state_type();
752 let block_number = block_number as i64;
753
754 let (mut all_nodes, all_hashes) = collect_nodes_from_proofs(&proofs)?;
755 let hashes: Vec<Vec<u8>> = all_hashes.into_iter().collect();
756
757 #[cfg(not(feature = "embedded-db"))]
758 let nodes_hash_ids: HashMap<Vec<u8>, i32> = batch_insert_hashes(hashes, self).await?;
759
760 #[cfg(feature = "embedded-db")]
761 let nodes_hash_ids: HashMap<Vec<u8>, i32> = {
762 let mut hash_ids: HashMap<Vec<u8>, i32> = HashMap::with_capacity(hashes.len());
763 for hash_chunk in hashes.chunks(20) {
764 let (query, sql) = build_hash_batch_insert(hash_chunk)?;
765 let chunk_ids: HashMap<Vec<u8>, i32> = query
766 .query_as(&sql)
767 .fetch(self.as_mut())
768 .try_collect()
769 .await?;
770 hash_ids.extend(chunk_ids);
771 }
772 hash_ids
773 };
774
775 for (node, children, hash) in &mut all_nodes {
776 node.created = block_number;
777 node.hash_id = *nodes_hash_ids.get(&*hash).ok_or(QueryError::Error {
778 message: "Missing node hash".to_string(),
779 })?;
780
781 if let Some(children) = children {
782 let children_hashes = children
783 .iter()
784 .map(|c| nodes_hash_ids.get(c).copied())
785 .collect::<Option<Vec<i32>>>()
786 .ok_or(QueryError::Error {
787 message: "Missing child hash".to_string(),
788 })?;
789
790 node.children = Some(children_hashes.into());
791 }
792 }
793
794 Node::upsert(name, all_nodes.into_iter().map(|(n, ..)| n), self).await?;
795
796 Ok(())
797 }
798}
799
800#[async_trait]
801impl<Mode: TransactionMode> PrunedHeightStorage for Transaction<Mode> {
802 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
803 let Some((height,)) =
804 query_as::<(i64,)>("SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1")
805 .fetch_optional(self.as_mut())
806 .await?
807 else {
808 return Ok(None);
809 };
810 Ok(Some(height as u64))
811 }
812}
813
814#[derive(Clone, Debug)]
815pub(super) struct PoolMetrics {
816 open_transactions: Box<dyn Gauge>,
817 transaction_durations: Box<dyn Histogram>,
818 commits: Box<dyn Counter>,
819 reverts: Box<dyn Counter>,
820 drops: Box<dyn Counter>,
821}
822
823impl PoolMetrics {
824 pub(super) fn new(metrics: &(impl Metrics + ?Sized)) -> Self {
825 Self {
826 open_transactions: metrics.create_gauge("open_transactions".into(), None),
827 transaction_durations: metrics
828 .create_histogram("transaction_duration".into(), Some("s".into())),
829 commits: metrics.create_counter("committed_transactions".into(), None),
830 reverts: metrics.create_counter("reverted_transactions".into(), None),
831 drops: metrics.create_counter("dropped_transactions".into(), None),
832 }
833 }
834}