hotshot_query_service/data_source/storage/
sql.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use anyhow::Context;
17use async_trait::async_trait;
18use chrono::Utc;
19use committable::Committable;
20#[cfg(not(feature = "embedded-db"))]
21use futures::future::FutureExt;
22use hotshot_types::{
23    data::{Leaf, Leaf2, VidShare},
24    simple_certificate::{QuorumCertificate, QuorumCertificate2},
25    traits::{metrics::Metrics, node_implementation::NodeType},
26    vid::advz::{ADVZCommon, ADVZShare},
27};
28use itertools::Itertools;
29use log::LevelFilter;
30#[cfg(not(feature = "embedded-db"))]
31use sqlx::postgres::{PgConnectOptions, PgSslMode};
32#[cfg(feature = "embedded-db")]
33use sqlx::sqlite::SqliteConnectOptions;
34use sqlx::{
35    pool::{Pool, PoolOptions},
36    ConnectOptions, Row,
37};
38
39use crate::{
40    availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
41    data_source::{
42        storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
43        update::Transaction as _,
44        VersionedDataSource,
45    },
46    metrics::PrometheusMetrics,
47    node::BlockId,
48    status::HasMetrics,
49    Header, QueryError, QueryResult, VidCommon,
50};
51pub extern crate sqlx;
52pub use sqlx::{Database, Sqlite};
53
54mod db;
55mod migrate;
56mod queries;
57mod transaction;
58
59pub use anyhow::Error;
60pub use db::*;
61pub use include_dir::include_dir;
62pub use queries::QueryBuilder;
63pub use refinery::Migration;
64pub use transaction::*;
65
66use self::{migrate::Migrator, transaction::PoolMetrics};
67use super::{AvailabilityStorage, NodeStorage};
68// This needs to be reexported so that we can reference it by absolute path relative to this crate
69// in the expansion of `include_migrations`, even when `include_migrations` is invoked from another
70// crate which doesn't have `include_dir` as a dependency.
71pub use crate::include_migrations;
72
73/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite.
74///
75/// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl
76/// Iterator<Item = Migration>`. Each migration must be a text file which is an immediate child of
77/// `path`, and there must be no non-migration files in `path`. The migration files must have names
78/// of the form `V${version}__${name}.sql`, where `version` is a positive integer indicating how the
79/// migration is to be ordered relative to other migrations, and `name` is a descriptive name for
80/// the migration.
81///
82/// `path` should be an absolute path. It is possible to give a path relative to the root of the
83/// invoking crate by using environment variable expansions and the `CARGO_MANIFEST_DIR` environment
84/// variable.
85///
86/// As an example, this is the invocation used to load the default migrations from the
87/// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at
88/// - PostgreSQL migrations are in `/migrations/postgres`.
89/// - SQLite migrations are in `/migrations/sqlite`.
90///
91/// ```
92/// # use hotshot_query_service::data_source::sql::{include_migrations, Migration};
93/// // For PostgreSQL
94/// #[cfg(not(feature = "embedded-db"))]
95///  let mut migrations: Vec<Migration> =
96///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect();
97/// // For SQLite
98/// #[cfg(feature = "embedded-db")]
99/// let mut migrations: Vec<Migration> =
100///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect();
101///    
102///     migrations.sort();
103///     assert_eq!(migrations[0].version(), 10);
104///     assert_eq!(migrations[0].name(), "init_schema");
105/// ```
106///
107/// Note that a similar macro is available from Refinery:
108/// [embed_migrations](https://docs.rs/refinery/0.8.11/refinery/macro.embed_migrations.html). This
109/// macro differs in that it evaluates to an iterator of [migrations](Migration), making it an
110/// expression macro, while `embed_migrations` is a statement macro that defines a module which
111/// provides access to the embedded migrations only indirectly via a
112/// [`Runner`](https://docs.rs/refinery/0.8.11/refinery/struct.Runner.html). The direct access to
113/// migrations provided by [`include_migrations`] makes this macro easier to use with
114/// [`Config::migrations`], for combining custom migrations with [`default_migrations`].
115#[macro_export]
116macro_rules! include_migrations {
117    ($dir:tt) => {
118        $crate::data_source::storage::sql::include_dir!($dir)
119            .files()
120            .map(|file| {
121                let path = file.path();
122                let name = path
123                    .file_name()
124                    .and_then(std::ffi::OsStr::to_str)
125                    .unwrap_or_else(|| {
126                        panic!(
127                            "migration file {} must have a non-empty UTF-8 name",
128                            path.display()
129                        )
130                    });
131                let sql = file
132                    .contents_utf8()
133                    .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
134                $crate::data_source::storage::sql::Migration::unapplied(name, sql)
135                    .expect("invalid migration")
136            })
137    };
138}
139
140/// The migrations required to build the default schema for this version of [`SqlStorage`].
141pub fn default_migrations() -> Vec<Migration> {
142    #[cfg(not(feature = "embedded-db"))]
143    let mut migrations =
144        include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
145
146    #[cfg(feature = "embedded-db")]
147    let mut migrations =
148        include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
149
150    // Check version uniqueness and sort by version.
151    validate_migrations(&mut migrations).expect("default migrations are invalid");
152
153    // Check that all migration versions are multiples of 100, so that custom migrations can be
154    // inserted in between.
155    for m in &migrations {
156        if m.version() <= 30 {
157            // An older version of this software used intervals of 10 instead of 100. This was
158            // changed to allow more custom migrations between each default migration, but we must
159            // still accept older migrations that followed the older rule.
160            assert!(
161                m.version() > 0 && m.version() % 10 == 0,
162                "legacy default migration version {} is not a positive multiple of 10",
163                m.version()
164            );
165        } else {
166            assert!(
167                m.version() % 100 == 0,
168                "default migration version {} is not a multiple of 100",
169                m.version()
170            );
171        }
172    }
173
174    migrations
175}
176
177/// Validate and preprocess a sequence of migrations.
178///
179/// * Ensure all migrations have distinct versions
180/// * Ensure migrations are sorted by increasing version
181fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
182    migrations.sort_by_key(|m| m.version());
183
184    // Check version uniqueness.
185    for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
186        if next <= prev {
187            return Err(Error::msg(format!(
188                "migration versions are not strictly increasing ({prev}->{next})"
189            )));
190        }
191    }
192
193    Ok(())
194}
195
196/// Add custom migrations to a default migration sequence.
197///
198/// Migrations in `custom` replace migrations in `default` with the same version. Otherwise, the two
199/// sequences `default` and `custom` are merged so that the resulting sequence is sorted by
200/// ascending version number. Each of `default` and `custom` is assumed to be the output of
201/// [`validate_migrations`]; that is, each is sorted by version and contains no duplicate versions.
202fn add_custom_migrations(
203    default: impl IntoIterator<Item = Migration>,
204    custom: impl IntoIterator<Item = Migration>,
205) -> impl Iterator<Item = Migration> {
206    default
207        .into_iter()
208        // Merge sorted lists, joining pairs of equal version into `EitherOrBoth::Both`.
209        .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
210        // Prefer the custom migration for a given version when both default and custom versions
211        // are present.
212        .map(|pair| pair.reduce(|_, custom| custom))
213}
214
215#[derive(Clone)]
216pub struct Config {
217    #[cfg(feature = "embedded-db")]
218    db_opt: SqliteConnectOptions,
219    #[cfg(not(feature = "embedded-db"))]
220    db_opt: PgConnectOptions,
221    pool_opt: PoolOptions<Db>,
222    #[cfg(not(feature = "embedded-db"))]
223    schema: String,
224    reset: bool,
225    migrations: Vec<Migration>,
226    no_migrations: bool,
227    pruner_cfg: Option<PrunerCfg>,
228    archive: bool,
229    pool: Option<Pool<Db>>,
230}
231
232#[cfg(not(feature = "embedded-db"))]
233impl Default for Config {
234    fn default() -> Self {
235        PgConnectOptions::default()
236            .username("postgres")
237            .password("password")
238            .host("localhost")
239            .port(5432)
240            .into()
241    }
242}
243
244#[cfg(feature = "embedded-db")]
245impl Default for Config {
246    fn default() -> Self {
247        SqliteConnectOptions::default()
248            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
249            .busy_timeout(Duration::from_secs(30))
250            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
251            .create_if_missing(true)
252            .into()
253    }
254}
255
256#[cfg(feature = "embedded-db")]
257impl From<SqliteConnectOptions> for Config {
258    fn from(db_opt: SqliteConnectOptions) -> Self {
259        Self {
260            db_opt,
261            pool_opt: PoolOptions::default(),
262            reset: false,
263            migrations: vec![],
264            no_migrations: false,
265            pruner_cfg: None,
266            archive: false,
267            pool: None,
268        }
269    }
270}
271
272#[cfg(not(feature = "embedded-db"))]
273impl From<PgConnectOptions> for Config {
274    fn from(db_opt: PgConnectOptions) -> Self {
275        Self {
276            db_opt,
277            pool_opt: PoolOptions::default(),
278            schema: "hotshot".into(),
279            reset: false,
280            migrations: vec![],
281            no_migrations: false,
282            pruner_cfg: None,
283            archive: false,
284            pool: None,
285        }
286    }
287}
288
289#[cfg(not(feature = "embedded-db"))]
290impl FromStr for Config {
291    type Err = <PgConnectOptions as FromStr>::Err;
292
293    fn from_str(s: &str) -> Result<Self, Self::Err> {
294        Ok(PgConnectOptions::from_str(s)?.into())
295    }
296}
297
298#[cfg(feature = "embedded-db")]
299impl FromStr for Config {
300    type Err = <SqliteConnectOptions as FromStr>::Err;
301
302    fn from_str(s: &str) -> Result<Self, Self::Err> {
303        Ok(SqliteConnectOptions::from_str(s)?.into())
304    }
305}
306
307#[cfg(feature = "embedded-db")]
308impl Config {
309    pub fn busy_timeout(mut self, timeout: Duration) -> Self {
310        self.db_opt = self.db_opt.busy_timeout(timeout);
311        self
312    }
313
314    pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
315        self.db_opt = self.db_opt.filename(path);
316        self
317    }
318}
319
320#[cfg(not(feature = "embedded-db"))]
321impl Config {
322    /// Set the hostname of the database server.
323    ///
324    /// The default is `localhost`.
325    pub fn host(mut self, host: impl Into<String>) -> Self {
326        self.db_opt = self.db_opt.host(&host.into());
327        self
328    }
329
330    /// Set the port on which to connect to the database.
331    ///
332    /// The default is 5432, the default Postgres port.
333    pub fn port(mut self, port: u16) -> Self {
334        self.db_opt = self.db_opt.port(port);
335        self
336    }
337
338    /// Set the DB user to connect as.
339    pub fn user(mut self, user: &str) -> Self {
340        self.db_opt = self.db_opt.username(user);
341        self
342    }
343
344    /// Set a password for connecting to the database.
345    pub fn password(mut self, password: &str) -> Self {
346        self.db_opt = self.db_opt.password(password);
347        self
348    }
349
350    /// Set the name of the database to connect to.
351    pub fn database(mut self, database: &str) -> Self {
352        self.db_opt = self.db_opt.database(database);
353        self
354    }
355
356    /// Use TLS for an encrypted connection to the database.
357    ///
358    /// Note that an encrypted connection may be established even if this option is not set, as long
359    /// as both the client and server support it. This option merely causes connection to fail if an
360    /// encrypted stream cannot be established.
361    pub fn tls(mut self) -> Self {
362        self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
363        self
364    }
365
366    /// Set the name of the schema to use for queries.
367    ///
368    /// The default schema is named `hotshot` and is created via the default migrations.
369    pub fn schema(mut self, schema: impl Into<String>) -> Self {
370        self.schema = schema.into();
371        self
372    }
373}
374
375impl Config {
376    /// Sets the database connection pool
377    /// This allows reusing an existing connection pool when building a new `SqlStorage` instance.
378    pub fn pool(mut self, pool: Pool<Db>) -> Self {
379        self.pool = Some(pool);
380        self
381    }
382
383    /// Reset the schema on connection.
384    ///
385    /// When this [`Config`] is used to [`connect`](Self::connect) a
386    /// [`SqlDataSource`](crate::data_source::SqlDataSource), if this option is set, the relevant
387    /// [`schema`](Self::schema) will first be dropped and then recreated, yielding a completely
388    /// fresh instance of the query service.
389    ///
390    /// This is a particularly useful capability for development and staging environments. Still, it
391    /// must be used with extreme caution, as using this will irrevocably delete any data pertaining
392    /// to the query service in the database.
393    pub fn reset_schema(mut self) -> Self {
394        self.reset = true;
395        self
396    }
397
398    /// Add custom migrations to run when connecting to the database.
399    pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
400        self.migrations.extend(migrations);
401        self
402    }
403
404    /// Skip all migrations when connecting to the database.
405    pub fn no_migrations(mut self) -> Self {
406        self.no_migrations = true;
407        self
408    }
409
410    /// Enable pruning with a given configuration.
411    ///
412    /// If [`archive`](Self::archive) was previously specified, this will override it.
413    pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
414        cfg.validate()?;
415        self.pruner_cfg = Some(cfg);
416        self.archive = false;
417        Ok(self)
418    }
419
420    /// Disable pruning and reconstruct previously pruned data.
421    ///
422    /// While running without pruning is the default behavior, the default will not try to
423    /// reconstruct data that was pruned in a previous run where pruning was enabled. This option
424    /// instructs the service to run without pruning _and_ reconstruct all previously pruned data by
425    /// fetching from peers.
426    ///
427    /// If [`pruner_cfg`](Self::pruner_cfg) was previously specified, this will override it.
428    pub fn archive(mut self) -> Self {
429        self.pruner_cfg = None;
430        self.archive = true;
431        self
432    }
433
434    /// Set the maximum idle time of a connection.
435    ///
436    /// Any connection which has been open and unused longer than this duration will be
437    /// automatically closed to reduce load on the server.
438    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
439        self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
440        self
441    }
442
443    /// Set the maximum lifetime of a connection.
444    ///
445    /// Any connection which has been open longer than this duration will be automatically closed
446    /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh
447    /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the
448    /// server implementation.
449    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
450        self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
451        self
452    }
453
454    /// Set the minimum number of connections to maintain at any time.
455    ///
456    /// The data source will, to the best of its ability, maintain at least `min` open connections
457    /// at all times. This can be used to reduce the latency hit of opening new connections when at
458    /// least this many simultaneous connections are frequently needed.
459    pub fn min_connections(mut self, min: u32) -> Self {
460        self.pool_opt = self.pool_opt.min_connections(min);
461        self
462    }
463
464    /// Set the maximum number of connections to maintain at any time.
465    ///
466    /// Once `max` connections are in use simultaneously, further attempts to acquire a connection
467    /// (or begin a transaction) will block until one of the existing connections is released.
468    pub fn max_connections(mut self, max: u32) -> Self {
469        self.pool_opt = self.pool_opt.max_connections(max);
470        self
471    }
472
473    /// Log at WARN level any time a SQL statement takes longer than `threshold`.
474    ///
475    /// The default threshold is 1s.
476    pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
477        self.db_opt = self
478            .db_opt
479            .log_slow_statements(LevelFilter::Warn, threshold);
480        self
481    }
482}
483
484/// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database.
485#[derive(Clone, Debug)]
486pub struct SqlStorage {
487    pool: Pool<Db>,
488    metrics: PrometheusMetrics,
489    pool_metrics: PoolMetrics,
490    pruner_cfg: Option<PrunerCfg>,
491}
492
493#[derive(Debug, Default)]
494pub struct Pruner {
495    pruned_height: Option<u64>,
496    target_height: Option<u64>,
497    minimum_retention_height: Option<u64>,
498}
499
500impl SqlStorage {
501    pub fn pool(&self) -> Pool<Db> {
502        self.pool.clone()
503    }
504    /// Connect to a remote database.
505    pub async fn connect(mut config: Config) -> Result<Self, Error> {
506        let metrics = PrometheusMetrics::default();
507        let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
508        let pool = config.pool_opt.clone();
509        let pruner_cfg = config.pruner_cfg;
510
511        // re-use the same pool if present and return early
512        if let Some(pool) = config.pool {
513            return Ok(Self {
514                metrics,
515                pool_metrics,
516                pool,
517                pruner_cfg,
518            });
519        }
520
521        #[cfg(not(feature = "embedded-db"))]
522        let schema = config.schema.clone();
523        #[cfg(not(feature = "embedded-db"))]
524        let pool = pool.after_connect(move |conn, _| {
525            let schema = config.schema.clone();
526            async move {
527                query(&format!("SET search_path TO {schema}"))
528                    .execute(conn)
529                    .await?;
530                Ok(())
531            }
532            .boxed()
533        });
534
535        #[cfg(feature = "embedded-db")]
536        if config.reset {
537            std::fs::remove_file(config.db_opt.get_filename())?;
538        }
539
540        let pool = pool.connect_with(config.db_opt).await?;
541
542        // Create or connect to the schema for this query service.
543        let mut conn = pool.acquire().await?;
544
545        #[cfg(not(feature = "embedded-db"))]
546        if config.reset {
547            query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
548                .execute(conn.as_mut())
549                .await?;
550        }
551
552        #[cfg(not(feature = "embedded-db"))]
553        query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
554            .execute(conn.as_mut())
555            .await?;
556
557        // Get migrations and interleave with custom migrations, sorting by version number.
558        validate_migrations(&mut config.migrations)?;
559        let migrations =
560            add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
561
562        // Get a migration runner. Depending on the config, we can either use this to actually run
563        // the migrations or just check if the database is up to date.
564        let runner = refinery::Runner::new(&migrations).set_grouped(true);
565
566        if config.no_migrations {
567            // We've been asked not to run any migrations. Abort if the DB is not already up to
568            // date.
569            let last_applied = runner
570                .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
571                .await?;
572            let last_expected = migrations.last();
573            if last_applied.as_ref() != last_expected {
574                return Err(Error::msg(format!(
575                    "DB is out of date: last applied migration is {last_applied:?}, but expected \
576                     {last_expected:?}"
577                )));
578            }
579        } else {
580            // Run migrations using `refinery`.
581            match runner.run_async(&mut Migrator::from(&mut conn)).await {
582                Ok(report) => {
583                    tracing::info!("ran DB migrations: {report:?}");
584                },
585                Err(err) => {
586                    tracing::error!("DB migrations failed: {:?}", err.report());
587                    Err(err)?;
588                },
589            }
590        }
591
592        if config.archive {
593            // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will
594            // reconstruct previously pruned data.
595            query("DELETE FROM pruned_height WHERE id = 1")
596                .execute(conn.as_mut())
597                .await?;
598        }
599
600        conn.close().await?;
601
602        Ok(Self {
603            pool,
604            pool_metrics,
605            metrics,
606            pruner_cfg,
607        })
608    }
609}
610
611impl PrunerConfig for SqlStorage {
612    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
613        self.pruner_cfg = Some(cfg);
614    }
615
616    fn get_pruning_config(&self) -> Option<PrunerCfg> {
617        self.pruner_cfg.clone()
618    }
619}
620
621impl HasMetrics for SqlStorage {
622    fn metrics(&self) -> &PrometheusMetrics {
623        &self.metrics
624    }
625}
626
627impl SqlStorage {
628    async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
629        let mut tx = self.read().await.map_err(|err| QueryError::Error {
630            message: err.to_string(),
631        })?;
632        let (Some(height),) =
633            query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
634                .fetch_one(tx.as_mut())
635                .await?
636        else {
637            return Ok(None);
638        };
639        Ok(Some(height as u64))
640    }
641
642    async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
643        let mut tx = self.read().await.map_err(|err| QueryError::Error {
644            message: err.to_string(),
645        })?;
646
647        // We order by timestamp and then height, even though logically this is no different than
648        // just ordering by height, since timestamps are monotonic. The reason is that this order
649        // allows the query planner to efficiently solve the where clause and presort the results
650        // based on the timestamp index. The remaining sort on height, which guarantees a unique
651        // block if multiple blocks have the same timestamp, is very efficient, because there are
652        // never more than a handful of blocks with the same timestamp.
653        let Some((height,)) = query_as::<(i64,)>(
654            "SELECT height FROM header
655              WHERE timestamp <= $1
656              ORDER BY timestamp DESC, height DESC
657              LIMIT 1",
658        )
659        .bind(timestamp)
660        .fetch_optional(tx.as_mut())
661        .await?
662        else {
663            return Ok(None);
664        };
665        Ok(Some(height as u64))
666    }
667
668    /// Get the stored VID share for a given block, if one exists.
669    pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
670    where
671        Types: NodeType,
672        Header<Types>: QueryableHeader<Types>,
673    {
674        let mut tx = self.read().await.map_err(|err| QueryError::Error {
675            message: err.to_string(),
676        })?;
677        let share = tx.vid_share(block_id).await?;
678        Ok(share)
679    }
680
681    /// Get the stored VID common data for a given block, if one exists.
682    pub async fn get_vid_common<Types: NodeType>(
683        &self,
684        block_id: BlockId<Types>,
685    ) -> QueryResult<VidCommonQueryData<Types>>
686    where
687        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
688        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
689    {
690        let mut tx = self.read().await.map_err(|err| QueryError::Error {
691            message: err.to_string(),
692        })?;
693        let common = tx.get_vid_common(block_id).await?;
694        Ok(common)
695    }
696
697    /// Get the stored VID common metadata for a given block, if one exists.
698    pub async fn get_vid_common_metadata<Types: NodeType>(
699        &self,
700        block_id: BlockId<Types>,
701    ) -> QueryResult<VidCommonMetadata<Types>>
702    where
703        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
704        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
705    {
706        let mut tx = self.read().await.map_err(|err| QueryError::Error {
707            message: err.to_string(),
708        })?;
709        let common = tx.get_vid_common_metadata(block_id).await?;
710        Ok(common)
711    }
712}
713
714#[async_trait]
715impl PruneStorage for SqlStorage {
716    type Pruner = Pruner;
717
718    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
719        let mut tx = self.read().await?;
720
721        #[cfg(not(feature = "embedded-db"))]
722        let query = "SELECT pg_database_size(current_database())";
723
724        #[cfg(feature = "embedded-db")]
725        let query = "
726            SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
727                     AS total_bytes";
728
729        let row = tx.fetch_one(query).await?;
730        let size: i64 = row.get(0);
731
732        Ok(size as u64)
733    }
734
735    /// Trigger incremental vacuum to free up space in the SQLite database.
736    /// Note: We don't vacuum the Postgres database,
737    /// as there is no manual trigger for incremental vacuum,
738    /// and a full vacuum can take a lot of time.
739    #[cfg(feature = "embedded-db")]
740    async fn vacuum(&self) -> anyhow::Result<()> {
741        let config = self.get_pruning_config().ok_or(QueryError::Error {
742            message: "Pruning config not found".to_string(),
743        })?;
744        let mut conn = self.pool().acquire().await?;
745        query(&format!(
746            "PRAGMA incremental_vacuum({})",
747            config.incremental_vacuum_pages()
748        ))
749        .execute(conn.as_mut())
750        .await?;
751        conn.close().await?;
752        Ok(())
753    }
754
755    /// Note: The prune operation may not immediately free up space even after rows are deleted.
756    /// This is because a vacuum operation may be necessary to reclaim more space.
757    /// PostgreSQL already performs auto vacuuming, so we are not including it here
758    /// as running a vacuum operation can be resource-intensive.
759    async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
760        let cfg = self.get_pruning_config().ok_or(QueryError::Error {
761            message: "Pruning config not found".to_string(),
762        })?;
763        let batch_size = cfg.batch_size();
764        let max_usage = cfg.max_usage();
765        let state_tables = cfg.state_tables();
766
767        // If a pruner run was already in progress, some variables may already be set,
768        // depending on whether a batch was deleted and which batch it was (target or minimum retention).
769        // This enables us to resume the pruner run from the exact heights.
770        // If any of these values are not set, they can be loaded from the database if necessary.
771        let mut minimum_retention_height = pruner.minimum_retention_height;
772        let mut target_height = pruner.target_height;
773        let mut height = match pruner.pruned_height {
774            Some(h) => h,
775            None => {
776                let Some(height) = self.get_minimum_height().await? else {
777                    tracing::info!("database is empty, nothing to prune");
778                    return Ok(None);
779                };
780
781                height
782            },
783        };
784
785        // Prune data exceeding target retention in batches
786        if pruner.target_height.is_none() {
787            let th = self
788                .get_height_by_timestamp(
789                    Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
790                )
791                .await?;
792            target_height = th;
793            pruner.target_height = target_height;
794        };
795
796        if let Some(target_height) = target_height {
797            if height < target_height {
798                height = min(height + batch_size, target_height);
799                let mut tx = self.write().await?;
800                tx.delete_batch(state_tables, height).await?;
801                tx.commit().await.map_err(|e| QueryError::Error {
802                    message: format!("failed to commit {e}"),
803                })?;
804                pruner.pruned_height = Some(height);
805                return Ok(Some(height));
806            }
807        }
808
809        // If threshold is set, prune data exceeding minimum retention in batches
810        // This parameter is needed for SQL storage as there is no direct way to get free space.
811        if let Some(threshold) = cfg.pruning_threshold() {
812            let usage = self.get_disk_usage().await?;
813
814            // Prune data exceeding minimum retention in batches starting from minimum height
815            // until usage is below threshold
816            if usage > threshold {
817                tracing::warn!(
818                    "Disk usage {usage} exceeds pruning threshold {:?}",
819                    cfg.pruning_threshold()
820                );
821
822                if minimum_retention_height.is_none() {
823                    minimum_retention_height = self
824                        .get_height_by_timestamp(
825                            Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
826                        )
827                        .await?;
828
829                    pruner.minimum_retention_height = minimum_retention_height;
830                }
831
832                if let Some(min_retention_height) = minimum_retention_height {
833                    if (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
834                        && height < min_retention_height
835                    {
836                        height = min(height + batch_size, min_retention_height);
837                        let mut tx = self.write().await?;
838                        tx.delete_batch(state_tables, height).await?;
839                        tx.commit().await.map_err(|e| QueryError::Error {
840                            message: format!("failed to commit {e}"),
841                        })?;
842
843                        self.vacuum().await?;
844
845                        pruner.pruned_height = Some(height);
846
847                        return Ok(Some(height));
848                    }
849                }
850            }
851        }
852
853        Ok(None)
854    }
855}
856
857impl VersionedDataSource for SqlStorage {
858    type Transaction<'a>
859        = Transaction<Write>
860    where
861        Self: 'a;
862    type ReadOnly<'a>
863        = Transaction<Read>
864    where
865        Self: 'a;
866
867    async fn write(&self) -> anyhow::Result<Transaction<Write>> {
868        Transaction::new(&self.pool, self.pool_metrics.clone()).await
869    }
870
871    async fn read(&self) -> anyhow::Result<Transaction<Read>> {
872        Transaction::new(&self.pool, self.pool_metrics.clone()).await
873    }
874}
875
876#[async_trait]
877pub trait MigrateTypes<Types: NodeType> {
878    async fn migrate_types(&self, batch_size: u64) -> anyhow::Result<()>;
879}
880
881#[async_trait]
882impl<Types: NodeType> MigrateTypes<Types> for SqlStorage {
883    async fn migrate_types(&self, batch_size: u64) -> anyhow::Result<()> {
884        let limit = batch_size;
885        let mut tx = self.read().await.map_err(|err| QueryError::Error {
886            message: err.to_string(),
887        })?;
888
889        // The table `types_migration` is populated in the SQL migration with `completed = false` and `migrated_rows = 0`
890        // so fetch_one() would always return a row.
891        // After each batch insert, it is updated with the number of rows migrated.
892        // This is necessary to resume from the same point in case of a restart.
893        let (is_migration_completed, mut offset) = query_as::<(bool, i64)>(
894            "SELECT completed, migrated_rows from types_migration WHERE id = 1 ",
895        )
896        .fetch_one(tx.as_mut())
897        .await?;
898
899        if is_migration_completed {
900            tracing::info!("types migration already completed");
901            return Ok(());
902        }
903
904        tracing::warn!(
905            "migrating query service types storage. Offset={offset}, batch_size={limit}"
906        );
907
908        loop {
909            let mut tx = self.read().await.map_err(|err| QueryError::Error {
910                message: err.to_string(),
911            })?;
912
913            let rows = QueryBuilder::default()
914                .query(
915                    "SELECT leaf, qc, common as vid_common, share as vid_share
916                    FROM leaf INNER JOIN vid on leaf.height = vid.height 
917                    WHERE leaf.height >= $1 AND leaf.height < $2",
918                )
919                .bind(offset)
920                .bind(offset + limit as i64)
921                .fetch_all(tx.as_mut())
922                .await?;
923
924            drop(tx);
925
926            if rows.is_empty() {
927                break;
928            }
929
930            let mut leaf_rows = Vec::new();
931            let mut vid_rows = Vec::new();
932
933            for row in rows.iter() {
934                let leaf1 = row.try_get("leaf")?;
935                let qc = row.try_get("qc")?;
936                let leaf1: Leaf<Types> = serde_json::from_value(leaf1)?;
937                let qc: QuorumCertificate<Types> = serde_json::from_value(qc)?;
938
939                let leaf2: Leaf2<Types> = leaf1.into();
940                let qc2: QuorumCertificate2<Types> = qc.to_qc2();
941
942                let commit = leaf2.commit();
943
944                let leaf2_json =
945                    serde_json::to_value(leaf2.clone()).context("failed to serialize leaf2")?;
946                let qc2_json = serde_json::to_value(qc2).context("failed to serialize QC2")?;
947
948                let vid_common_bytes: Vec<u8> = row.try_get("vid_common")?;
949                let vid_share_bytes: Option<Vec<u8>> = row.try_get("vid_share")?;
950
951                let mut new_vid_share_bytes = None;
952
953                if let Some(vid_share_bytes) = vid_share_bytes {
954                    let vid_share: ADVZShare = bincode::deserialize(&vid_share_bytes)
955                        .context("failed to deserialize vid_share")?;
956                    new_vid_share_bytes = Some(
957                        bincode::serialize(&VidShare::V0(vid_share))
958                            .context("failed to serialize vid_share")?,
959                    );
960                }
961
962                let vid_common: ADVZCommon = bincode::deserialize(&vid_common_bytes)
963                    .context("failed to deserialize vid_common")?;
964                let new_vid_common_bytes = bincode::serialize(&VidCommon::V0(vid_common))
965                    .context("failed to serialize vid_common")?;
966
967                vid_rows.push((
968                    leaf2.height() as i64,
969                    new_vid_common_bytes,
970                    new_vid_share_bytes,
971                ));
972                leaf_rows.push((
973                    leaf2.height() as i64,
974                    commit.to_string(),
975                    leaf2.block_header().commit().to_string(),
976                    leaf2_json,
977                    qc2_json,
978                ));
979            }
980
981            // migrate leaf2
982            let mut query_builder: sqlx::QueryBuilder<Db> =
983                sqlx::QueryBuilder::new("INSERT INTO leaf2 (height, hash, block_hash, leaf, qc) ");
984
985            // Advance the `offset` to the highest `leaf.height` processed in this batch.
986            // This ensures the next iteration starts from the next unseen leaf
987            offset += limit as i64;
988
989            query_builder.push_values(leaf_rows.into_iter(), |mut b, row| {
990                b.push_bind(row.0)
991                    .push_bind(row.1)
992                    .push_bind(row.2)
993                    .push_bind(row.3)
994                    .push_bind(row.4);
995            });
996
997            query_builder.push(" ON CONFLICT DO NOTHING");
998            let query = query_builder.build();
999
1000            let mut tx = self.write().await.map_err(|err| QueryError::Error {
1001                message: err.to_string(),
1002            })?;
1003
1004            query.execute(tx.as_mut()).await?;
1005
1006            // update migrated_rows column with the offset
1007            tx.upsert(
1008                "types_migration",
1009                ["id", "completed", "migrated_rows"],
1010                ["id"],
1011                [(1_i64, false, offset)],
1012            )
1013            .await?;
1014
1015            // migrate vid
1016            let mut query_builder: sqlx::QueryBuilder<Db> =
1017                sqlx::QueryBuilder::new("INSERT INTO vid2 (height, common, share) ");
1018
1019            query_builder.push_values(vid_rows.into_iter(), |mut b, row| {
1020                b.push_bind(row.0).push_bind(row.1).push_bind(row.2);
1021            });
1022            query_builder.push(" ON CONFLICT DO NOTHING");
1023            let query = query_builder.build();
1024
1025            query.execute(tx.as_mut()).await?;
1026
1027            tx.commit().await?;
1028
1029            tracing::warn!("Migrated leaf and vid: offset={offset}");
1030
1031            tracing::info!("offset={offset}");
1032            if rows.len() < limit as usize {
1033                break;
1034            }
1035        }
1036
1037        let mut tx = self.write().await.map_err(|err| QueryError::Error {
1038            message: err.to_string(),
1039        })?;
1040
1041        tracing::warn!("query service types migration is completed!");
1042
1043        tx.upsert(
1044            "types_migration",
1045            ["id", "completed", "migrated_rows"],
1046            ["id"],
1047            [(1_i64, true, offset)],
1048        )
1049        .await?;
1050
1051        tracing::info!("updated types_migration table");
1052
1053        tx.commit().await?;
1054        Ok(())
1055    }
1056}
1057
1058// These tests run the `postgres` Docker image, which doesn't work on Windows.
1059#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
1060pub mod testing {
1061    #![allow(unused_imports)]
1062    use std::{
1063        env,
1064        process::{Command, Stdio},
1065        str::{self, FromStr},
1066        time::Duration,
1067    };
1068
1069    use portpicker::pick_unused_port;
1070    use refinery::Migration;
1071    use tokio::{net::TcpStream, time::timeout};
1072
1073    use super::Config;
1074    use crate::testing::sleep;
1075    #[derive(Debug)]
1076    pub struct TmpDb {
1077        #[cfg(not(feature = "embedded-db"))]
1078        host: String,
1079        #[cfg(not(feature = "embedded-db"))]
1080        port: u16,
1081        #[cfg(not(feature = "embedded-db"))]
1082        container_id: String,
1083        #[cfg(feature = "embedded-db")]
1084        db_path: std::path::PathBuf,
1085        #[allow(dead_code)]
1086        persistent: bool,
1087    }
1088    impl TmpDb {
1089        #[cfg(feature = "embedded-db")]
1090        fn init_sqlite_db(persistent: bool) -> Self {
1091            let file = tempfile::Builder::new()
1092                .prefix("sqlite-")
1093                .suffix(".db")
1094                .tempfile()
1095                .unwrap();
1096
1097            let (_, db_path) = file.keep().unwrap();
1098
1099            Self {
1100                db_path,
1101                persistent,
1102            }
1103        }
1104        pub async fn init() -> Self {
1105            #[cfg(feature = "embedded-db")]
1106            return Self::init_sqlite_db(false);
1107
1108            #[cfg(not(feature = "embedded-db"))]
1109            Self::init_postgres(false).await
1110        }
1111
1112        pub async fn persistent() -> Self {
1113            #[cfg(feature = "embedded-db")]
1114            return Self::init_sqlite_db(true);
1115
1116            #[cfg(not(feature = "embedded-db"))]
1117            Self::init_postgres(true).await
1118        }
1119
1120        #[cfg(not(feature = "embedded-db"))]
1121        async fn init_postgres(persistent: bool) -> Self {
1122            let docker_hostname = env::var("DOCKER_HOSTNAME");
1123            // This picks an unused port on the current system.  If docker is
1124            // configured to run on a different host then this may not find a
1125            // "free" port on that system.
1126            // We *might* be able to get away with this as any remote docker
1127            // host should hopefully be pretty open with it's port space.
1128            let port = pick_unused_port().unwrap();
1129            let host = docker_hostname.unwrap_or("localhost".to_string());
1130
1131            let mut cmd = Command::new("docker");
1132            cmd.arg("run")
1133                .arg("-d")
1134                .args(["-p", &format!("{port}:5432")])
1135                .args(["-e", "POSTGRES_PASSWORD=password"]);
1136
1137            if !persistent {
1138                cmd.arg("--rm");
1139            }
1140
1141            let output = cmd.arg("postgres").output().unwrap();
1142            let stdout = str::from_utf8(&output.stdout).unwrap();
1143            let stderr = str::from_utf8(&output.stderr).unwrap();
1144            if !output.status.success() {
1145                panic!("failed to start postgres docker: {stderr}");
1146            }
1147
1148            // Create the TmpDb object immediately after starting the Docker container, so if
1149            // anything panics after this `drop` will be called and we will clean up.
1150            let container_id = stdout.trim().to_owned();
1151            tracing::info!("launched postgres docker {container_id}");
1152            let db = Self {
1153                host,
1154                port,
1155                container_id: container_id.clone(),
1156                persistent,
1157            };
1158
1159            db.wait_for_ready().await;
1160            db
1161        }
1162
1163        #[cfg(not(feature = "embedded-db"))]
1164        pub fn host(&self) -> String {
1165            self.host.clone()
1166        }
1167
1168        #[cfg(not(feature = "embedded-db"))]
1169        pub fn port(&self) -> u16 {
1170            self.port
1171        }
1172
1173        #[cfg(feature = "embedded-db")]
1174        pub fn path(&self) -> std::path::PathBuf {
1175            self.db_path.clone()
1176        }
1177
1178        pub fn config(&self) -> Config {
1179            #[cfg(feature = "embedded-db")]
1180            let mut cfg = Config::default().db_path(self.db_path.clone());
1181
1182            #[cfg(not(feature = "embedded-db"))]
1183            let mut cfg = Config::default()
1184                .user("postgres")
1185                .password("password")
1186                .host(self.host())
1187                .port(self.port());
1188
1189            cfg = cfg.migrations(vec![Migration::unapplied(
1190                "V101__create_test_merkle_tree_table.sql",
1191                &TestMerkleTreeMigration::create("test_tree"),
1192            )
1193            .unwrap()]);
1194
1195            cfg
1196        }
1197
1198        #[cfg(not(feature = "embedded-db"))]
1199        pub fn stop_postgres(&mut self) {
1200            tracing::info!(container = self.container_id, "stopping postgres");
1201            let output = Command::new("docker")
1202                .args(["stop", self.container_id.as_str()])
1203                .output()
1204                .unwrap();
1205            assert!(
1206                output.status.success(),
1207                "error killing postgres docker {}: {}",
1208                self.container_id,
1209                str::from_utf8(&output.stderr).unwrap()
1210            );
1211        }
1212
1213        #[cfg(not(feature = "embedded-db"))]
1214        pub async fn start_postgres(&mut self) {
1215            tracing::info!(container = self.container_id, "resuming postgres");
1216            let output = Command::new("docker")
1217                .args(["start", self.container_id.as_str()])
1218                .output()
1219                .unwrap();
1220            assert!(
1221                output.status.success(),
1222                "error starting postgres docker {}: {}",
1223                self.container_id,
1224                str::from_utf8(&output.stderr).unwrap()
1225            );
1226
1227            self.wait_for_ready().await;
1228        }
1229
1230        #[cfg(not(feature = "embedded-db"))]
1231        async fn wait_for_ready(&self) {
1232            let timeout_duration = Duration::from_secs(
1233                env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1234                    .unwrap_or("60".to_string())
1235                    .parse()
1236                    .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1237            );
1238
1239            if let Err(err) = timeout(timeout_duration, async {
1240                while Command::new("docker")
1241                    .args([
1242                        "exec",
1243                        &self.container_id,
1244                        "pg_isready",
1245                        "-h",
1246                        "localhost",
1247                        "-U",
1248                        "postgres",
1249                    ])
1250                    .env("PGPASSWORD", "password")
1251                    // Null input so the command terminates as soon as it manages to connect.
1252                    .stdin(Stdio::null())
1253                    // Discard command output.
1254                    .stdout(Stdio::null())
1255                    .stderr(Stdio::null())
1256                    .status()
1257                    // We should ensure the exit status. A simple `unwrap`
1258                    // would panic on unrelated errors (such as network
1259                    // connection failures)
1260                    .and_then(|status| {
1261                        status
1262                            .success()
1263                            .then_some(true)
1264                            // Any ol' Error will do
1265                            .ok_or(std::io::Error::from_raw_os_error(666))
1266                    })
1267                    .is_err()
1268                {
1269                    tracing::warn!("database is not ready");
1270                    sleep(Duration::from_secs(1)).await;
1271                }
1272
1273                // The above command ensures the database is ready inside the Docker container.
1274                // However, on some systems, there is a slight delay before the port is exposed via
1275                // host networking. We don't need to check again that the database is ready on the
1276                // host (and maybe can't, because the host might not have pg_isready installed), but
1277                // we can ensure the port is open by just establishing a TCP connection.
1278                while let Err(err) =
1279                    TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1280                {
1281                    tracing::warn!("database is ready, but port is not available to host: {err:#}");
1282                    sleep(Duration::from_millis(100)).await;
1283                }
1284            })
1285            .await
1286            {
1287                panic!(
1288                    "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1289                     {err:#}\n{}",
1290                    "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1291                );
1292            }
1293        }
1294    }
1295
1296    #[cfg(not(feature = "embedded-db"))]
1297    impl Drop for TmpDb {
1298        fn drop(&mut self) {
1299            self.stop_postgres();
1300        }
1301    }
1302
1303    #[cfg(feature = "embedded-db")]
1304    impl Drop for TmpDb {
1305        fn drop(&mut self) {
1306            if !self.persistent {
1307                std::fs::remove_file(self.db_path.clone()).unwrap();
1308            }
1309        }
1310    }
1311
1312    pub struct TestMerkleTreeMigration;
1313
1314    impl TestMerkleTreeMigration {
1315        fn create(name: &str) -> String {
1316            let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1317                (
1318                    "TEXT",
1319                    "BLOB",
1320                    "INTEGER PRIMARY KEY AUTOINCREMENT",
1321                    " (json_extract(data, '$.test_merkle_tree_root'))",
1322                )
1323            } else {
1324                (
1325                    "BIT(8)",
1326                    "BYTEA",
1327                    "SERIAL PRIMARY KEY",
1328                    "(data->>'test_merkle_tree_root')",
1329                )
1330            };
1331
1332            format!(
1333                "CREATE TABLE IF NOT EXISTS hash
1334            (
1335                id {hash_pk},
1336                value {binary}  NOT NULL UNIQUE
1337            );
1338    
1339            ALTER TABLE header
1340            ADD column test_merkle_tree_root text
1341            GENERATED ALWAYS as {root_stored_column} STORED;
1342
1343            CREATE TABLE {name}
1344            (
1345                path JSONB NOT NULL, 
1346                created BIGINT NOT NULL,
1347                hash_id INT NOT NULL,
1348                children JSONB,
1349                children_bitvec {bit_vec},
1350                idx JSONB,
1351                entry JSONB,
1352                PRIMARY KEY (path, created)
1353            );
1354            CREATE INDEX {name}_created ON {name} (created);"
1355            )
1356        }
1357    }
1358}
1359
1360// These tests run the `postgres` Docker image, which doesn't work on Windows.
1361#[cfg(all(test, not(target_os = "windows")))]
1362mod test {
1363    use std::time::Duration;
1364
1365    use committable::{Commitment, CommitmentBoundsArkless, Committable};
1366    use hotshot::traits::BlockPayload;
1367    use hotshot_example_types::{
1368        node_types::TestVersions,
1369        state_types::{TestInstanceState, TestValidatedState},
1370    };
1371    use hotshot_types::{
1372        data::{QuorumProposal, ViewNumber},
1373        simple_vote::QuorumData,
1374        traits::{
1375            block_contents::{BlockHeader, GENESIS_VID_NUM_STORAGE_NODES},
1376            node_implementation::{ConsensusTime, Versions},
1377            EncodeBytes,
1378        },
1379        vid::advz::advz_scheme,
1380    };
1381    use jf_merkle_tree::{
1382        prelude::UniversalMerkleTree, MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme,
1383    };
1384    use jf_vid::VidScheme;
1385    use tokio::time::sleep;
1386    use vbs::version::StaticVersionType;
1387
1388    use super::{testing::TmpDb, *};
1389    use crate::{
1390        availability::LeafQueryData,
1391        data_source::storage::{pruning::PrunedHeightStorage, UpdateAvailabilityStorage},
1392        merklized_state::{MerklizedState, UpdateStateData},
1393        testing::mocks::{MockHeader, MockMerkleTree, MockPayload, MockTypes, MockVersions},
1394    };
1395
1396    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1397    async fn test_migrations() {
1398        let db = TmpDb::init().await;
1399        let cfg = db.config();
1400
1401        let connect = |migrations: bool, custom_migrations| {
1402            let cfg = cfg.clone();
1403            async move {
1404                let mut cfg = cfg.migrations(custom_migrations);
1405                if !migrations {
1406                    cfg = cfg.no_migrations();
1407                }
1408                let client = SqlStorage::connect(cfg).await?;
1409                Ok::<_, Error>(client)
1410            }
1411        };
1412
1413        // Connecting with migrations disabled should fail if the database is not already up to date
1414        // (since we've just created a fresh database, it isn't).
1415        let err = connect(false, vec![]).await.unwrap_err();
1416        tracing::info!("connecting without running migrations failed as expected: {err}");
1417
1418        // Now connect and run migrations to bring the database up to date.
1419        connect(true, vec![]).await.unwrap();
1420        // Now connecting without migrations should work.
1421        connect(false, vec![]).await.unwrap();
1422
1423        // Connect with some custom migrations, to advance the schema even further. Pass in the
1424        // custom migrations out of order; they should still execute in order of version number.
1425        // The SQL commands used here will fail if not run in order.
1426        let migrations = vec![
1427            Migration::unapplied(
1428                "V9999__create_test_table.sql",
1429                "ALTER TABLE test ADD COLUMN data INTEGER;",
1430            )
1431            .unwrap(),
1432            Migration::unapplied(
1433                "V9998__create_test_table.sql",
1434                "CREATE TABLE test (x bigint);",
1435            )
1436            .unwrap(),
1437        ];
1438        connect(true, migrations.clone()).await.unwrap();
1439
1440        // Connect using the default schema (no custom migrations) and not running migrations. This
1441        // should fail because the database is _ahead_ of the client in terms of schema.
1442        let err = connect(false, vec![]).await.unwrap_err();
1443        tracing::info!("connecting without running migrations failed as expected: {err}");
1444
1445        // Connecting with the customized schema should work even without running migrations.
1446        connect(true, migrations).await.unwrap();
1447    }
1448
1449    #[test]
1450    #[cfg(not(feature = "embedded-db"))]
1451    fn test_config_from_str() {
1452        let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1453        assert_eq!(cfg.db_opt.get_username(), "user");
1454        assert_eq!(cfg.db_opt.get_host(), "host");
1455        assert_eq!(cfg.db_opt.get_port(), 8080);
1456    }
1457
1458    #[test]
1459    #[cfg(feature = "embedded-db")]
1460    fn test_config_from_str() {
1461        let cfg = Config::from_str("sqlite://data.db").unwrap();
1462        assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1463    }
1464
1465    async fn vacuum(storage: &SqlStorage) {
1466        #[cfg(feature = "embedded-db")]
1467        let query = "PRAGMA incremental_vacuum(16000)";
1468        #[cfg(not(feature = "embedded-db"))]
1469        let query = "VACUUM";
1470        storage
1471            .pool
1472            .acquire()
1473            .await
1474            .unwrap()
1475            .execute(query)
1476            .await
1477            .unwrap();
1478    }
1479
1480    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1481    async fn test_target_period_pruning() {
1482        let db = TmpDb::init().await;
1483        let cfg = db.config();
1484
1485        let mut storage = SqlStorage::connect(cfg).await.unwrap();
1486        let mut leaf = LeafQueryData::<MockTypes>::genesis::<TestVersions>(
1487            &TestValidatedState::default(),
1488            &TestInstanceState::default(),
1489        )
1490        .await;
1491        // insert some mock data
1492        for i in 0..20 {
1493            leaf.leaf.block_header_mut().block_number = i;
1494            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1495            let mut tx = storage.write().await.unwrap();
1496            tx.insert_leaf(leaf.clone()).await.unwrap();
1497            tx.commit().await.unwrap();
1498        }
1499
1500        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1501
1502        // Set pruner config to default which has minimum retention set to 1 day
1503        storage.set_pruning_config(PrunerCfg::new());
1504        // No data will be pruned
1505        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1506
1507        // Vacuum the database to reclaim space.
1508        // This is necessary to ensure the test passes.
1509        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1510        vacuum(&storage).await;
1511        // Pruned height should be none
1512        assert!(pruned_height.is_none());
1513
1514        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1515
1516        assert_eq!(
1517            height_after_pruning, height_before_pruning,
1518            "some data has been pruned"
1519        );
1520
1521        // Set pruner config to target retention set to 1s
1522        storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1523        sleep(Duration::from_secs(2)).await;
1524        let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1525        // All of the data is now older than 1s.
1526        // This would prune all the data as the target retention is set to 1s
1527        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1528        // Vacuum the database to reclaim space.
1529        // This is necessary to ensure the test passes.
1530        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1531        vacuum(&storage).await;
1532
1533        // Pruned height should be some
1534        assert!(pruned_height.is_some());
1535        let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1536        // All the tables should be empty
1537        // counting rows in header table
1538        let header_rows = storage
1539            .read()
1540            .await
1541            .unwrap()
1542            .fetch_one("select count(*) as count from header")
1543            .await
1544            .unwrap()
1545            .get::<i64, _>("count");
1546        // the table should be empty
1547        assert_eq!(header_rows, 0);
1548
1549        // counting rows in leaf table.
1550        // Deleting rows from header table would delete rows in all the tables
1551        // as each of table implement "ON DELETE CASCADE" fk constraint with the header table.
1552        let leaf_rows = storage
1553            .read()
1554            .await
1555            .unwrap()
1556            .fetch_one("select count(*) as count from leaf")
1557            .await
1558            .unwrap()
1559            .get::<i64, _>("count");
1560        // the table should be empty
1561        assert_eq!(leaf_rows, 0);
1562
1563        assert!(
1564            usage_before_pruning > usage_after_pruning,
1565            " disk usage should decrease after pruning"
1566        )
1567    }
1568
1569    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1570    async fn test_merklized_state_pruning() {
1571        let db = TmpDb::init().await;
1572        let config = db.config();
1573
1574        let storage = SqlStorage::connect(config).await.unwrap();
1575        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1576            MockMerkleTree::new(MockMerkleTree::tree_height());
1577
1578        // insert some entries into the tree and the header table
1579        // Header table is used the get_path query to check if the header exists for the block height.
1580        let mut tx = storage.write().await.unwrap();
1581
1582        for block_height in 0..250 {
1583            test_tree.update(block_height, block_height).unwrap();
1584
1585            // data field of the header
1586            let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1587            tx.upsert(
1588                "header",
1589                ["height", "hash", "payload_hash", "timestamp", "data"],
1590                ["height"],
1591                [(
1592                    block_height as i64,
1593                    format!("randomHash{block_height}"),
1594                    "t".to_string(),
1595                    0,
1596                    test_data,
1597                )],
1598            )
1599            .await
1600            .unwrap();
1601            // proof for the index from the tree
1602            let (_, proof) = test_tree.lookup(block_height).expect_ok().unwrap();
1603            // traversal path for the index.
1604            let traversal_path =
1605                <usize as ToTraversalPath<8>>::to_traversal_path(&block_height, test_tree.height());
1606
1607            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1608                &mut tx,
1609                proof.clone(),
1610                traversal_path.clone(),
1611                block_height as u64,
1612            )
1613            .await
1614            .expect("failed to insert nodes");
1615        }
1616
1617        // update saved state height
1618        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 250)
1619            .await
1620            .unwrap();
1621        tx.commit().await.unwrap();
1622
1623        let mut tx = storage.read().await.unwrap();
1624
1625        // checking if the data is inserted correctly
1626        // there should be multiple nodes with same index but different created time
1627        let (count,) = query_as::<(i64,)>(
1628            " SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1629             count(*) > 1)",
1630        )
1631        .fetch_one(tx.as_mut())
1632        .await
1633        .unwrap();
1634
1635        tracing::info!("Number of nodes with multiple snapshots : {count}");
1636        assert!(count > 0);
1637
1638        // This should delete all the nodes having height < 250 and is not the newest node with its position
1639        let mut tx = storage.write().await.unwrap();
1640        tx.delete_batch(vec!["test_tree".to_string()], 250)
1641            .await
1642            .unwrap();
1643
1644        tx.commit().await.unwrap();
1645        let mut tx = storage.read().await.unwrap();
1646        let (count,) = query_as::<(i64,)>(
1647            "SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1648             count(*) > 1)",
1649        )
1650        .fetch_one(tx.as_mut())
1651        .await
1652        .unwrap();
1653
1654        tracing::info!("Number of nodes with multiple snapshots : {count}");
1655
1656        assert!(count == 0);
1657    }
1658
1659    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1660    async fn test_minimum_retention_pruning() {
1661        let db = TmpDb::init().await;
1662
1663        let mut storage = SqlStorage::connect(db.config()).await.unwrap();
1664        let mut leaf = LeafQueryData::<MockTypes>::genesis::<TestVersions>(
1665            &TestValidatedState::default(),
1666            &TestInstanceState::default(),
1667        )
1668        .await;
1669        // insert some mock data
1670        for i in 0..20 {
1671            leaf.leaf.block_header_mut().block_number = i;
1672            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1673            let mut tx = storage.write().await.unwrap();
1674            tx.insert_leaf(leaf.clone()).await.unwrap();
1675            tx.commit().await.unwrap();
1676        }
1677
1678        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1679        let cfg = PrunerCfg::new();
1680        // Set pruning_threshold to 1
1681        // SQL storage size is more than 1000 bytes even without any data indexed
1682        // This would mean that the threshold would always be greater than the disk usage
1683        // However, minimum retention is set to 24 hours by default so the data would not be pruned
1684        storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1685        println!("{:?}", storage.get_pruning_config().unwrap());
1686        // Pruning would not delete any data
1687        // All the data is younger than minimum retention period even though the usage > threshold
1688        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1689        // Vacuum the database to reclaim space.
1690        // This is necessary to ensure the test passes.
1691        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1692        vacuum(&storage).await;
1693
1694        // Pruned height should be none
1695        assert!(pruned_height.is_none());
1696
1697        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1698
1699        assert_eq!(
1700            height_after_pruning, height_before_pruning,
1701            "some data has been pruned"
1702        );
1703
1704        // Change minimum retention to 1s
1705        storage.set_pruning_config(
1706            cfg.with_minimum_retention(Duration::from_secs(1))
1707                .with_pruning_threshold(1),
1708        );
1709        // sleep for 2s to make sure the data is older than minimum retention
1710        sleep(Duration::from_secs(2)).await;
1711        // This would prune all the data
1712        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1713        // Vacuum the database to reclaim space.
1714        // This is necessary to ensure the test passes.
1715        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1716        vacuum(&storage).await;
1717
1718        // Pruned height should be some
1719        assert!(pruned_height.is_some());
1720        // All the tables should be empty
1721        // counting rows in header table
1722        let header_rows = storage
1723            .read()
1724            .await
1725            .unwrap()
1726            .fetch_one("select count(*) as count from header")
1727            .await
1728            .unwrap()
1729            .get::<i64, _>("count");
1730        // the table should be empty
1731        assert_eq!(header_rows, 0);
1732    }
1733
1734    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1735    async fn test_pruned_height_storage() {
1736        let db = TmpDb::init().await;
1737        let cfg = db.config();
1738
1739        let storage = SqlStorage::connect(cfg).await.unwrap();
1740        assert!(storage
1741            .read()
1742            .await
1743            .unwrap()
1744            .load_pruned_height()
1745            .await
1746            .unwrap()
1747            .is_none());
1748        for height in [10, 20, 30] {
1749            let mut tx = storage.write().await.unwrap();
1750            tx.save_pruned_height(height).await.unwrap();
1751            tx.commit().await.unwrap();
1752            assert_eq!(
1753                storage
1754                    .read()
1755                    .await
1756                    .unwrap()
1757                    .load_pruned_height()
1758                    .await
1759                    .unwrap(),
1760                Some(height)
1761            );
1762        }
1763    }
1764
1765    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1766    async fn test_types_migration() {
1767        let num_rows = 500;
1768        let db = TmpDb::init().await;
1769
1770        let storage = SqlStorage::connect(db.config()).await.unwrap();
1771
1772        for i in 0..num_rows {
1773            let view = ViewNumber::new(i);
1774            let validated_state = TestValidatedState::default();
1775            let instance_state = TestInstanceState::default();
1776
1777            let (payload, metadata) = <MockPayload as BlockPayload<MockTypes>>::from_transactions(
1778                [],
1779                &validated_state,
1780                &instance_state,
1781            )
1782            .await
1783            .unwrap();
1784
1785            let mut block_header = <MockHeader as BlockHeader<MockTypes>>::genesis::<MockVersions>(
1786                &instance_state,
1787                payload.clone(),
1788                &metadata,
1789            );
1790
1791            block_header.block_number = i;
1792
1793            let null_quorum_data = QuorumData {
1794                leaf_commit: Commitment::<Leaf<MockTypes>>::default_commitment_no_preimage(),
1795            };
1796
1797            let mut qc = QuorumCertificate::new(
1798                null_quorum_data.clone(),
1799                null_quorum_data.commit(),
1800                view,
1801                None,
1802                std::marker::PhantomData,
1803            );
1804
1805            let quorum_proposal = QuorumProposal {
1806                block_header,
1807                view_number: view,
1808                justify_qc: qc.clone(),
1809                upgrade_certificate: None,
1810                proposal_certificate: None,
1811            };
1812
1813            let mut leaf = Leaf::from_quorum_proposal(&quorum_proposal);
1814            leaf.fill_block_payload::<MockVersions>(
1815                payload.clone(),
1816                GENESIS_VID_NUM_STORAGE_NODES,
1817                <MockVersions as Versions>::Base::VERSION,
1818            )
1819            .unwrap();
1820            qc.data.leaf_commit = <Leaf<MockTypes> as Committable>::commit(&leaf);
1821
1822            let height = leaf.height() as i64;
1823            let hash = <Leaf<_> as Committable>::commit(&leaf).to_string();
1824            let header = leaf.block_header();
1825
1826            let header_json = serde_json::to_value(header)
1827                .context("failed to serialize header")
1828                .unwrap();
1829
1830            let payload_commitment =
1831                <MockHeader as BlockHeader<MockTypes>>::payload_commitment(header);
1832            let mut tx = storage.write().await.unwrap();
1833
1834            tx.upsert(
1835                "header",
1836                ["height", "hash", "payload_hash", "data", "timestamp"],
1837                ["height"],
1838                [(
1839                    height,
1840                    leaf.block_header().commit().to_string(),
1841                    payload_commitment.to_string(),
1842                    header_json,
1843                    <MockHeader as BlockHeader<MockTypes>>::timestamp(leaf.block_header()) as i64,
1844                )],
1845            )
1846            .await
1847            .unwrap();
1848
1849            let leaf_json = serde_json::to_value(leaf.clone()).expect("failed to serialize leaf");
1850            let qc_json = serde_json::to_value(qc).expect("failed to serialize QC");
1851            tx.upsert(
1852                "leaf",
1853                ["height", "hash", "block_hash", "leaf", "qc"],
1854                ["height"],
1855                [(
1856                    height,
1857                    hash,
1858                    header.commit().to_string(),
1859                    leaf_json,
1860                    qc_json,
1861                )],
1862            )
1863            .await
1864            .unwrap();
1865
1866            let mut vid = advz_scheme(2);
1867            let disperse = vid.disperse(payload.encode()).unwrap();
1868            let common = disperse.common;
1869
1870            let common_bytes = bincode::serialize(&common).unwrap();
1871            let share = disperse.shares[0].clone();
1872            let mut share_bytes = Some(bincode::serialize(&share).unwrap());
1873
1874            // insert some nullable vid shares
1875            if i % 10 == 0 {
1876                share_bytes = None
1877            }
1878
1879            tx.upsert(
1880                "vid",
1881                ["height", "common", "share"],
1882                ["height"],
1883                [(height, common_bytes, share_bytes)],
1884            )
1885            .await
1886            .unwrap();
1887            tx.commit().await.unwrap();
1888        }
1889
1890        <SqlStorage as MigrateTypes<MockTypes>>::migrate_types(&storage, 50)
1891            .await
1892            .expect("failed to migrate");
1893
1894        <SqlStorage as MigrateTypes<MockTypes>>::migrate_types(&storage, 50)
1895            .await
1896            .expect("failed to migrate");
1897
1898        let mut tx = storage.read().await.unwrap();
1899        let (leaf_count,) = query_as::<(i64,)>("SELECT COUNT(*) from leaf2")
1900            .fetch_one(tx.as_mut())
1901            .await
1902            .unwrap();
1903
1904        let (vid_count,) = query_as::<(i64,)>("SELECT COUNT(*) from vid2")
1905            .fetch_one(tx.as_mut())
1906            .await
1907            .unwrap();
1908
1909        assert_eq!(leaf_count as u64, num_rows, "not all leaves migrated");
1910        assert_eq!(vid_count as u64, num_rows, "not all vid migrated");
1911    }
1912}