hotshot_query_service/data_source/storage/
sql.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use anyhow::Context;
17use async_trait::async_trait;
18use chrono::Utc;
19use committable::Committable;
20#[cfg(not(feature = "embedded-db"))]
21use futures::future::FutureExt;
22use hotshot_types::{
23    data::{Leaf, Leaf2, VidShare},
24    simple_certificate::{QuorumCertificate, QuorumCertificate2},
25    traits::{metrics::Metrics, node_implementation::NodeType},
26    vid::advz::{ADVZCommon, ADVZShare},
27};
28use itertools::Itertools;
29use log::LevelFilter;
30#[cfg(not(feature = "embedded-db"))]
31use sqlx::postgres::{PgConnectOptions, PgSslMode};
32#[cfg(feature = "embedded-db")]
33use sqlx::sqlite::SqliteConnectOptions;
34use sqlx::{
35    pool::{Pool, PoolOptions},
36    ConnectOptions, Row,
37};
38
39use crate::{
40    availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
41    data_source::{
42        storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
43        update::Transaction as _,
44        VersionedDataSource,
45    },
46    metrics::PrometheusMetrics,
47    node::BlockId,
48    status::HasMetrics,
49    Header, QueryError, QueryResult, VidCommon,
50};
51pub extern crate sqlx;
52pub use sqlx::{Database, Sqlite};
53
54mod db;
55mod migrate;
56mod queries;
57mod transaction;
58
59pub use anyhow::Error;
60pub use db::*;
61pub use include_dir::include_dir;
62pub use queries::QueryBuilder;
63pub use refinery::Migration;
64pub use transaction::*;
65
66use self::{migrate::Migrator, transaction::PoolMetrics};
67use super::{AvailabilityStorage, NodeStorage};
68// This needs to be reexported so that we can reference it by absolute path relative to this crate
69// in the expansion of `include_migrations`, even when `include_migrations` is invoked from another
70// crate which doesn't have `include_dir` as a dependency.
71pub use crate::include_migrations;
72
73/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite.
74///
75/// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl
76/// Iterator<Item = Migration>`. Each migration must be a text file which is an immediate child of
77/// `path`, and there must be no non-migration files in `path`. The migration files must have names
78/// of the form `V${version}__${name}.sql`, where `version` is a positive integer indicating how the
79/// migration is to be ordered relative to other migrations, and `name` is a descriptive name for
80/// the migration.
81///
82/// `path` should be an absolute path. It is possible to give a path relative to the root of the
83/// invoking crate by using environment variable expansions and the `CARGO_MANIFEST_DIR` environment
84/// variable.
85///
86/// As an example, this is the invocation used to load the default migrations from the
87/// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at
88/// - PostgreSQL migrations are in `/migrations/postgres`.
89/// - SQLite migrations are in `/migrations/sqlite`.
90///
91/// ```
92/// # use hotshot_query_service::data_source::sql::{include_migrations, Migration};
93/// // For PostgreSQL
94/// #[cfg(not(feature = "embedded-db"))]
95///  let mut migrations: Vec<Migration> =
96///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect();
97/// // For SQLite
98/// #[cfg(feature = "embedded-db")]
99/// let mut migrations: Vec<Migration> =
100///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect();
101///    
102///     migrations.sort();
103///     assert_eq!(migrations[0].version(), 10);
104///     assert_eq!(migrations[0].name(), "init_schema");
105/// ```
106///
107/// Note that a similar macro is available from Refinery:
108/// [embed_migrations](https://docs.rs/refinery/0.8.11/refinery/macro.embed_migrations.html). This
109/// macro differs in that it evaluates to an iterator of [migrations](Migration), making it an
110/// expression macro, while `embed_migrations` is a statement macro that defines a module which
111/// provides access to the embedded migrations only indirectly via a
112/// [`Runner`](https://docs.rs/refinery/0.8.11/refinery/struct.Runner.html). The direct access to
113/// migrations provided by [`include_migrations`] makes this macro easier to use with
114/// [`Config::migrations`], for combining custom migrations with [`default_migrations`].
115#[macro_export]
116macro_rules! include_migrations {
117    ($dir:tt) => {
118        $crate::data_source::storage::sql::include_dir!($dir)
119            .files()
120            .map(|file| {
121                let path = file.path();
122                let name = path
123                    .file_name()
124                    .and_then(std::ffi::OsStr::to_str)
125                    .unwrap_or_else(|| {
126                        panic!(
127                            "migration file {} must have a non-empty UTF-8 name",
128                            path.display()
129                        )
130                    });
131                let sql = file
132                    .contents_utf8()
133                    .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
134                $crate::data_source::storage::sql::Migration::unapplied(name, sql)
135                    .expect("invalid migration")
136            })
137    };
138}
139
140/// The migrations required to build the default schema for this version of [`SqlStorage`].
141pub fn default_migrations() -> Vec<Migration> {
142    #[cfg(not(feature = "embedded-db"))]
143    let mut migrations =
144        include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
145
146    #[cfg(feature = "embedded-db")]
147    let mut migrations =
148        include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
149
150    // Check version uniqueness and sort by version.
151    validate_migrations(&mut migrations).expect("default migrations are invalid");
152
153    // Check that all migration versions are multiples of 100, so that custom migrations can be
154    // inserted in between.
155    for m in &migrations {
156        if m.version() <= 30 {
157            // An older version of this software used intervals of 10 instead of 100. This was
158            // changed to allow more custom migrations between each default migration, but we must
159            // still accept older migrations that followed the older rule.
160            assert!(
161                m.version() > 0 && m.version() % 10 == 0,
162                "legacy default migration version {} is not a positive multiple of 10",
163                m.version()
164            );
165        } else {
166            assert!(
167                m.version() % 100 == 0,
168                "default migration version {} is not a multiple of 100",
169                m.version()
170            );
171        }
172    }
173
174    migrations
175}
176
177/// Validate and preprocess a sequence of migrations.
178///
179/// * Ensure all migrations have distinct versions
180/// * Ensure migrations are sorted by increasing version
181fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
182    migrations.sort_by_key(|m| m.version());
183
184    // Check version uniqueness.
185    for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
186        if next <= prev {
187            return Err(Error::msg(format!(
188                "migration versions are not strictly increasing ({prev}->{next})"
189            )));
190        }
191    }
192
193    Ok(())
194}
195
196/// Add custom migrations to a default migration sequence.
197///
198/// Migrations in `custom` replace migrations in `default` with the same version. Otherwise, the two
199/// sequences `default` and `custom` are merged so that the resulting sequence is sorted by
200/// ascending version number. Each of `default` and `custom` is assumed to be the output of
201/// [`validate_migrations`]; that is, each is sorted by version and contains no duplicate versions.
202fn add_custom_migrations(
203    default: impl IntoIterator<Item = Migration>,
204    custom: impl IntoIterator<Item = Migration>,
205) -> impl Iterator<Item = Migration> {
206    default
207        .into_iter()
208        // Merge sorted lists, joining pairs of equal version into `EitherOrBoth::Both`.
209        .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
210        // Prefer the custom migration for a given version when both default and custom versions
211        // are present.
212        .map(|pair| pair.reduce(|_, custom| custom))
213}
214
215#[derive(Clone)]
216pub struct Config {
217    #[cfg(feature = "embedded-db")]
218    db_opt: SqliteConnectOptions,
219
220    #[cfg(not(feature = "embedded-db"))]
221    db_opt: PgConnectOptions,
222
223    pool_opt: PoolOptions<Db>,
224
225    /// Extra pool_opt to allow separately configuring the connection pool for query service
226    #[cfg(not(feature = "embedded-db"))]
227    pool_opt_query: PoolOptions<Db>,
228
229    #[cfg(not(feature = "embedded-db"))]
230    schema: String,
231    reset: bool,
232    migrations: Vec<Migration>,
233    no_migrations: bool,
234    pruner_cfg: Option<PrunerCfg>,
235    archive: bool,
236    pool: Option<Pool<Db>>,
237}
238
239#[cfg(not(feature = "embedded-db"))]
240impl Default for Config {
241    fn default() -> Self {
242        PgConnectOptions::default()
243            .username("postgres")
244            .password("password")
245            .host("localhost")
246            .port(5432)
247            .into()
248    }
249}
250
251#[cfg(feature = "embedded-db")]
252impl Default for Config {
253    fn default() -> Self {
254        SqliteConnectOptions::default()
255            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
256            .busy_timeout(Duration::from_secs(30))
257            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
258            .create_if_missing(true)
259            .into()
260    }
261}
262
263#[cfg(feature = "embedded-db")]
264impl From<SqliteConnectOptions> for Config {
265    fn from(db_opt: SqliteConnectOptions) -> Self {
266        Self {
267            db_opt,
268            pool_opt: PoolOptions::default(),
269            reset: false,
270            migrations: vec![],
271            no_migrations: false,
272            pruner_cfg: None,
273            archive: false,
274            pool: None,
275        }
276    }
277}
278
279#[cfg(not(feature = "embedded-db"))]
280impl From<PgConnectOptions> for Config {
281    fn from(db_opt: PgConnectOptions) -> Self {
282        Self {
283            db_opt,
284            pool_opt: PoolOptions::default(),
285            pool_opt_query: PoolOptions::default(),
286            schema: "hotshot".into(),
287            reset: false,
288            migrations: vec![],
289            no_migrations: false,
290            pruner_cfg: None,
291            archive: false,
292            pool: None,
293        }
294    }
295}
296
297#[cfg(not(feature = "embedded-db"))]
298impl FromStr for Config {
299    type Err = <PgConnectOptions as FromStr>::Err;
300
301    fn from_str(s: &str) -> Result<Self, Self::Err> {
302        Ok(PgConnectOptions::from_str(s)?.into())
303    }
304}
305
306#[cfg(feature = "embedded-db")]
307impl FromStr for Config {
308    type Err = <SqliteConnectOptions as FromStr>::Err;
309
310    fn from_str(s: &str) -> Result<Self, Self::Err> {
311        Ok(SqliteConnectOptions::from_str(s)?.into())
312    }
313}
314
315#[cfg(feature = "embedded-db")]
316impl Config {
317    pub fn busy_timeout(mut self, timeout: Duration) -> Self {
318        self.db_opt = self.db_opt.busy_timeout(timeout);
319        self
320    }
321
322    pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
323        self.db_opt = self.db_opt.filename(path);
324        self
325    }
326}
327
328#[cfg(not(feature = "embedded-db"))]
329impl Config {
330    /// Set the hostname of the database server.
331    ///
332    /// The default is `localhost`.
333    pub fn host(mut self, host: impl Into<String>) -> Self {
334        self.db_opt = self.db_opt.host(&host.into());
335        self
336    }
337
338    /// Set the port on which to connect to the database.
339    ///
340    /// The default is 5432, the default Postgres port.
341    pub fn port(mut self, port: u16) -> Self {
342        self.db_opt = self.db_opt.port(port);
343        self
344    }
345
346    /// Set the DB user to connect as.
347    pub fn user(mut self, user: &str) -> Self {
348        self.db_opt = self.db_opt.username(user);
349        self
350    }
351
352    /// Set a password for connecting to the database.
353    pub fn password(mut self, password: &str) -> Self {
354        self.db_opt = self.db_opt.password(password);
355        self
356    }
357
358    /// Set the name of the database to connect to.
359    pub fn database(mut self, database: &str) -> Self {
360        self.db_opt = self.db_opt.database(database);
361        self
362    }
363
364    /// Use TLS for an encrypted connection to the database.
365    ///
366    /// Note that an encrypted connection may be established even if this option is not set, as long
367    /// as both the client and server support it. This option merely causes connection to fail if an
368    /// encrypted stream cannot be established.
369    pub fn tls(mut self) -> Self {
370        self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
371        self
372    }
373
374    /// Set the name of the schema to use for queries.
375    ///
376    /// The default schema is named `hotshot` and is created via the default migrations.
377    pub fn schema(mut self, schema: impl Into<String>) -> Self {
378        self.schema = schema.into();
379        self
380    }
381}
382
383impl Config {
384    /// Sets the database connection pool
385    /// This allows reusing an existing connection pool when building a new `SqlStorage` instance.
386    pub fn pool(mut self, pool: Pool<Db>) -> Self {
387        self.pool = Some(pool);
388        self
389    }
390
391    /// Reset the schema on connection.
392    ///
393    /// When this [`Config`] is used to [`connect`](Self::connect) a
394    /// [`SqlDataSource`](crate::data_source::SqlDataSource), if this option is set, the relevant
395    /// [`schema`](Self::schema) will first be dropped and then recreated, yielding a completely
396    /// fresh instance of the query service.
397    ///
398    /// This is a particularly useful capability for development and staging environments. Still, it
399    /// must be used with extreme caution, as using this will irrevocably delete any data pertaining
400    /// to the query service in the database.
401    pub fn reset_schema(mut self) -> Self {
402        self.reset = true;
403        self
404    }
405
406    /// Add custom migrations to run when connecting to the database.
407    pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
408        self.migrations.extend(migrations);
409        self
410    }
411
412    /// Skip all migrations when connecting to the database.
413    pub fn no_migrations(mut self) -> Self {
414        self.no_migrations = true;
415        self
416    }
417
418    /// Enable pruning with a given configuration.
419    ///
420    /// If [`archive`](Self::archive) was previously specified, this will override it.
421    pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
422        cfg.validate()?;
423        self.pruner_cfg = Some(cfg);
424        self.archive = false;
425        Ok(self)
426    }
427
428    /// Disable pruning and reconstruct previously pruned data.
429    ///
430    /// While running without pruning is the default behavior, the default will not try to
431    /// reconstruct data that was pruned in a previous run where pruning was enabled. This option
432    /// instructs the service to run without pruning _and_ reconstruct all previously pruned data by
433    /// fetching from peers.
434    ///
435    /// If [`pruner_cfg`](Self::pruner_cfg) was previously specified, this will override it.
436    pub fn archive(mut self) -> Self {
437        self.pruner_cfg = None;
438        self.archive = true;
439        self
440    }
441
442    /// Set the maximum idle time of a connection.
443    ///
444    /// Any connection which has been open and unused longer than this duration will be
445    /// automatically closed to reduce load on the server.
446    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
447        self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
448
449        #[cfg(not(feature = "embedded-db"))]
450        {
451            self.pool_opt_query = self.pool_opt_query.idle_timeout(Some(timeout));
452        }
453
454        self
455    }
456
457    /// Set the maximum lifetime of a connection.
458    ///
459    /// Any connection which has been open longer than this duration will be automatically closed
460    /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh
461    /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the
462    /// server implementation.
463    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
464        self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
465
466        #[cfg(not(feature = "embedded-db"))]
467        {
468            self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
469        }
470
471        self
472    }
473
474    /// Set the minimum number of connections to maintain at any time.
475    ///
476    /// The data source will, to the best of its ability, maintain at least `min` open connections
477    /// at all times. This can be used to reduce the latency hit of opening new connections when at
478    /// least this many simultaneous connections are frequently needed.
479    pub fn min_connections(mut self, min: u32) -> Self {
480        self.pool_opt = self.pool_opt.min_connections(min);
481        self
482    }
483
484    #[cfg(not(feature = "embedded-db"))]
485    pub fn query_min_connections(mut self, min: u32) -> Self {
486        self.pool_opt_query = self.pool_opt_query.min_connections(min);
487        self
488    }
489
490    /// Set the maximum number of connections to maintain at any time.
491    ///
492    /// Once `max` connections are in use simultaneously, further attempts to acquire a connection
493    /// (or begin a transaction) will block until one of the existing connections is released.
494    pub fn max_connections(mut self, max: u32) -> Self {
495        self.pool_opt = self.pool_opt.max_connections(max);
496        self
497    }
498
499    #[cfg(not(feature = "embedded-db"))]
500    pub fn query_max_connections(mut self, max: u32) -> Self {
501        self.pool_opt_query = self.pool_opt_query.max_connections(max);
502        self
503    }
504
505    /// Log at WARN level any time a SQL statement takes longer than `threshold`.
506    ///
507    /// The default threshold is 1s.
508    pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
509        self.db_opt = self
510            .db_opt
511            .log_slow_statements(LevelFilter::Warn, threshold);
512        self
513    }
514
515    /// Set the maximum time a single SQL statement is allowed to run before being canceled.
516    ///
517    /// This helps prevent queries from running indefinitely even when the client is dropped
518    #[cfg(not(feature = "embedded-db"))]
519    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
520        // Format duration as milliseconds
521        // PostgreSQL interprets values without units as milliseconds
522        let timeout_ms = timeout.as_millis();
523        self.db_opt = self
524            .db_opt
525            .options([("statement_timeout", timeout_ms.to_string())]);
526        self
527    }
528
529    /// not supported for SQLite.
530    #[cfg(feature = "embedded-db")]
531    pub fn statement_timeout(self, _timeout: Duration) -> Self {
532        self
533    }
534}
535
536/// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database.
537#[derive(Clone, Debug)]
538pub struct SqlStorage {
539    pool: Pool<Db>,
540    metrics: PrometheusMetrics,
541    pool_metrics: PoolMetrics,
542    pruner_cfg: Option<PrunerCfg>,
543}
544
545#[derive(Debug, Default)]
546pub struct Pruner {
547    pruned_height: Option<u64>,
548    target_height: Option<u64>,
549    minimum_retention_height: Option<u64>,
550}
551
552#[derive(PartialEq)]
553pub enum StorageConnectionType {
554    Sequencer,
555    Query,
556}
557
558impl SqlStorage {
559    pub fn pool(&self) -> Pool<Db> {
560        self.pool.clone()
561    }
562
563    /// Connect to a remote database.
564    #[allow(unused_variables)]
565    pub async fn connect(
566        mut config: Config,
567        connection_type: StorageConnectionType,
568    ) -> Result<Self, Error> {
569        let metrics = PrometheusMetrics::default();
570        let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
571
572        #[cfg(feature = "embedded-db")]
573        let pool = config.pool_opt.clone();
574        #[cfg(not(feature = "embedded-db"))]
575        let pool = match connection_type {
576            StorageConnectionType::Sequencer => config.pool_opt.clone(),
577            StorageConnectionType::Query => config.pool_opt_query.clone(),
578        };
579
580        let pruner_cfg = config.pruner_cfg;
581
582        // Only reuse the same pool if we're using sqlite
583        if cfg!(feature = "embedded-db") || connection_type == StorageConnectionType::Sequencer {
584            // re-use the same pool if present and return early
585            if let Some(pool) = config.pool {
586                return Ok(Self {
587                    metrics,
588                    pool_metrics,
589                    pool,
590                    pruner_cfg,
591                });
592            }
593        } else if config.pool.is_some() {
594            tracing::info!("not reusing existing pool for query connection");
595        }
596
597        #[cfg(not(feature = "embedded-db"))]
598        let schema = config.schema.clone();
599        #[cfg(not(feature = "embedded-db"))]
600        let pool = pool.after_connect(move |conn, _| {
601            let schema = config.schema.clone();
602            async move {
603                query(&format!("SET search_path TO {schema}"))
604                    .execute(conn)
605                    .await?;
606                Ok(())
607            }
608            .boxed()
609        });
610
611        #[cfg(feature = "embedded-db")]
612        if config.reset {
613            std::fs::remove_file(config.db_opt.get_filename())?;
614        }
615
616        let pool = pool.connect_with(config.db_opt).await?;
617
618        // Create or connect to the schema for this query service.
619        let mut conn = pool.acquire().await?;
620
621        // Disable statement timeout for migrations, as they can take a long time
622        #[cfg(not(feature = "embedded-db"))]
623        query("SET statement_timeout = 0")
624            .execute(conn.as_mut())
625            .await?;
626
627        #[cfg(not(feature = "embedded-db"))]
628        if config.reset {
629            query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
630                .execute(conn.as_mut())
631                .await?;
632        }
633
634        #[cfg(not(feature = "embedded-db"))]
635        query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
636            .execute(conn.as_mut())
637            .await?;
638
639        // Get migrations and interleave with custom migrations, sorting by version number.
640        validate_migrations(&mut config.migrations)?;
641        let migrations =
642            add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
643
644        // Get a migration runner. Depending on the config, we can either use this to actually run
645        // the migrations or just check if the database is up to date.
646        let runner = refinery::Runner::new(&migrations).set_grouped(true);
647
648        if config.no_migrations {
649            // We've been asked not to run any migrations. Abort if the DB is not already up to
650            // date.
651            let last_applied = runner
652                .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
653                .await?;
654            let last_expected = migrations.last();
655            if last_applied.as_ref() != last_expected {
656                return Err(Error::msg(format!(
657                    "DB is out of date: last applied migration is {last_applied:?}, but expected \
658                     {last_expected:?}"
659                )));
660            }
661        } else {
662            // Run migrations using `refinery`.
663            match runner.run_async(&mut Migrator::from(&mut conn)).await {
664                Ok(report) => {
665                    tracing::info!("ran DB migrations: {report:?}");
666                },
667                Err(err) => {
668                    tracing::error!("DB migrations failed: {:?}", err.report());
669                    Err(err)?;
670                },
671            }
672        }
673
674        if config.archive {
675            // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will
676            // reconstruct previously pruned data.
677            query("DELETE FROM pruned_height WHERE id = 1")
678                .execute(conn.as_mut())
679                .await?;
680        }
681
682        conn.close().await?;
683
684        Ok(Self {
685            pool,
686            pool_metrics,
687            metrics,
688            pruner_cfg,
689        })
690    }
691}
692
693impl PrunerConfig for SqlStorage {
694    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
695        self.pruner_cfg = Some(cfg);
696    }
697
698    fn get_pruning_config(&self) -> Option<PrunerCfg> {
699        self.pruner_cfg.clone()
700    }
701}
702
703impl HasMetrics for SqlStorage {
704    fn metrics(&self) -> &PrometheusMetrics {
705        &self.metrics
706    }
707}
708
709impl SqlStorage {
710    async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
711        let mut tx = self.read().await.map_err(|err| QueryError::Error {
712            message: err.to_string(),
713        })?;
714        let (Some(height),) =
715            query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
716                .fetch_one(tx.as_mut())
717                .await?
718        else {
719            return Ok(None);
720        };
721        Ok(Some(height as u64))
722    }
723
724    async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
725        let mut tx = self.read().await.map_err(|err| QueryError::Error {
726            message: err.to_string(),
727        })?;
728
729        // We order by timestamp and then height, even though logically this is no different than
730        // just ordering by height, since timestamps are monotonic. The reason is that this order
731        // allows the query planner to efficiently solve the where clause and presort the results
732        // based on the timestamp index. The remaining sort on height, which guarantees a unique
733        // block if multiple blocks have the same timestamp, is very efficient, because there are
734        // never more than a handful of blocks with the same timestamp.
735        let Some((height,)) = query_as::<(i64,)>(
736            "SELECT height FROM header
737              WHERE timestamp <= $1
738              ORDER BY timestamp DESC, height DESC
739              LIMIT 1",
740        )
741        .bind(timestamp)
742        .fetch_optional(tx.as_mut())
743        .await?
744        else {
745            return Ok(None);
746        };
747        Ok(Some(height as u64))
748    }
749
750    /// Get the stored VID share for a given block, if one exists.
751    pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
752    where
753        Types: NodeType,
754        Header<Types>: QueryableHeader<Types>,
755    {
756        let mut tx = self.read().await.map_err(|err| QueryError::Error {
757            message: err.to_string(),
758        })?;
759        let share = tx.vid_share(block_id).await?;
760        Ok(share)
761    }
762
763    /// Get the stored VID common data for a given block, if one exists.
764    pub async fn get_vid_common<Types: NodeType>(
765        &self,
766        block_id: BlockId<Types>,
767    ) -> QueryResult<VidCommonQueryData<Types>>
768    where
769        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
770        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
771    {
772        let mut tx = self.read().await.map_err(|err| QueryError::Error {
773            message: err.to_string(),
774        })?;
775        let common = tx.get_vid_common(block_id).await?;
776        Ok(common)
777    }
778
779    /// Get the stored VID common metadata for a given block, if one exists.
780    pub async fn get_vid_common_metadata<Types: NodeType>(
781        &self,
782        block_id: BlockId<Types>,
783    ) -> QueryResult<VidCommonMetadata<Types>>
784    where
785        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
786        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
787    {
788        let mut tx = self.read().await.map_err(|err| QueryError::Error {
789            message: err.to_string(),
790        })?;
791        let common = tx.get_vid_common_metadata(block_id).await?;
792        Ok(common)
793    }
794}
795
796#[async_trait]
797impl PruneStorage for SqlStorage {
798    type Pruner = Pruner;
799
800    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
801        let mut tx = self.read().await?;
802
803        #[cfg(not(feature = "embedded-db"))]
804        let query = "SELECT pg_database_size(current_database())";
805
806        #[cfg(feature = "embedded-db")]
807        let query = "
808            SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
809                     AS total_bytes";
810
811        let row = tx.fetch_one(query).await?;
812        let size: i64 = row.get(0);
813
814        Ok(size as u64)
815    }
816
817    /// Trigger incremental vacuum to free up space in the SQLite database.
818    /// Note: We don't vacuum the Postgres database,
819    /// as there is no manual trigger for incremental vacuum,
820    /// and a full vacuum can take a lot of time.
821    #[cfg(feature = "embedded-db")]
822    async fn vacuum(&self) -> anyhow::Result<()> {
823        let config = self.get_pruning_config().ok_or(QueryError::Error {
824            message: "Pruning config not found".to_string(),
825        })?;
826        let mut conn = self.pool().acquire().await?;
827        query(&format!(
828            "PRAGMA incremental_vacuum({})",
829            config.incremental_vacuum_pages()
830        ))
831        .execute(conn.as_mut())
832        .await?;
833        conn.close().await?;
834        Ok(())
835    }
836
837    /// Note: The prune operation may not immediately free up space even after rows are deleted.
838    /// This is because a vacuum operation may be necessary to reclaim more space.
839    /// PostgreSQL already performs auto vacuuming, so we are not including it here
840    /// as running a vacuum operation can be resource-intensive.
841    async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
842        let cfg = self.get_pruning_config().ok_or(QueryError::Error {
843            message: "Pruning config not found".to_string(),
844        })?;
845        let batch_size = cfg.batch_size();
846        let max_usage = cfg.max_usage();
847        let state_tables = cfg.state_tables();
848
849        // If a pruner run was already in progress, some variables may already be set,
850        // depending on whether a batch was deleted and which batch it was (target or minimum retention).
851        // This enables us to resume the pruner run from the exact heights.
852        // If any of these values are not set, they can be loaded from the database if necessary.
853        let mut minimum_retention_height = pruner.minimum_retention_height;
854        let mut target_height = pruner.target_height;
855        let mut height = match pruner.pruned_height {
856            Some(h) => h,
857            None => {
858                let Some(height) = self.get_minimum_height().await? else {
859                    tracing::info!("database is empty, nothing to prune");
860                    return Ok(None);
861                };
862
863                height
864            },
865        };
866
867        // Prune data exceeding target retention in batches
868        if pruner.target_height.is_none() {
869            let th = self
870                .get_height_by_timestamp(
871                    Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
872                )
873                .await?;
874            target_height = th;
875            pruner.target_height = target_height;
876        };
877
878        if let Some(target_height) = target_height {
879            if height < target_height {
880                height = min(height + batch_size, target_height);
881                let mut tx = self.write().await?;
882                tx.delete_batch(state_tables, height).await?;
883                tx.commit().await.map_err(|e| QueryError::Error {
884                    message: format!("failed to commit {e}"),
885                })?;
886                pruner.pruned_height = Some(height);
887                return Ok(Some(height));
888            }
889        }
890
891        // If threshold is set, prune data exceeding minimum retention in batches
892        // This parameter is needed for SQL storage as there is no direct way to get free space.
893        if let Some(threshold) = cfg.pruning_threshold() {
894            let usage = self.get_disk_usage().await?;
895
896            // Prune data exceeding minimum retention in batches starting from minimum height
897            // until usage is below threshold
898            if usage > threshold {
899                tracing::warn!(
900                    "Disk usage {usage} exceeds pruning threshold {:?}",
901                    cfg.pruning_threshold()
902                );
903
904                if minimum_retention_height.is_none() {
905                    minimum_retention_height = self
906                        .get_height_by_timestamp(
907                            Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
908                        )
909                        .await?;
910
911                    pruner.minimum_retention_height = minimum_retention_height;
912                }
913
914                if let Some(min_retention_height) = minimum_retention_height {
915                    if (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
916                        && height < min_retention_height
917                    {
918                        height = min(height + batch_size, min_retention_height);
919                        let mut tx = self.write().await?;
920                        tx.delete_batch(state_tables, height).await?;
921                        tx.commit().await.map_err(|e| QueryError::Error {
922                            message: format!("failed to commit {e}"),
923                        })?;
924
925                        self.vacuum().await?;
926
927                        pruner.pruned_height = Some(height);
928
929                        return Ok(Some(height));
930                    }
931                }
932            }
933        }
934
935        Ok(None)
936    }
937}
938
939impl VersionedDataSource for SqlStorage {
940    type Transaction<'a>
941        = Transaction<Write>
942    where
943        Self: 'a;
944    type ReadOnly<'a>
945        = Transaction<Read>
946    where
947        Self: 'a;
948
949    async fn write(&self) -> anyhow::Result<Transaction<Write>> {
950        Transaction::new(&self.pool, self.pool_metrics.clone()).await
951    }
952
953    async fn read(&self) -> anyhow::Result<Transaction<Read>> {
954        Transaction::new(&self.pool, self.pool_metrics.clone()).await
955    }
956}
957
958#[async_trait]
959pub trait MigrateTypes<Types: NodeType> {
960    async fn migrate_types(&self, batch_size: u64) -> anyhow::Result<()>;
961}
962
963#[async_trait]
964impl<Types: NodeType> MigrateTypes<Types> for SqlStorage {
965    async fn migrate_types(&self, batch_size: u64) -> anyhow::Result<()> {
966        let limit = batch_size;
967        let mut tx = self.read().await.map_err(|err| QueryError::Error {
968            message: err.to_string(),
969        })?;
970
971        // The table `types_migration` is populated in the SQL migration with `completed = false` and `migrated_rows = 0`
972        // so fetch_one() would always return a row.
973        // After each batch insert, it is updated with the number of rows migrated.
974        // This is necessary to resume from the same point in case of a restart.
975        let (is_migration_completed, mut offset) = query_as::<(bool, i64)>(
976            "SELECT completed, migrated_rows from types_migration WHERE id = 1 ",
977        )
978        .fetch_one(tx.as_mut())
979        .await?;
980
981        if is_migration_completed {
982            tracing::info!("types migration already completed");
983            return Ok(());
984        }
985
986        tracing::warn!(
987            "migrating query service types storage. Offset={offset}, batch_size={limit}"
988        );
989
990        loop {
991            let mut tx = self.read().await.map_err(|err| QueryError::Error {
992                message: err.to_string(),
993            })?;
994
995            let rows = QueryBuilder::default()
996                .query(
997                    "SELECT leaf, qc, common as vid_common, share as vid_share
998                    FROM leaf INNER JOIN vid on leaf.height = vid.height 
999                    WHERE leaf.height >= $1 AND leaf.height < $2",
1000                )
1001                .bind(offset)
1002                .bind(offset + limit as i64)
1003                .fetch_all(tx.as_mut())
1004                .await?;
1005
1006            drop(tx);
1007
1008            if rows.is_empty() {
1009                break;
1010            }
1011
1012            let mut leaf_rows = Vec::new();
1013            let mut vid_rows = Vec::new();
1014
1015            for row in rows.iter() {
1016                let leaf1 = row.try_get("leaf")?;
1017                let qc = row.try_get("qc")?;
1018                let leaf1: Leaf<Types> = serde_json::from_value(leaf1)?;
1019                let qc: QuorumCertificate<Types> = serde_json::from_value(qc)?;
1020
1021                let leaf2: Leaf2<Types> = leaf1.into();
1022                let qc2: QuorumCertificate2<Types> = qc.to_qc2();
1023
1024                let commit = leaf2.commit();
1025
1026                let leaf2_json =
1027                    serde_json::to_value(leaf2.clone()).context("failed to serialize leaf2")?;
1028                let qc2_json = serde_json::to_value(qc2).context("failed to serialize QC2")?;
1029
1030                let vid_common_bytes: Vec<u8> = row.try_get("vid_common")?;
1031                let vid_share_bytes: Option<Vec<u8>> = row.try_get("vid_share")?;
1032
1033                let mut new_vid_share_bytes = None;
1034
1035                if let Some(vid_share_bytes) = vid_share_bytes {
1036                    let vid_share: ADVZShare = bincode::deserialize(&vid_share_bytes)
1037                        .context("failed to deserialize vid_share")?;
1038                    new_vid_share_bytes = Some(
1039                        bincode::serialize(&VidShare::V0(vid_share))
1040                            .context("failed to serialize vid_share")?,
1041                    );
1042                }
1043
1044                let vid_common: ADVZCommon = bincode::deserialize(&vid_common_bytes)
1045                    .context("failed to deserialize vid_common")?;
1046                let new_vid_common_bytes = bincode::serialize(&VidCommon::V0(vid_common))
1047                    .context("failed to serialize vid_common")?;
1048
1049                vid_rows.push((
1050                    leaf2.height() as i64,
1051                    new_vid_common_bytes,
1052                    new_vid_share_bytes,
1053                ));
1054                leaf_rows.push((
1055                    leaf2.height() as i64,
1056                    commit.to_string(),
1057                    leaf2.block_header().commit().to_string(),
1058                    leaf2_json,
1059                    qc2_json,
1060                ));
1061            }
1062
1063            // migrate leaf2
1064            let mut query_builder: sqlx::QueryBuilder<Db> =
1065                sqlx::QueryBuilder::new("INSERT INTO leaf2 (height, hash, block_hash, leaf, qc) ");
1066
1067            // Advance the `offset` to the highest `leaf.height` processed in this batch.
1068            // This ensures the next iteration starts from the next unseen leaf
1069            offset += limit as i64;
1070
1071            query_builder.push_values(leaf_rows, |mut b, row| {
1072                b.push_bind(row.0)
1073                    .push_bind(row.1)
1074                    .push_bind(row.2)
1075                    .push_bind(row.3)
1076                    .push_bind(row.4);
1077            });
1078
1079            query_builder.push(" ON CONFLICT DO NOTHING");
1080            let query = query_builder.build();
1081
1082            let mut tx = self.write().await.map_err(|err| QueryError::Error {
1083                message: err.to_string(),
1084            })?;
1085
1086            query.execute(tx.as_mut()).await?;
1087
1088            // update migrated_rows column with the offset
1089            tx.upsert(
1090                "types_migration",
1091                ["id", "completed", "migrated_rows"],
1092                ["id"],
1093                [(1_i64, false, offset)],
1094            )
1095            .await?;
1096
1097            // migrate vid
1098            let mut query_builder: sqlx::QueryBuilder<Db> =
1099                sqlx::QueryBuilder::new("INSERT INTO vid2 (height, common, share) ");
1100
1101            query_builder.push_values(vid_rows, |mut b, row| {
1102                b.push_bind(row.0).push_bind(row.1).push_bind(row.2);
1103            });
1104            query_builder.push(" ON CONFLICT DO NOTHING");
1105            let query = query_builder.build();
1106
1107            query.execute(tx.as_mut()).await?;
1108
1109            tx.commit().await?;
1110
1111            tracing::warn!("Migrated leaf and vid: offset={offset}");
1112
1113            tracing::info!("offset={offset}");
1114            if rows.len() < limit as usize {
1115                break;
1116            }
1117        }
1118
1119        let mut tx = self.write().await.map_err(|err| QueryError::Error {
1120            message: err.to_string(),
1121        })?;
1122
1123        tracing::warn!("query service types migration is completed!");
1124
1125        tx.upsert(
1126            "types_migration",
1127            ["id", "completed", "migrated_rows"],
1128            ["id"],
1129            [(1_i64, true, offset)],
1130        )
1131        .await?;
1132
1133        tracing::info!("updated types_migration table");
1134
1135        tx.commit().await?;
1136        Ok(())
1137    }
1138}
1139
1140// These tests run the `postgres` Docker image, which doesn't work on Windows.
1141#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
1142pub mod testing {
1143    #![allow(unused_imports)]
1144    use std::{
1145        env,
1146        process::{Command, Stdio},
1147        str::{self, FromStr},
1148        time::Duration,
1149    };
1150
1151    use portpicker::pick_unused_port;
1152    use refinery::Migration;
1153    use tokio::{net::TcpStream, time::timeout};
1154
1155    use super::Config;
1156    use crate::testing::sleep;
1157    #[derive(Debug)]
1158    pub struct TmpDb {
1159        #[cfg(not(feature = "embedded-db"))]
1160        host: String,
1161        #[cfg(not(feature = "embedded-db"))]
1162        port: u16,
1163        #[cfg(not(feature = "embedded-db"))]
1164        container_id: String,
1165        #[cfg(feature = "embedded-db")]
1166        db_path: std::path::PathBuf,
1167        #[allow(dead_code)]
1168        persistent: bool,
1169    }
1170    impl TmpDb {
1171        #[cfg(feature = "embedded-db")]
1172        fn init_sqlite_db(persistent: bool) -> Self {
1173            let file = tempfile::Builder::new()
1174                .prefix("sqlite-")
1175                .suffix(".db")
1176                .tempfile()
1177                .unwrap();
1178
1179            let (_, db_path) = file.keep().unwrap();
1180
1181            Self {
1182                db_path,
1183                persistent,
1184            }
1185        }
1186        pub async fn init() -> Self {
1187            #[cfg(feature = "embedded-db")]
1188            return Self::init_sqlite_db(false);
1189
1190            #[cfg(not(feature = "embedded-db"))]
1191            Self::init_postgres(false).await
1192        }
1193
1194        pub async fn persistent() -> Self {
1195            #[cfg(feature = "embedded-db")]
1196            return Self::init_sqlite_db(true);
1197
1198            #[cfg(not(feature = "embedded-db"))]
1199            Self::init_postgres(true).await
1200        }
1201
1202        #[cfg(not(feature = "embedded-db"))]
1203        async fn init_postgres(persistent: bool) -> Self {
1204            let docker_hostname = env::var("DOCKER_HOSTNAME");
1205            // This picks an unused port on the current system.  If docker is
1206            // configured to run on a different host then this may not find a
1207            // "free" port on that system.
1208            // We *might* be able to get away with this as any remote docker
1209            // host should hopefully be pretty open with it's port space.
1210            let port = pick_unused_port().unwrap();
1211            let host = docker_hostname.unwrap_or("localhost".to_string());
1212
1213            let mut cmd = Command::new("docker");
1214            cmd.arg("run")
1215                .arg("-d")
1216                .args(["-p", &format!("{port}:5432")])
1217                .args(["-e", "POSTGRES_PASSWORD=password"]);
1218
1219            if !persistent {
1220                cmd.arg("--rm");
1221            }
1222
1223            let output = cmd.arg("postgres").output().unwrap();
1224            let stdout = str::from_utf8(&output.stdout).unwrap();
1225            let stderr = str::from_utf8(&output.stderr).unwrap();
1226            if !output.status.success() {
1227                panic!("failed to start postgres docker: {stderr}");
1228            }
1229
1230            // Create the TmpDb object immediately after starting the Docker container, so if
1231            // anything panics after this `drop` will be called and we will clean up.
1232            let container_id = stdout.trim().to_owned();
1233            tracing::info!("launched postgres docker {container_id}");
1234            let db = Self {
1235                host,
1236                port,
1237                container_id: container_id.clone(),
1238                persistent,
1239            };
1240
1241            db.wait_for_ready().await;
1242            db
1243        }
1244
1245        #[cfg(not(feature = "embedded-db"))]
1246        pub fn host(&self) -> String {
1247            self.host.clone()
1248        }
1249
1250        #[cfg(not(feature = "embedded-db"))]
1251        pub fn port(&self) -> u16 {
1252            self.port
1253        }
1254
1255        #[cfg(feature = "embedded-db")]
1256        pub fn path(&self) -> std::path::PathBuf {
1257            self.db_path.clone()
1258        }
1259
1260        pub fn config(&self) -> Config {
1261            #[cfg(feature = "embedded-db")]
1262            let mut cfg = Config::default().db_path(self.db_path.clone());
1263
1264            #[cfg(not(feature = "embedded-db"))]
1265            let mut cfg = Config::default()
1266                .user("postgres")
1267                .password("password")
1268                .host(self.host())
1269                .port(self.port());
1270
1271            cfg = cfg.migrations(vec![Migration::unapplied(
1272                "V101__create_test_merkle_tree_table.sql",
1273                &TestMerkleTreeMigration::create("test_tree"),
1274            )
1275            .unwrap()]);
1276
1277            cfg
1278        }
1279
1280        #[cfg(not(feature = "embedded-db"))]
1281        pub fn stop_postgres(&mut self) {
1282            tracing::info!(container = self.container_id, "stopping postgres");
1283            let output = Command::new("docker")
1284                .args(["stop", self.container_id.as_str()])
1285                .output()
1286                .unwrap();
1287            assert!(
1288                output.status.success(),
1289                "error killing postgres docker {}: {}",
1290                self.container_id,
1291                str::from_utf8(&output.stderr).unwrap()
1292            );
1293        }
1294
1295        #[cfg(not(feature = "embedded-db"))]
1296        pub async fn start_postgres(&mut self) {
1297            tracing::info!(container = self.container_id, "resuming postgres");
1298            let output = Command::new("docker")
1299                .args(["start", self.container_id.as_str()])
1300                .output()
1301                .unwrap();
1302            assert!(
1303                output.status.success(),
1304                "error starting postgres docker {}: {}",
1305                self.container_id,
1306                str::from_utf8(&output.stderr).unwrap()
1307            );
1308
1309            self.wait_for_ready().await;
1310        }
1311
1312        #[cfg(not(feature = "embedded-db"))]
1313        async fn wait_for_ready(&self) {
1314            let timeout_duration = Duration::from_secs(
1315                env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1316                    .unwrap_or("60".to_string())
1317                    .parse()
1318                    .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1319            );
1320
1321            if let Err(err) = timeout(timeout_duration, async {
1322                while Command::new("docker")
1323                    .args([
1324                        "exec",
1325                        &self.container_id,
1326                        "pg_isready",
1327                        "-h",
1328                        "localhost",
1329                        "-U",
1330                        "postgres",
1331                    ])
1332                    .env("PGPASSWORD", "password")
1333                    // Null input so the command terminates as soon as it manages to connect.
1334                    .stdin(Stdio::null())
1335                    // Discard command output.
1336                    .stdout(Stdio::null())
1337                    .stderr(Stdio::null())
1338                    .status()
1339                    // We should ensure the exit status. A simple `unwrap`
1340                    // would panic on unrelated errors (such as network
1341                    // connection failures)
1342                    .and_then(|status| {
1343                        status
1344                            .success()
1345                            .then_some(true)
1346                            // Any ol' Error will do
1347                            .ok_or(std::io::Error::from_raw_os_error(666))
1348                    })
1349                    .is_err()
1350                {
1351                    tracing::warn!("database is not ready");
1352                    sleep(Duration::from_secs(1)).await;
1353                }
1354
1355                // The above command ensures the database is ready inside the Docker container.
1356                // However, on some systems, there is a slight delay before the port is exposed via
1357                // host networking. We don't need to check again that the database is ready on the
1358                // host (and maybe can't, because the host might not have pg_isready installed), but
1359                // we can ensure the port is open by just establishing a TCP connection.
1360                while let Err(err) =
1361                    TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1362                {
1363                    tracing::warn!("database is ready, but port is not available to host: {err:#}");
1364                    sleep(Duration::from_millis(100)).await;
1365                }
1366            })
1367            .await
1368            {
1369                panic!(
1370                    "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1371                     {err:#}\n{}",
1372                    "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1373                );
1374            }
1375        }
1376    }
1377
1378    #[cfg(not(feature = "embedded-db"))]
1379    impl Drop for TmpDb {
1380        fn drop(&mut self) {
1381            self.stop_postgres();
1382        }
1383    }
1384
1385    #[cfg(feature = "embedded-db")]
1386    impl Drop for TmpDb {
1387        fn drop(&mut self) {
1388            if !self.persistent {
1389                std::fs::remove_file(self.db_path.clone()).unwrap();
1390            }
1391        }
1392    }
1393
1394    pub struct TestMerkleTreeMigration;
1395
1396    impl TestMerkleTreeMigration {
1397        fn create(name: &str) -> String {
1398            let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1399                (
1400                    "TEXT",
1401                    "BLOB",
1402                    "INTEGER PRIMARY KEY AUTOINCREMENT",
1403                    " (json_extract(data, '$.test_merkle_tree_root'))",
1404                )
1405            } else {
1406                (
1407                    "BIT(8)",
1408                    "BYTEA",
1409                    "SERIAL PRIMARY KEY",
1410                    "(data->>'test_merkle_tree_root')",
1411                )
1412            };
1413
1414            format!(
1415                "CREATE TABLE IF NOT EXISTS hash
1416            (
1417                id {hash_pk},
1418                value {binary}  NOT NULL UNIQUE
1419            );
1420    
1421            ALTER TABLE header
1422            ADD column test_merkle_tree_root text
1423            GENERATED ALWAYS as {root_stored_column} STORED;
1424
1425            CREATE TABLE {name}
1426            (
1427                path JSONB NOT NULL, 
1428                created BIGINT NOT NULL,
1429                hash_id INT NOT NULL,
1430                children JSONB,
1431                children_bitvec {bit_vec},
1432                idx JSONB,
1433                entry JSONB,
1434                PRIMARY KEY (path, created)
1435            );
1436            CREATE INDEX {name}_created ON {name} (created);"
1437            )
1438        }
1439    }
1440}
1441
1442// These tests run the `postgres` Docker image, which doesn't work on Windows.
1443#[cfg(all(test, not(target_os = "windows")))]
1444mod test {
1445    use std::time::Duration;
1446
1447    use committable::{Commitment, CommitmentBoundsArkless, Committable};
1448    use hotshot::traits::BlockPayload;
1449    use hotshot_example_types::{
1450        node_types::TestVersions,
1451        state_types::{TestInstanceState, TestValidatedState},
1452    };
1453    use hotshot_types::{
1454        data::{QuorumProposal, ViewNumber},
1455        simple_vote::QuorumData,
1456        traits::{
1457            block_contents::{BlockHeader, GENESIS_VID_NUM_STORAGE_NODES},
1458            node_implementation::{ConsensusTime, Versions},
1459            EncodeBytes,
1460        },
1461        vid::advz::advz_scheme,
1462    };
1463    use jf_advz::VidScheme;
1464    use jf_merkle_tree_compat::{
1465        prelude::UniversalMerkleTree, MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme,
1466    };
1467    use tokio::time::sleep;
1468    use vbs::version::StaticVersionType;
1469
1470    use super::{testing::TmpDb, *};
1471    use crate::{
1472        availability::LeafQueryData,
1473        data_source::storage::{pruning::PrunedHeightStorage, UpdateAvailabilityStorage},
1474        merklized_state::{MerklizedState, UpdateStateData},
1475        testing::mocks::{MockHeader, MockMerkleTree, MockPayload, MockTypes, MockVersions},
1476    };
1477
1478    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1479    async fn test_migrations() {
1480        let db = TmpDb::init().await;
1481        let cfg = db.config();
1482
1483        let connect = |migrations: bool, custom_migrations| {
1484            let cfg = cfg.clone();
1485            async move {
1486                let mut cfg = cfg.migrations(custom_migrations);
1487                if !migrations {
1488                    cfg = cfg.no_migrations();
1489                }
1490                let client = SqlStorage::connect(cfg, StorageConnectionType::Query).await?;
1491                Ok::<_, Error>(client)
1492            }
1493        };
1494
1495        // Connecting with migrations disabled should fail if the database is not already up to date
1496        // (since we've just created a fresh database, it isn't).
1497        let err = connect(false, vec![]).await.unwrap_err();
1498        tracing::info!("connecting without running migrations failed as expected: {err}");
1499
1500        // Now connect and run migrations to bring the database up to date.
1501        connect(true, vec![]).await.unwrap();
1502        // Now connecting without migrations should work.
1503        connect(false, vec![]).await.unwrap();
1504
1505        // Connect with some custom migrations, to advance the schema even further. Pass in the
1506        // custom migrations out of order; they should still execute in order of version number.
1507        // The SQL commands used here will fail if not run in order.
1508        let migrations = vec![
1509            Migration::unapplied(
1510                "V9999__create_test_table.sql",
1511                "ALTER TABLE test ADD COLUMN data INTEGER;",
1512            )
1513            .unwrap(),
1514            Migration::unapplied(
1515                "V9998__create_test_table.sql",
1516                "CREATE TABLE test (x bigint);",
1517            )
1518            .unwrap(),
1519        ];
1520        connect(true, migrations.clone()).await.unwrap();
1521
1522        // Connect using the default schema (no custom migrations) and not running migrations. This
1523        // should fail because the database is _ahead_ of the client in terms of schema.
1524        let err = connect(false, vec![]).await.unwrap_err();
1525        tracing::info!("connecting without running migrations failed as expected: {err}");
1526
1527        // Connecting with the customized schema should work even without running migrations.
1528        connect(true, migrations).await.unwrap();
1529    }
1530
1531    #[test]
1532    #[cfg(not(feature = "embedded-db"))]
1533    fn test_config_from_str() {
1534        let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1535        assert_eq!(cfg.db_opt.get_username(), "user");
1536        assert_eq!(cfg.db_opt.get_host(), "host");
1537        assert_eq!(cfg.db_opt.get_port(), 8080);
1538    }
1539
1540    #[test]
1541    #[cfg(feature = "embedded-db")]
1542    fn test_config_from_str() {
1543        let cfg = Config::from_str("sqlite://data.db").unwrap();
1544        assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1545    }
1546
1547    async fn vacuum(storage: &SqlStorage) {
1548        #[cfg(feature = "embedded-db")]
1549        let query = "PRAGMA incremental_vacuum(16000)";
1550        #[cfg(not(feature = "embedded-db"))]
1551        let query = "VACUUM";
1552        storage
1553            .pool
1554            .acquire()
1555            .await
1556            .unwrap()
1557            .execute(query)
1558            .await
1559            .unwrap();
1560    }
1561
1562    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1563    async fn test_target_period_pruning() {
1564        let db = TmpDb::init().await;
1565        let cfg = db.config();
1566
1567        let mut storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1568            .await
1569            .unwrap();
1570        let mut leaf = LeafQueryData::<MockTypes>::genesis::<TestVersions>(
1571            &TestValidatedState::default(),
1572            &TestInstanceState::default(),
1573        )
1574        .await;
1575        // insert some mock data
1576        for i in 0..20 {
1577            leaf.leaf.block_header_mut().block_number = i;
1578            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1579            let mut tx = storage.write().await.unwrap();
1580            tx.insert_leaf(leaf.clone()).await.unwrap();
1581            tx.commit().await.unwrap();
1582        }
1583
1584        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1585
1586        // Set pruner config to default which has minimum retention set to 1 day
1587        storage.set_pruning_config(PrunerCfg::new());
1588        // No data will be pruned
1589        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1590
1591        // Vacuum the database to reclaim space.
1592        // This is necessary to ensure the test passes.
1593        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1594        vacuum(&storage).await;
1595        // Pruned height should be none
1596        assert!(pruned_height.is_none());
1597
1598        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1599
1600        assert_eq!(
1601            height_after_pruning, height_before_pruning,
1602            "some data has been pruned"
1603        );
1604
1605        // Set pruner config to target retention set to 1s
1606        storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1607        sleep(Duration::from_secs(2)).await;
1608        let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1609        // All of the data is now older than 1s.
1610        // This would prune all the data as the target retention is set to 1s
1611        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1612        // Vacuum the database to reclaim space.
1613        // This is necessary to ensure the test passes.
1614        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1615        vacuum(&storage).await;
1616
1617        // Pruned height should be some
1618        assert!(pruned_height.is_some());
1619        let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1620        // All the tables should be empty
1621        // counting rows in header table
1622        let header_rows = storage
1623            .read()
1624            .await
1625            .unwrap()
1626            .fetch_one("select count(*) as count from header")
1627            .await
1628            .unwrap()
1629            .get::<i64, _>("count");
1630        // the table should be empty
1631        assert_eq!(header_rows, 0);
1632
1633        // counting rows in leaf table.
1634        // Deleting rows from header table would delete rows in all the tables
1635        // as each of table implement "ON DELETE CASCADE" fk constraint with the header table.
1636        let leaf_rows = storage
1637            .read()
1638            .await
1639            .unwrap()
1640            .fetch_one("select count(*) as count from leaf")
1641            .await
1642            .unwrap()
1643            .get::<i64, _>("count");
1644        // the table should be empty
1645        assert_eq!(leaf_rows, 0);
1646
1647        assert!(
1648            usage_before_pruning > usage_after_pruning,
1649            " disk usage should decrease after pruning"
1650        )
1651    }
1652
1653    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1654    async fn test_merklized_state_pruning() {
1655        let db = TmpDb::init().await;
1656        let config = db.config();
1657
1658        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1659            .await
1660            .unwrap();
1661        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1662            MockMerkleTree::new(MockMerkleTree::tree_height());
1663
1664        // insert some entries into the tree and the header table
1665        // Header table is used the get_path query to check if the header exists for the block height.
1666        let mut tx = storage.write().await.unwrap();
1667
1668        for block_height in 0..250 {
1669            test_tree.update(block_height, block_height).unwrap();
1670
1671            // data field of the header
1672            let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1673            tx.upsert(
1674                "header",
1675                ["height", "hash", "payload_hash", "timestamp", "data"],
1676                ["height"],
1677                [(
1678                    block_height as i64,
1679                    format!("randomHash{block_height}"),
1680                    "t".to_string(),
1681                    0,
1682                    test_data,
1683                )],
1684            )
1685            .await
1686            .unwrap();
1687            // proof for the index from the tree
1688            let (_, proof) = test_tree.lookup(block_height).expect_ok().unwrap();
1689            // traversal path for the index.
1690            let traversal_path =
1691                <usize as ToTraversalPath<8>>::to_traversal_path(&block_height, test_tree.height());
1692
1693            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1694                &mut tx,
1695                proof.clone(),
1696                traversal_path.clone(),
1697                block_height as u64,
1698            )
1699            .await
1700            .expect("failed to insert nodes");
1701        }
1702
1703        // update saved state height
1704        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 250)
1705            .await
1706            .unwrap();
1707        tx.commit().await.unwrap();
1708
1709        let mut tx = storage.read().await.unwrap();
1710
1711        // checking if the data is inserted correctly
1712        // there should be multiple nodes with same index but different created time
1713        let (count,) = query_as::<(i64,)>(
1714            " SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1715             count(*) > 1)",
1716        )
1717        .fetch_one(tx.as_mut())
1718        .await
1719        .unwrap();
1720
1721        tracing::info!("Number of nodes with multiple snapshots : {count}");
1722        assert!(count > 0);
1723
1724        // This should delete all the nodes having height < 250 and is not the newest node with its position
1725        let mut tx = storage.write().await.unwrap();
1726        tx.delete_batch(vec!["test_tree".to_string()], 250)
1727            .await
1728            .unwrap();
1729
1730        tx.commit().await.unwrap();
1731        let mut tx = storage.read().await.unwrap();
1732        let (count,) = query_as::<(i64,)>(
1733            "SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1734             count(*) > 1)",
1735        )
1736        .fetch_one(tx.as_mut())
1737        .await
1738        .unwrap();
1739
1740        tracing::info!("Number of nodes with multiple snapshots : {count}");
1741
1742        assert!(count == 0);
1743    }
1744
1745    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1746    async fn test_minimum_retention_pruning() {
1747        let db = TmpDb::init().await;
1748
1749        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1750            .await
1751            .unwrap();
1752        let mut leaf = LeafQueryData::<MockTypes>::genesis::<TestVersions>(
1753            &TestValidatedState::default(),
1754            &TestInstanceState::default(),
1755        )
1756        .await;
1757        // insert some mock data
1758        for i in 0..20 {
1759            leaf.leaf.block_header_mut().block_number = i;
1760            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1761            let mut tx = storage.write().await.unwrap();
1762            tx.insert_leaf(leaf.clone()).await.unwrap();
1763            tx.commit().await.unwrap();
1764        }
1765
1766        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1767        let cfg = PrunerCfg::new();
1768        // Set pruning_threshold to 1
1769        // SQL storage size is more than 1000 bytes even without any data indexed
1770        // This would mean that the threshold would always be greater than the disk usage
1771        // However, minimum retention is set to 24 hours by default so the data would not be pruned
1772        storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1773        println!("{:?}", storage.get_pruning_config().unwrap());
1774        // Pruning would not delete any data
1775        // All the data is younger than minimum retention period even though the usage > threshold
1776        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1777        // Vacuum the database to reclaim space.
1778        // This is necessary to ensure the test passes.
1779        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1780        vacuum(&storage).await;
1781
1782        // Pruned height should be none
1783        assert!(pruned_height.is_none());
1784
1785        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1786
1787        assert_eq!(
1788            height_after_pruning, height_before_pruning,
1789            "some data has been pruned"
1790        );
1791
1792        // Change minimum retention to 1s
1793        storage.set_pruning_config(
1794            cfg.with_minimum_retention(Duration::from_secs(1))
1795                .with_pruning_threshold(1),
1796        );
1797        // sleep for 2s to make sure the data is older than minimum retention
1798        sleep(Duration::from_secs(2)).await;
1799        // This would prune all the data
1800        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1801        // Vacuum the database to reclaim space.
1802        // This is necessary to ensure the test passes.
1803        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1804        vacuum(&storage).await;
1805
1806        // Pruned height should be some
1807        assert!(pruned_height.is_some());
1808        // All the tables should be empty
1809        // counting rows in header table
1810        let header_rows = storage
1811            .read()
1812            .await
1813            .unwrap()
1814            .fetch_one("select count(*) as count from header")
1815            .await
1816            .unwrap()
1817            .get::<i64, _>("count");
1818        // the table should be empty
1819        assert_eq!(header_rows, 0);
1820    }
1821
1822    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1823    async fn test_pruned_height_storage() {
1824        let db = TmpDb::init().await;
1825        let cfg = db.config();
1826
1827        let storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1828            .await
1829            .unwrap();
1830        assert!(storage
1831            .read()
1832            .await
1833            .unwrap()
1834            .load_pruned_height()
1835            .await
1836            .unwrap()
1837            .is_none());
1838        for height in [10, 20, 30] {
1839            let mut tx = storage.write().await.unwrap();
1840            tx.save_pruned_height(height).await.unwrap();
1841            tx.commit().await.unwrap();
1842            assert_eq!(
1843                storage
1844                    .read()
1845                    .await
1846                    .unwrap()
1847                    .load_pruned_height()
1848                    .await
1849                    .unwrap(),
1850                Some(height)
1851            );
1852        }
1853    }
1854
1855    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1856    async fn test_types_migration() {
1857        let num_rows = 500;
1858        let db = TmpDb::init().await;
1859
1860        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1861            .await
1862            .unwrap();
1863
1864        for i in 0..num_rows {
1865            let view = ViewNumber::new(i);
1866            let validated_state = TestValidatedState::default();
1867            let instance_state = TestInstanceState::default();
1868
1869            let (payload, metadata) = <MockPayload as BlockPayload<MockTypes>>::from_transactions(
1870                [],
1871                &validated_state,
1872                &instance_state,
1873            )
1874            .await
1875            .unwrap();
1876
1877            let mut block_header = <MockHeader as BlockHeader<MockTypes>>::genesis::<MockVersions>(
1878                &instance_state,
1879                payload.clone(),
1880                &metadata,
1881            );
1882
1883            block_header.block_number = i;
1884
1885            let null_quorum_data = QuorumData {
1886                leaf_commit: Commitment::<Leaf<MockTypes>>::default_commitment_no_preimage(),
1887            };
1888
1889            let mut qc = QuorumCertificate::new(
1890                null_quorum_data.clone(),
1891                null_quorum_data.commit(),
1892                view,
1893                None,
1894                std::marker::PhantomData,
1895            );
1896
1897            let quorum_proposal = QuorumProposal {
1898                block_header,
1899                view_number: view,
1900                justify_qc: qc.clone(),
1901                upgrade_certificate: None,
1902                proposal_certificate: None,
1903            };
1904
1905            let mut leaf = Leaf::from_quorum_proposal(&quorum_proposal);
1906            leaf.fill_block_payload::<MockVersions>(
1907                payload.clone(),
1908                GENESIS_VID_NUM_STORAGE_NODES,
1909                <MockVersions as Versions>::Base::VERSION,
1910            )
1911            .unwrap();
1912            qc.data.leaf_commit = <Leaf<MockTypes> as Committable>::commit(&leaf);
1913
1914            let height = leaf.height() as i64;
1915            let hash = <Leaf<_> as Committable>::commit(&leaf).to_string();
1916            let header = leaf.block_header();
1917
1918            let header_json = serde_json::to_value(header)
1919                .context("failed to serialize header")
1920                .unwrap();
1921
1922            let payload_commitment =
1923                <MockHeader as BlockHeader<MockTypes>>::payload_commitment(header);
1924            let mut tx = storage.write().await.unwrap();
1925
1926            tx.upsert(
1927                "header",
1928                ["height", "hash", "payload_hash", "data", "timestamp"],
1929                ["height"],
1930                [(
1931                    height,
1932                    leaf.block_header().commit().to_string(),
1933                    payload_commitment.to_string(),
1934                    header_json,
1935                    <MockHeader as BlockHeader<MockTypes>>::timestamp(leaf.block_header()) as i64,
1936                )],
1937            )
1938            .await
1939            .unwrap();
1940
1941            let leaf_json = serde_json::to_value(leaf.clone()).expect("failed to serialize leaf");
1942            let qc_json = serde_json::to_value(qc).expect("failed to serialize QC");
1943            tx.upsert(
1944                "leaf",
1945                ["height", "hash", "block_hash", "leaf", "qc"],
1946                ["height"],
1947                [(
1948                    height,
1949                    hash,
1950                    header.commit().to_string(),
1951                    leaf_json,
1952                    qc_json,
1953                )],
1954            )
1955            .await
1956            .unwrap();
1957
1958            let mut vid = advz_scheme(2);
1959            let disperse = vid.disperse(payload.encode()).unwrap();
1960            let common = disperse.common;
1961
1962            let common_bytes = bincode::serialize(&common).unwrap();
1963            let share = disperse.shares[0].clone();
1964            let mut share_bytes = Some(bincode::serialize(&share).unwrap());
1965
1966            // insert some nullable vid shares
1967            if i % 10 == 0 {
1968                share_bytes = None
1969            }
1970
1971            tx.upsert(
1972                "vid",
1973                ["height", "common", "share"],
1974                ["height"],
1975                [(height, common_bytes, share_bytes)],
1976            )
1977            .await
1978            .unwrap();
1979            tx.commit().await.unwrap();
1980        }
1981
1982        <SqlStorage as MigrateTypes<MockTypes>>::migrate_types(&storage, 50)
1983            .await
1984            .expect("failed to migrate");
1985
1986        <SqlStorage as MigrateTypes<MockTypes>>::migrate_types(&storage, 50)
1987            .await
1988            .expect("failed to migrate");
1989
1990        let mut tx = storage.read().await.unwrap();
1991        let (leaf_count,) = query_as::<(i64,)>("SELECT COUNT(*) from leaf2")
1992            .fetch_one(tx.as_mut())
1993            .await
1994            .unwrap();
1995
1996        let (vid_count,) = query_as::<(i64,)>("SELECT COUNT(*) from vid2")
1997            .fetch_one(tx.as_mut())
1998            .await
1999            .unwrap();
2000
2001        assert_eq!(leaf_count as u64, num_rows, "not all leaves migrated");
2002        assert_eq!(vid_count as u64, num_rows, "not all vid migrated");
2003    }
2004
2005    #[test_log::test(tokio::test(flavor = "multi_thread"))]
2006    async fn test_transaction_upsert_retries() {
2007        let db = TmpDb::init().await;
2008        let config = db.config();
2009
2010        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
2011            .await
2012            .unwrap();
2013
2014        let mut tx = storage.write().await.unwrap();
2015
2016        // Try to upsert into a table that does not exist.
2017        // This will fail, so our `upsert` function will enter the retry loop.
2018        // Since the table does not exist, all retries will eventually
2019        // fail and we expect an error to be returned.
2020        //
2021        // Previously, this case would cause  a panic because we were calling
2022        // methods on `QueryBuilder` after `.build()` without first
2023        // calling `.reset()`and according to the sqlx docs, that always panics.
2024        // Now, since we are properly calling `.reset()` inside `upsert()` for
2025        // the query builder, the function returns an error instead of panicking.
2026        tx.upsert("does_not_exist", ["test"], ["test"], [(1_i64,)])
2027            .await
2028            .unwrap_err();
2029    }
2030}