hotshot_query_service/data_source/storage/
sql.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use async_trait::async_trait;
17use chrono::Utc;
18#[cfg(not(feature = "embedded-db"))]
19use futures::future::FutureExt;
20use hotshot_types::{
21    data::VidShare,
22    traits::{metrics::Metrics, node_implementation::NodeType},
23};
24use itertools::Itertools;
25use log::LevelFilter;
26#[cfg(not(feature = "embedded-db"))]
27use sqlx::postgres::{PgConnectOptions, PgSslMode};
28#[cfg(feature = "embedded-db")]
29use sqlx::sqlite::SqliteConnectOptions;
30use sqlx::{
31    ConnectOptions, Row,
32    pool::{Pool, PoolOptions},
33};
34
35use crate::{
36    Header, QueryError, QueryResult,
37    availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
38    data_source::{
39        VersionedDataSource,
40        storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
41        update::Transaction as _,
42    },
43    metrics::PrometheusMetrics,
44    node::BlockId,
45    status::HasMetrics,
46};
47pub extern crate sqlx;
48pub use sqlx::{Database, Sqlite};
49
50mod db;
51mod migrate;
52mod queries;
53mod transaction;
54
55pub use anyhow::Error;
56pub use db::*;
57pub use include_dir::include_dir;
58pub use queries::QueryBuilder;
59pub use refinery::Migration;
60pub use transaction::*;
61
62use self::{migrate::Migrator, transaction::PoolMetrics};
63use super::{AvailabilityStorage, NodeStorage};
64// This needs to be reexported so that we can reference it by absolute path relative to this crate
65// in the expansion of `include_migrations`, even when `include_migrations` is invoked from another
66// crate which doesn't have `include_dir` as a dependency.
67pub use crate::include_migrations;
68
69/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite.
70///
71/// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl
72/// Iterator<Item = Migration>`. Each migration must be a text file which is an immediate child of
73/// `path`, and there must be no non-migration files in `path`. The migration files must have names
74/// of the form `V${version}__${name}.sql`, where `version` is a positive integer indicating how the
75/// migration is to be ordered relative to other migrations, and `name` is a descriptive name for
76/// the migration.
77///
78/// `path` should be an absolute path. It is possible to give a path relative to the root of the
79/// invoking crate by using environment variable expansions and the `CARGO_MANIFEST_DIR` environment
80/// variable.
81///
82/// As an example, this is the invocation used to load the default migrations from the
83/// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at
84/// - PostgreSQL migrations are in `/migrations/postgres`.
85/// - SQLite migrations are in `/migrations/sqlite`.
86///
87/// ```
88/// # use hotshot_query_service::data_source::sql::{include_migrations, Migration};
89/// // For PostgreSQL
90/// #[cfg(not(feature = "embedded-db"))]
91///  let mut migrations: Vec<Migration> =
92///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect();
93/// // For SQLite
94/// #[cfg(feature = "embedded-db")]
95/// let mut migrations: Vec<Migration> =
96///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect();
97///
98///     migrations.sort();
99///     assert_eq!(migrations[0].version(), 10);
100///     assert_eq!(migrations[0].name(), "init_schema");
101/// ```
102///
103/// Note that a similar macro is available from Refinery:
104/// [embed_migrations](https://docs.rs/refinery/0.8.11/refinery/macro.embed_migrations.html). This
105/// macro differs in that it evaluates to an iterator of [migrations](Migration), making it an
106/// expression macro, while `embed_migrations` is a statement macro that defines a module which
107/// provides access to the embedded migrations only indirectly via a
108/// [`Runner`](https://docs.rs/refinery/0.8.11/refinery/struct.Runner.html). The direct access to
109/// migrations provided by [`include_migrations`] makes this macro easier to use with
110/// [`Config::migrations`], for combining custom migrations with [`default_migrations`].
111#[macro_export]
112macro_rules! include_migrations {
113    ($dir:tt) => {
114        $crate::data_source::storage::sql::include_dir!($dir)
115            .files()
116            .map(|file| {
117                let path = file.path();
118                let name = path
119                    .file_name()
120                    .and_then(std::ffi::OsStr::to_str)
121                    .unwrap_or_else(|| {
122                        panic!(
123                            "migration file {} must have a non-empty UTF-8 name",
124                            path.display()
125                        )
126                    });
127                let sql = file
128                    .contents_utf8()
129                    .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
130                $crate::data_source::storage::sql::Migration::unapplied(name, sql)
131                    .expect("invalid migration")
132            })
133    };
134}
135
136/// The migrations required to build the default schema for this version of [`SqlStorage`].
137pub fn default_migrations() -> Vec<Migration> {
138    #[cfg(not(feature = "embedded-db"))]
139    let mut migrations =
140        include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
141
142    #[cfg(feature = "embedded-db")]
143    let mut migrations =
144        include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
145
146    // Check version uniqueness and sort by version.
147    validate_migrations(&mut migrations).expect("default migrations are invalid");
148
149    // Check that all migration versions are multiples of 100, so that custom migrations can be
150    // inserted in between.
151    for m in &migrations {
152        if m.version() <= 30 {
153            // An older version of this software used intervals of 10 instead of 100. This was
154            // changed to allow more custom migrations between each default migration, but we must
155            // still accept older migrations that followed the older rule.
156            assert!(
157                m.version() > 0 && m.version() % 10 == 0,
158                "legacy default migration version {} is not a positive multiple of 10",
159                m.version()
160            );
161        } else {
162            assert!(
163                m.version() % 100 == 0,
164                "default migration version {} is not a multiple of 100",
165                m.version()
166            );
167        }
168    }
169
170    migrations
171}
172
173/// Validate and preprocess a sequence of migrations.
174///
175/// * Ensure all migrations have distinct versions
176/// * Ensure migrations are sorted by increasing version
177fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
178    migrations.sort_by_key(|m| m.version());
179
180    // Check version uniqueness.
181    for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
182        if next <= prev {
183            return Err(Error::msg(format!(
184                "migration versions are not strictly increasing ({prev}->{next})"
185            )));
186        }
187    }
188
189    Ok(())
190}
191
192/// Add custom migrations to a default migration sequence.
193///
194/// Migrations in `custom` replace migrations in `default` with the same version. Otherwise, the two
195/// sequences `default` and `custom` are merged so that the resulting sequence is sorted by
196/// ascending version number. Each of `default` and `custom` is assumed to be the output of
197/// [`validate_migrations`]; that is, each is sorted by version and contains no duplicate versions.
198fn add_custom_migrations(
199    default: impl IntoIterator<Item = Migration>,
200    custom: impl IntoIterator<Item = Migration>,
201) -> impl Iterator<Item = Migration> {
202    default
203        .into_iter()
204        // Merge sorted lists, joining pairs of equal version into `EitherOrBoth::Both`.
205        .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
206        // Prefer the custom migration for a given version when both default and custom versions
207        // are present.
208        .map(|pair| pair.reduce(|_, custom| custom))
209}
210
211#[derive(Clone)]
212pub struct Config {
213    #[cfg(feature = "embedded-db")]
214    db_opt: SqliteConnectOptions,
215
216    #[cfg(not(feature = "embedded-db"))]
217    db_opt: PgConnectOptions,
218
219    pool_opt: PoolOptions<Db>,
220
221    /// Extra pool_opt to allow separately configuring the connection pool for query service
222    #[cfg(not(feature = "embedded-db"))]
223    pool_opt_query: PoolOptions<Db>,
224
225    #[cfg(not(feature = "embedded-db"))]
226    schema: String,
227    reset: bool,
228    migrations: Vec<Migration>,
229    no_migrations: bool,
230    pruner_cfg: Option<PrunerCfg>,
231    archive: bool,
232    pool: Option<Pool<Db>>,
233}
234
235#[cfg(not(feature = "embedded-db"))]
236impl Default for Config {
237    fn default() -> Self {
238        PgConnectOptions::default()
239            .username("postgres")
240            .password("password")
241            .host("localhost")
242            .port(5432)
243            .into()
244    }
245}
246
247#[cfg(feature = "embedded-db")]
248impl Default for Config {
249    fn default() -> Self {
250        SqliteConnectOptions::default()
251            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
252            .busy_timeout(Duration::from_secs(30))
253            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
254            .create_if_missing(true)
255            .into()
256    }
257}
258
259#[cfg(feature = "embedded-db")]
260impl From<SqliteConnectOptions> for Config {
261    fn from(db_opt: SqliteConnectOptions) -> Self {
262        Self {
263            db_opt,
264            pool_opt: PoolOptions::default(),
265            reset: false,
266            migrations: vec![],
267            no_migrations: false,
268            pruner_cfg: None,
269            archive: false,
270            pool: None,
271        }
272    }
273}
274
275#[cfg(not(feature = "embedded-db"))]
276impl From<PgConnectOptions> for Config {
277    fn from(db_opt: PgConnectOptions) -> Self {
278        Self {
279            db_opt,
280            pool_opt: PoolOptions::default(),
281            pool_opt_query: PoolOptions::default(),
282            schema: "hotshot".into(),
283            reset: false,
284            migrations: vec![],
285            no_migrations: false,
286            pruner_cfg: None,
287            archive: false,
288            pool: None,
289        }
290    }
291}
292
293#[cfg(not(feature = "embedded-db"))]
294impl FromStr for Config {
295    type Err = <PgConnectOptions as FromStr>::Err;
296
297    fn from_str(s: &str) -> Result<Self, Self::Err> {
298        Ok(PgConnectOptions::from_str(s)?.into())
299    }
300}
301
302#[cfg(feature = "embedded-db")]
303impl FromStr for Config {
304    type Err = <SqliteConnectOptions as FromStr>::Err;
305
306    fn from_str(s: &str) -> Result<Self, Self::Err> {
307        Ok(SqliteConnectOptions::from_str(s)?.into())
308    }
309}
310
311#[cfg(feature = "embedded-db")]
312impl Config {
313    pub fn busy_timeout(mut self, timeout: Duration) -> Self {
314        self.db_opt = self.db_opt.busy_timeout(timeout);
315        self
316    }
317
318    pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
319        self.db_opt = self.db_opt.filename(path);
320        self
321    }
322}
323
324#[cfg(not(feature = "embedded-db"))]
325impl Config {
326    /// Set the hostname of the database server.
327    ///
328    /// The default is `localhost`.
329    pub fn host(mut self, host: impl Into<String>) -> Self {
330        self.db_opt = self.db_opt.host(&host.into());
331        self
332    }
333
334    /// Set the port on which to connect to the database.
335    ///
336    /// The default is 5432, the default Postgres port.
337    pub fn port(mut self, port: u16) -> Self {
338        self.db_opt = self.db_opt.port(port);
339        self
340    }
341
342    /// Set the DB user to connect as.
343    pub fn user(mut self, user: &str) -> Self {
344        self.db_opt = self.db_opt.username(user);
345        self
346    }
347
348    /// Set a password for connecting to the database.
349    pub fn password(mut self, password: &str) -> Self {
350        self.db_opt = self.db_opt.password(password);
351        self
352    }
353
354    /// Set the name of the database to connect to.
355    pub fn database(mut self, database: &str) -> Self {
356        self.db_opt = self.db_opt.database(database);
357        self
358    }
359
360    /// Use TLS for an encrypted connection to the database.
361    ///
362    /// Note that an encrypted connection may be established even if this option is not set, as long
363    /// as both the client and server support it. This option merely causes connection to fail if an
364    /// encrypted stream cannot be established.
365    pub fn tls(mut self) -> Self {
366        self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
367        self
368    }
369
370    /// Set the name of the schema to use for queries.
371    ///
372    /// The default schema is named `hotshot` and is created via the default migrations.
373    pub fn schema(mut self, schema: impl Into<String>) -> Self {
374        self.schema = schema.into();
375        self
376    }
377}
378
379impl Config {
380    /// Sets the database connection pool
381    /// This allows reusing an existing connection pool when building a new `SqlStorage` instance.
382    pub fn pool(mut self, pool: Pool<Db>) -> Self {
383        self.pool = Some(pool);
384        self
385    }
386
387    /// Reset the schema on connection.
388    ///
389    /// When this [`Config`] is used to [`connect`](Self::connect) a
390    /// [`SqlDataSource`](crate::data_source::SqlDataSource), if this option is set, the relevant
391    /// [`schema`](Self::schema) will first be dropped and then recreated, yielding a completely
392    /// fresh instance of the query service.
393    ///
394    /// This is a particularly useful capability for development and staging environments. Still, it
395    /// must be used with extreme caution, as using this will irrevocably delete any data pertaining
396    /// to the query service in the database.
397    pub fn reset_schema(mut self) -> Self {
398        self.reset = true;
399        self
400    }
401
402    /// Add custom migrations to run when connecting to the database.
403    pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
404        self.migrations.extend(migrations);
405        self
406    }
407
408    /// Skip all migrations when connecting to the database.
409    pub fn no_migrations(mut self) -> Self {
410        self.no_migrations = true;
411        self
412    }
413
414    /// Enable pruning with a given configuration.
415    ///
416    /// If [`archive`](Self::archive) was previously specified, this will override it.
417    pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
418        cfg.validate()?;
419        self.pruner_cfg = Some(cfg);
420        self.archive = false;
421        Ok(self)
422    }
423
424    /// Disable pruning and reconstruct previously pruned data.
425    ///
426    /// While running without pruning is the default behavior, the default will not try to
427    /// reconstruct data that was pruned in a previous run where pruning was enabled. This option
428    /// instructs the service to run without pruning _and_ reconstruct all previously pruned data by
429    /// fetching from peers.
430    ///
431    /// If [`pruner_cfg`](Self::pruner_cfg) was previously specified, this will override it.
432    pub fn archive(mut self) -> Self {
433        self.pruner_cfg = None;
434        self.archive = true;
435        self
436    }
437
438    /// Set the maximum idle time of a connection.
439    ///
440    /// Any connection which has been open and unused longer than this duration will be
441    /// automatically closed to reduce load on the server.
442    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
443        self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
444
445        #[cfg(not(feature = "embedded-db"))]
446        {
447            self.pool_opt_query = self.pool_opt_query.idle_timeout(Some(timeout));
448        }
449
450        self
451    }
452
453    /// Set the maximum lifetime of a connection.
454    ///
455    /// Any connection which has been open longer than this duration will be automatically closed
456    /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh
457    /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the
458    /// server implementation.
459    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
460        self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
461
462        #[cfg(not(feature = "embedded-db"))]
463        {
464            self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
465        }
466
467        self
468    }
469
470    /// Set the minimum number of connections to maintain at any time.
471    ///
472    /// The data source will, to the best of its ability, maintain at least `min` open connections
473    /// at all times. This can be used to reduce the latency hit of opening new connections when at
474    /// least this many simultaneous connections are frequently needed.
475    pub fn min_connections(mut self, min: u32) -> Self {
476        self.pool_opt = self.pool_opt.min_connections(min);
477        self
478    }
479
480    #[cfg(not(feature = "embedded-db"))]
481    pub fn query_min_connections(mut self, min: u32) -> Self {
482        self.pool_opt_query = self.pool_opt_query.min_connections(min);
483        self
484    }
485
486    /// Set the maximum number of connections to maintain at any time.
487    ///
488    /// Once `max` connections are in use simultaneously, further attempts to acquire a connection
489    /// (or begin a transaction) will block until one of the existing connections is released.
490    pub fn max_connections(mut self, max: u32) -> Self {
491        self.pool_opt = self.pool_opt.max_connections(max);
492        self
493    }
494
495    #[cfg(not(feature = "embedded-db"))]
496    pub fn query_max_connections(mut self, max: u32) -> Self {
497        self.pool_opt_query = self.pool_opt_query.max_connections(max);
498        self
499    }
500
501    /// Log at WARN level any time a SQL statement takes longer than `threshold`.
502    ///
503    /// The default threshold is 1s.
504    pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
505        self.db_opt = self
506            .db_opt
507            .log_slow_statements(LevelFilter::Warn, threshold);
508        self
509    }
510
511    /// Set the maximum time a single SQL statement is allowed to run before being canceled.
512    ///
513    /// This helps prevent queries from running indefinitely even when the client is dropped
514    #[cfg(not(feature = "embedded-db"))]
515    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
516        // Format duration as milliseconds
517        // PostgreSQL interprets values without units as milliseconds
518        let timeout_ms = timeout.as_millis();
519        self.db_opt = self
520            .db_opt
521            .options([("statement_timeout", timeout_ms.to_string())]);
522        self
523    }
524
525    /// not supported for SQLite.
526    #[cfg(feature = "embedded-db")]
527    pub fn statement_timeout(self, _timeout: Duration) -> Self {
528        self
529    }
530}
531
532/// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database.
533#[derive(Clone, Debug)]
534pub struct SqlStorage {
535    pool: Pool<Db>,
536    metrics: PrometheusMetrics,
537    pool_metrics: PoolMetrics,
538    pruner_cfg: Option<PrunerCfg>,
539}
540
541#[derive(Debug, Default)]
542pub struct Pruner {
543    pruned_height: Option<u64>,
544    target_height: Option<u64>,
545    minimum_retention_height: Option<u64>,
546}
547
548#[derive(PartialEq)]
549pub enum StorageConnectionType {
550    Sequencer,
551    Query,
552}
553
554impl SqlStorage {
555    pub fn pool(&self) -> Pool<Db> {
556        self.pool.clone()
557    }
558
559    /// Connect to a remote database.
560    #[allow(unused_variables)]
561    pub async fn connect(
562        mut config: Config,
563        connection_type: StorageConnectionType,
564    ) -> Result<Self, Error> {
565        let metrics = PrometheusMetrics::default();
566        let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
567
568        #[cfg(feature = "embedded-db")]
569        let pool = config.pool_opt.clone();
570        #[cfg(not(feature = "embedded-db"))]
571        let pool = match connection_type {
572            StorageConnectionType::Sequencer => config.pool_opt.clone(),
573            StorageConnectionType::Query => config.pool_opt_query.clone(),
574        };
575
576        let pruner_cfg = config.pruner_cfg;
577
578        // Only reuse the same pool if we're using sqlite
579        if cfg!(feature = "embedded-db") || connection_type == StorageConnectionType::Sequencer {
580            // re-use the same pool if present and return early
581            if let Some(pool) = config.pool {
582                return Ok(Self {
583                    metrics,
584                    pool_metrics,
585                    pool,
586                    pruner_cfg,
587                });
588            }
589        } else if config.pool.is_some() {
590            tracing::info!("not reusing existing pool for query connection");
591        }
592
593        #[cfg(not(feature = "embedded-db"))]
594        let schema = config.schema.clone();
595        #[cfg(not(feature = "embedded-db"))]
596        let pool = pool.after_connect(move |conn, _| {
597            let schema = config.schema.clone();
598            async move {
599                query(&format!("SET search_path TO {schema}"))
600                    .execute(conn)
601                    .await?;
602                Ok(())
603            }
604            .boxed()
605        });
606
607        #[cfg(feature = "embedded-db")]
608        if config.reset {
609            std::fs::remove_file(config.db_opt.get_filename())?;
610        }
611
612        let pool = pool.connect_with(config.db_opt).await?;
613
614        // Create or connect to the schema for this query service.
615        let mut conn = pool.acquire().await?;
616
617        // Disable statement timeout for migrations, as they can take a long time
618        #[cfg(not(feature = "embedded-db"))]
619        query("SET statement_timeout = 0")
620            .execute(conn.as_mut())
621            .await?;
622
623        #[cfg(not(feature = "embedded-db"))]
624        if config.reset {
625            query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
626                .execute(conn.as_mut())
627                .await?;
628        }
629
630        #[cfg(not(feature = "embedded-db"))]
631        query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
632            .execute(conn.as_mut())
633            .await?;
634
635        // Get migrations and interleave with custom migrations, sorting by version number.
636        validate_migrations(&mut config.migrations)?;
637        let migrations =
638            add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
639
640        // Get a migration runner. Depending on the config, we can either use this to actually run
641        // the migrations or just check if the database is up to date.
642        let runner = refinery::Runner::new(&migrations).set_grouped(true);
643
644        if config.no_migrations {
645            // We've been asked not to run any migrations. Abort if the DB is not already up to
646            // date.
647            let last_applied = runner
648                .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
649                .await?;
650            let last_expected = migrations.last();
651            if last_applied.as_ref() != last_expected {
652                return Err(Error::msg(format!(
653                    "DB is out of date: last applied migration is {last_applied:?}, but expected \
654                     {last_expected:?}"
655                )));
656            }
657        } else {
658            // Run migrations using `refinery`.
659            match runner.run_async(&mut Migrator::from(&mut conn)).await {
660                Ok(report) => {
661                    tracing::info!("ran DB migrations: {report:?}");
662                },
663                Err(err) => {
664                    tracing::error!("DB migrations failed: {:?}", err.report());
665                    Err(err)?;
666                },
667            }
668        }
669
670        if config.archive {
671            // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will
672            // reconstruct previously pruned data.
673            query("DELETE FROM pruned_height WHERE id = 1")
674                .execute(conn.as_mut())
675                .await?;
676        }
677
678        conn.close().await?;
679
680        Ok(Self {
681            pool,
682            pool_metrics,
683            metrics,
684            pruner_cfg,
685        })
686    }
687}
688
689impl PrunerConfig for SqlStorage {
690    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
691        self.pruner_cfg = Some(cfg);
692    }
693
694    fn get_pruning_config(&self) -> Option<PrunerCfg> {
695        self.pruner_cfg.clone()
696    }
697}
698
699impl HasMetrics for SqlStorage {
700    fn metrics(&self) -> &PrometheusMetrics {
701        &self.metrics
702    }
703}
704
705impl SqlStorage {
706    async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
707        let mut tx = self.read().await.map_err(|err| QueryError::Error {
708            message: err.to_string(),
709        })?;
710        let (Some(height),) =
711            query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
712                .fetch_one(tx.as_mut())
713                .await?
714        else {
715            return Ok(None);
716        };
717        Ok(Some(height as u64))
718    }
719
720    async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
721        let mut tx = self.read().await.map_err(|err| QueryError::Error {
722            message: err.to_string(),
723        })?;
724
725        // We order by timestamp and then height, even though logically this is no different than
726        // just ordering by height, since timestamps are monotonic. The reason is that this order
727        // allows the query planner to efficiently solve the where clause and presort the results
728        // based on the timestamp index. The remaining sort on height, which guarantees a unique
729        // block if multiple blocks have the same timestamp, is very efficient, because there are
730        // never more than a handful of blocks with the same timestamp.
731        let Some((height,)) = query_as::<(i64,)>(
732            "SELECT height FROM header
733              WHERE timestamp <= $1
734              ORDER BY timestamp DESC, height DESC
735              LIMIT 1",
736        )
737        .bind(timestamp)
738        .fetch_optional(tx.as_mut())
739        .await?
740        else {
741            return Ok(None);
742        };
743        Ok(Some(height as u64))
744    }
745
746    /// Get the stored VID share for a given block, if one exists.
747    pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
748    where
749        Types: NodeType,
750        Header<Types>: QueryableHeader<Types>,
751    {
752        let mut tx = self.read().await.map_err(|err| QueryError::Error {
753            message: err.to_string(),
754        })?;
755        let share = tx.vid_share(block_id).await?;
756        Ok(share)
757    }
758
759    /// Get the stored VID common data for a given block, if one exists.
760    pub async fn get_vid_common<Types: NodeType>(
761        &self,
762        block_id: BlockId<Types>,
763    ) -> QueryResult<VidCommonQueryData<Types>>
764    where
765        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
766        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
767    {
768        let mut tx = self.read().await.map_err(|err| QueryError::Error {
769            message: err.to_string(),
770        })?;
771        let common = tx.get_vid_common(block_id).await?;
772        Ok(common)
773    }
774
775    /// Get the stored VID common metadata for a given block, if one exists.
776    pub async fn get_vid_common_metadata<Types: NodeType>(
777        &self,
778        block_id: BlockId<Types>,
779    ) -> QueryResult<VidCommonMetadata<Types>>
780    where
781        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
782        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
783    {
784        let mut tx = self.read().await.map_err(|err| QueryError::Error {
785            message: err.to_string(),
786        })?;
787        let common = tx.get_vid_common_metadata(block_id).await?;
788        Ok(common)
789    }
790}
791
792#[async_trait]
793impl PruneStorage for SqlStorage {
794    type Pruner = Pruner;
795
796    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
797        let mut tx = self.read().await?;
798
799        #[cfg(not(feature = "embedded-db"))]
800        let query = "SELECT pg_database_size(current_database())";
801
802        #[cfg(feature = "embedded-db")]
803        let query = "
804            SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
805                     AS total_bytes";
806
807        let row = tx.fetch_one(query).await?;
808        let size: i64 = row.get(0);
809
810        Ok(size as u64)
811    }
812
813    /// Trigger incremental vacuum to free up space in the SQLite database.
814    /// Note: We don't vacuum the Postgres database,
815    /// as there is no manual trigger for incremental vacuum,
816    /// and a full vacuum can take a lot of time.
817    #[cfg(feature = "embedded-db")]
818    async fn vacuum(&self) -> anyhow::Result<()> {
819        let config = self.get_pruning_config().ok_or(QueryError::Error {
820            message: "Pruning config not found".to_string(),
821        })?;
822        let mut conn = self.pool().acquire().await?;
823        query(&format!(
824            "PRAGMA incremental_vacuum({})",
825            config.incremental_vacuum_pages()
826        ))
827        .execute(conn.as_mut())
828        .await?;
829        conn.close().await?;
830        Ok(())
831    }
832
833    /// Note: The prune operation may not immediately free up space even after rows are deleted.
834    /// This is because a vacuum operation may be necessary to reclaim more space.
835    /// PostgreSQL already performs auto vacuuming, so we are not including it here
836    /// as running a vacuum operation can be resource-intensive.
837    async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
838        let cfg = self.get_pruning_config().ok_or(QueryError::Error {
839            message: "Pruning config not found".to_string(),
840        })?;
841        let batch_size = cfg.batch_size();
842        let max_usage = cfg.max_usage();
843        let state_tables = cfg.state_tables();
844
845        // If a pruner run was already in progress, some variables may already be set,
846        // depending on whether a batch was deleted and which batch it was (target or minimum retention).
847        // This enables us to resume the pruner run from the exact heights.
848        // If any of these values are not set, they can be loaded from the database if necessary.
849        let mut minimum_retention_height = pruner.minimum_retention_height;
850        let mut target_height = pruner.target_height;
851        let pruned_height = match pruner.pruned_height {
852            Some(h) => Some(h),
853            None => {
854                let Some(height) = self.get_minimum_height().await? else {
855                    tracing::info!("database is empty, nothing to prune");
856                    return Ok(None);
857                };
858
859                if height > 0 { Some(height - 1) } else { None }
860            },
861        };
862
863        // Prune data exceeding target retention in batches
864        if pruner.target_height.is_none() {
865            let th = self
866                .get_height_by_timestamp(
867                    Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
868                )
869                .await?;
870            target_height = th;
871            pruner.target_height = target_height;
872        };
873
874        if let Some(th) = target_height
875            && pruned_height < Some(th)
876        {
877            let batch_end = match pruned_height {
878                None => batch_size - 1,
879                Some(h) => h + batch_size,
880            };
881            let to = min(batch_end, th);
882            let mut tx = self.write().await?;
883            tx.delete_batch(state_tables, to).await?;
884            tx.commit().await.map_err(|e| QueryError::Error {
885                message: format!("failed to commit {e}"),
886            })?;
887            pruner.pruned_height = Some(to);
888            return Ok(Some(to));
889        }
890
891        // If threshold is set, prune data exceeding minimum retention in batches
892        // This parameter is needed for SQL storage as there is no direct way to get free space.
893        if let Some(threshold) = cfg.pruning_threshold() {
894            let usage = self.get_disk_usage().await?;
895
896            // Prune data exceeding minimum retention in batches starting from minimum height
897            // until usage is below threshold
898            if usage > threshold {
899                tracing::warn!(
900                    "Disk usage {usage} exceeds pruning threshold {:?}",
901                    cfg.pruning_threshold()
902                );
903
904                if minimum_retention_height.is_none() {
905                    minimum_retention_height = self
906                        .get_height_by_timestamp(
907                            Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
908                        )
909                        .await?;
910
911                    pruner.minimum_retention_height = minimum_retention_height;
912                }
913
914                if let Some(min_retention_height) = minimum_retention_height
915                    && (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
916                    && pruned_height < Some(min_retention_height)
917                {
918                    let batch_end = match pruned_height {
919                        None => batch_size - 1,
920                        Some(h) => h + batch_size,
921                    };
922                    let to = min(batch_end, min_retention_height);
923                    let mut tx = self.write().await?;
924                    tx.delete_batch(state_tables, to).await?;
925                    tx.commit().await.map_err(|e| QueryError::Error {
926                        message: format!("failed to commit {e}"),
927                    })?;
928
929                    self.vacuum().await?;
930
931                    pruner.pruned_height = Some(to);
932                    return Ok(Some(to));
933                }
934            }
935        }
936
937        Ok(None)
938    }
939}
940
941impl VersionedDataSource for SqlStorage {
942    type Transaction<'a>
943        = Transaction<Write>
944    where
945        Self: 'a;
946    type ReadOnly<'a>
947        = Transaction<Read>
948    where
949        Self: 'a;
950
951    async fn write(&self) -> anyhow::Result<Transaction<Write>> {
952        Transaction::new(&self.pool, self.pool_metrics.clone()).await
953    }
954
955    async fn read(&self) -> anyhow::Result<Transaction<Read>> {
956        Transaction::new(&self.pool, self.pool_metrics.clone()).await
957    }
958}
959
960// These tests run the `postgres` Docker image, which doesn't work on Windows.
961#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
962pub mod testing {
963    #![allow(unused_imports)]
964    use std::{
965        env,
966        process::{Command, Stdio},
967        str::{self, FromStr},
968        time::Duration,
969    };
970
971    use refinery::Migration;
972    use test_utils::reserve_tcp_port;
973    use tokio::{net::TcpStream, time::timeout};
974
975    use super::Config;
976    use crate::testing::sleep;
977    #[derive(Debug)]
978    pub struct TmpDb {
979        #[cfg(not(feature = "embedded-db"))]
980        host: String,
981        #[cfg(not(feature = "embedded-db"))]
982        port: u16,
983        #[cfg(not(feature = "embedded-db"))]
984        container_id: String,
985        #[cfg(feature = "embedded-db")]
986        db_path: std::path::PathBuf,
987        #[allow(dead_code)]
988        persistent: bool,
989    }
990    impl TmpDb {
991        #[cfg(feature = "embedded-db")]
992        fn init_sqlite_db(persistent: bool) -> Self {
993            let file = tempfile::Builder::new()
994                .prefix("sqlite-")
995                .suffix(".db")
996                .tempfile()
997                .unwrap();
998
999            let (_, db_path) = file.keep().unwrap();
1000
1001            Self {
1002                db_path,
1003                persistent,
1004            }
1005        }
1006        pub async fn init() -> Self {
1007            #[cfg(feature = "embedded-db")]
1008            return Self::init_sqlite_db(false);
1009
1010            #[cfg(not(feature = "embedded-db"))]
1011            Self::init_postgres(false).await
1012        }
1013
1014        pub async fn persistent() -> Self {
1015            #[cfg(feature = "embedded-db")]
1016            return Self::init_sqlite_db(true);
1017
1018            #[cfg(not(feature = "embedded-db"))]
1019            Self::init_postgres(true).await
1020        }
1021
1022        #[cfg(not(feature = "embedded-db"))]
1023        async fn init_postgres(persistent: bool) -> Self {
1024            let docker_hostname = env::var("DOCKER_HOSTNAME");
1025            // This picks an unused port on the current system.  If docker is
1026            // configured to run on a different host then this may not find a
1027            // "free" port on that system.
1028            // We *might* be able to get away with this as any remote docker
1029            // host should hopefully be pretty open with it's port space.
1030            let port = reserve_tcp_port().unwrap();
1031            let host = docker_hostname.unwrap_or("localhost".to_string());
1032
1033            let mut cmd = Command::new("docker");
1034            cmd.arg("run")
1035                .arg("-d")
1036                .args(["-p", &format!("{port}:5432")])
1037                .args(["-e", "POSTGRES_PASSWORD=password"]);
1038
1039            if !persistent {
1040                cmd.arg("--rm");
1041            }
1042
1043            let output = cmd.arg("postgres").output().unwrap();
1044            let stdout = str::from_utf8(&output.stdout).unwrap();
1045            let stderr = str::from_utf8(&output.stderr).unwrap();
1046            if !output.status.success() {
1047                panic!("failed to start postgres docker: {stderr}");
1048            }
1049
1050            // Create the TmpDb object immediately after starting the Docker container, so if
1051            // anything panics after this `drop` will be called and we will clean up.
1052            let container_id = stdout.trim().to_owned();
1053            tracing::info!("launched postgres docker {container_id}");
1054            let db = Self {
1055                host,
1056                port,
1057                container_id: container_id.clone(),
1058                persistent,
1059            };
1060
1061            db.wait_for_ready().await;
1062            db
1063        }
1064
1065        #[cfg(not(feature = "embedded-db"))]
1066        pub fn host(&self) -> String {
1067            self.host.clone()
1068        }
1069
1070        #[cfg(not(feature = "embedded-db"))]
1071        pub fn port(&self) -> u16 {
1072            self.port
1073        }
1074
1075        #[cfg(feature = "embedded-db")]
1076        pub fn path(&self) -> std::path::PathBuf {
1077            self.db_path.clone()
1078        }
1079
1080        pub fn config(&self) -> Config {
1081            #[cfg(feature = "embedded-db")]
1082            let mut cfg = Config::default().db_path(self.db_path.clone());
1083
1084            #[cfg(not(feature = "embedded-db"))]
1085            let mut cfg = Config::default()
1086                .user("postgres")
1087                .password("password")
1088                .host(self.host())
1089                .port(self.port());
1090
1091            cfg = cfg.migrations(vec![
1092                Migration::unapplied(
1093                    "V101__create_test_merkle_tree_table.sql",
1094                    &TestMerkleTreeMigration::create("test_tree"),
1095                )
1096                .unwrap(),
1097            ]);
1098
1099            cfg
1100        }
1101
1102        #[cfg(not(feature = "embedded-db"))]
1103        pub fn stop_postgres(&mut self) {
1104            tracing::info!(container = self.container_id, "stopping postgres");
1105            let output = Command::new("docker")
1106                .args(["stop", self.container_id.as_str()])
1107                .output()
1108                .unwrap();
1109            assert!(
1110                output.status.success(),
1111                "error killing postgres docker {}: {}",
1112                self.container_id,
1113                str::from_utf8(&output.stderr).unwrap()
1114            );
1115        }
1116
1117        #[cfg(not(feature = "embedded-db"))]
1118        pub async fn start_postgres(&mut self) {
1119            tracing::info!(container = self.container_id, "resuming postgres");
1120            let output = Command::new("docker")
1121                .args(["start", self.container_id.as_str()])
1122                .output()
1123                .unwrap();
1124            assert!(
1125                output.status.success(),
1126                "error starting postgres docker {}: {}",
1127                self.container_id,
1128                str::from_utf8(&output.stderr).unwrap()
1129            );
1130
1131            self.wait_for_ready().await;
1132        }
1133
1134        #[cfg(not(feature = "embedded-db"))]
1135        async fn wait_for_ready(&self) {
1136            let timeout_duration = Duration::from_secs(
1137                env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1138                    .unwrap_or("60".to_string())
1139                    .parse()
1140                    .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1141            );
1142
1143            if let Err(err) = timeout(timeout_duration, async {
1144                while Command::new("docker")
1145                    .args([
1146                        "exec",
1147                        &self.container_id,
1148                        "pg_isready",
1149                        "-h",
1150                        "localhost",
1151                        "-U",
1152                        "postgres",
1153                    ])
1154                    .env("PGPASSWORD", "password")
1155                    // Null input so the command terminates as soon as it manages to connect.
1156                    .stdin(Stdio::null())
1157                    // Discard command output.
1158                    .stdout(Stdio::null())
1159                    .stderr(Stdio::null())
1160                    .status()
1161                    // We should ensure the exit status. A simple `unwrap`
1162                    // would panic on unrelated errors (such as network
1163                    // connection failures)
1164                    .and_then(|status| {
1165                        status
1166                            .success()
1167                            .then_some(true)
1168                            // Any ol' Error will do
1169                            .ok_or(std::io::Error::from_raw_os_error(666))
1170                    })
1171                    .is_err()
1172                {
1173                    tracing::warn!("database is not ready");
1174                    sleep(Duration::from_secs(1)).await;
1175                }
1176
1177                // The above command ensures the database is ready inside the Docker container.
1178                // However, on some systems, there is a slight delay before the port is exposed via
1179                // host networking. We don't need to check again that the database is ready on the
1180                // host (and maybe can't, because the host might not have pg_isready installed), but
1181                // we can ensure the port is open by just establishing a TCP connection.
1182                while let Err(err) =
1183                    TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1184                {
1185                    tracing::warn!("database is ready, but port is not available to host: {err:#}");
1186                    sleep(Duration::from_millis(100)).await;
1187                }
1188            })
1189            .await
1190            {
1191                panic!(
1192                    "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1193                     {err:#}\n{}",
1194                    "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1195                );
1196            }
1197        }
1198    }
1199
1200    #[cfg(not(feature = "embedded-db"))]
1201    impl Drop for TmpDb {
1202        fn drop(&mut self) {
1203            self.stop_postgres();
1204        }
1205    }
1206
1207    #[cfg(feature = "embedded-db")]
1208    impl Drop for TmpDb {
1209        fn drop(&mut self) {
1210            if !self.persistent {
1211                std::fs::remove_file(self.db_path.clone()).unwrap();
1212            }
1213        }
1214    }
1215
1216    pub struct TestMerkleTreeMigration;
1217
1218    impl TestMerkleTreeMigration {
1219        fn create(name: &str) -> String {
1220            let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1221                (
1222                    "TEXT",
1223                    "BLOB",
1224                    "INTEGER PRIMARY KEY AUTOINCREMENT",
1225                    " (json_extract(data, '$.test_merkle_tree_root'))",
1226                )
1227            } else {
1228                (
1229                    "BIT(8)",
1230                    "BYTEA",
1231                    "SERIAL PRIMARY KEY",
1232                    "(data->>'test_merkle_tree_root')",
1233                )
1234            };
1235
1236            format!(
1237                "CREATE TABLE IF NOT EXISTS hash
1238            (
1239                id {hash_pk},
1240                value {binary}  NOT NULL UNIQUE
1241            );
1242
1243            ALTER TABLE header
1244            ADD column test_merkle_tree_root text
1245            GENERATED ALWAYS as {root_stored_column} STORED;
1246
1247            CREATE TABLE {name}
1248            (
1249                path JSONB NOT NULL,
1250                created BIGINT NOT NULL,
1251                hash_id INT NOT NULL,
1252                children JSONB,
1253                children_bitvec {bit_vec},
1254                idx JSONB,
1255                entry JSONB,
1256                PRIMARY KEY (path, created)
1257            );
1258            CREATE INDEX {name}_created ON {name} (created);"
1259            )
1260        }
1261    }
1262}
1263
1264// These tests run the `postgres` Docker image, which doesn't work on Windows.
1265#[cfg(all(test, not(target_os = "windows")))]
1266mod test {
1267    use std::time::Duration;
1268
1269    use hotshot_example_types::{
1270        node_types::TEST_VERSIONS,
1271        state_types::{TestInstanceState, TestValidatedState},
1272    };
1273    use jf_merkle_tree_compat::{
1274        MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::UniversalMerkleTree,
1275    };
1276    use tokio::time::sleep;
1277
1278    use super::{testing::TmpDb, *};
1279    use crate::{
1280        availability::{BlockQueryData, LeafQueryData},
1281        data_source::storage::{UpdateAvailabilityStorage, pruning::PrunedHeightStorage},
1282        merklized_state::{MerklizedState, UpdateStateData},
1283        testing::mocks::{MockMerkleTree, MockTypes},
1284    };
1285
1286    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1287    async fn test_migrations() {
1288        let db = TmpDb::init().await;
1289        let cfg = db.config();
1290
1291        let connect = |migrations: bool, custom_migrations| {
1292            let cfg = cfg.clone();
1293            async move {
1294                let mut cfg = cfg.migrations(custom_migrations);
1295                if !migrations {
1296                    cfg = cfg.no_migrations();
1297                }
1298                let client = SqlStorage::connect(cfg, StorageConnectionType::Query).await?;
1299                Ok::<_, Error>(client)
1300            }
1301        };
1302
1303        // Connecting with migrations disabled should fail if the database is not already up to date
1304        // (since we've just created a fresh database, it isn't).
1305        let err = connect(false, vec![]).await.unwrap_err();
1306        tracing::info!("connecting without running migrations failed as expected: {err}");
1307
1308        // Now connect and run migrations to bring the database up to date.
1309        connect(true, vec![]).await.unwrap();
1310        // Now connecting without migrations should work.
1311        connect(false, vec![]).await.unwrap();
1312
1313        // Connect with some custom migrations, to advance the schema even further. Pass in the
1314        // custom migrations out of order; they should still execute in order of version number.
1315        // The SQL commands used here will fail if not run in order.
1316        let migrations = vec![
1317            Migration::unapplied(
1318                "V9999__create_test_table.sql",
1319                "ALTER TABLE test ADD COLUMN data INTEGER;",
1320            )
1321            .unwrap(),
1322            Migration::unapplied(
1323                "V9998__create_test_table.sql",
1324                "CREATE TABLE test (x bigint);",
1325            )
1326            .unwrap(),
1327        ];
1328        connect(true, migrations.clone()).await.unwrap();
1329
1330        // Connect using the default schema (no custom migrations) and not running migrations. This
1331        // should fail because the database is _ahead_ of the client in terms of schema.
1332        let err = connect(false, vec![]).await.unwrap_err();
1333        tracing::info!("connecting without running migrations failed as expected: {err}");
1334
1335        // Connecting with the customized schema should work even without running migrations.
1336        connect(true, migrations).await.unwrap();
1337    }
1338
1339    #[test]
1340    #[cfg(not(feature = "embedded-db"))]
1341    fn test_config_from_str() {
1342        let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1343        assert_eq!(cfg.db_opt.get_username(), "user");
1344        assert_eq!(cfg.db_opt.get_host(), "host");
1345        assert_eq!(cfg.db_opt.get_port(), 8080);
1346    }
1347
1348    #[test]
1349    #[cfg(feature = "embedded-db")]
1350    fn test_config_from_str() {
1351        let cfg = Config::from_str("sqlite://data.db").unwrap();
1352        assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1353    }
1354
1355    async fn vacuum(storage: &SqlStorage) {
1356        #[cfg(feature = "embedded-db")]
1357        let query = "PRAGMA incremental_vacuum(16000)";
1358        #[cfg(not(feature = "embedded-db"))]
1359        let query = "VACUUM";
1360        storage
1361            .pool
1362            .acquire()
1363            .await
1364            .unwrap()
1365            .execute(query)
1366            .await
1367            .unwrap();
1368    }
1369
1370    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1371    async fn test_target_period_pruning() {
1372        let db = TmpDb::init().await;
1373        let cfg = db.config();
1374
1375        let mut storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1376            .await
1377            .unwrap();
1378        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1379            &TestValidatedState::default(),
1380            &TestInstanceState::default(),
1381            TEST_VERSIONS.test,
1382        )
1383        .await;
1384        // insert some mock data
1385        for i in 0..20 {
1386            leaf.leaf.block_header_mut().block_number = i;
1387            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1388            let mut tx = storage.write().await.unwrap();
1389            tx.insert_leaf(leaf.clone()).await.unwrap();
1390            tx.commit().await.unwrap();
1391        }
1392
1393        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1394
1395        // Set pruner config to default which has minimum retention set to 1 day
1396        storage.set_pruning_config(PrunerCfg::new());
1397        // No data will be pruned
1398        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1399
1400        // Vacuum the database to reclaim space.
1401        // This is necessary to ensure the test passes.
1402        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1403        vacuum(&storage).await;
1404        // Pruned height should be none
1405        assert!(pruned_height.is_none());
1406
1407        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1408
1409        assert_eq!(
1410            height_after_pruning, height_before_pruning,
1411            "some data has been pruned"
1412        );
1413
1414        // Set pruner config to target retention set to 1s
1415        storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1416        sleep(Duration::from_secs(2)).await;
1417        let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1418        // All of the data is now older than 1s.
1419        // This would prune all the data as the target retention is set to 1s
1420        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1421        // Vacuum the database to reclaim space.
1422        // This is necessary to ensure the test passes.
1423        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1424        vacuum(&storage).await;
1425
1426        // Pruned height should be some
1427        assert!(pruned_height.is_some());
1428        let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1429        // All the tables should be empty
1430        // counting rows in header table
1431        let header_rows = storage
1432            .read()
1433            .await
1434            .unwrap()
1435            .fetch_one("select count(*) as count from header")
1436            .await
1437            .unwrap()
1438            .get::<i64, _>("count");
1439        // the table should be empty
1440        assert_eq!(header_rows, 0);
1441
1442        // counting rows in leaf table.
1443        // Deleting rows from header table would delete rows in all the tables
1444        // as each of table implement "ON DELETE CASCADE" fk constraint with the header table.
1445        let leaf_rows = storage
1446            .read()
1447            .await
1448            .unwrap()
1449            .fetch_one("select count(*) as count from leaf")
1450            .await
1451            .unwrap()
1452            .get::<i64, _>("count");
1453        // the table should be empty
1454        assert_eq!(leaf_rows, 0);
1455
1456        assert!(
1457            usage_before_pruning > usage_after_pruning,
1458            " disk usage should decrease after pruning"
1459        )
1460    }
1461
1462    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1463    async fn test_merklized_state_pruning() {
1464        let db = TmpDb::init().await;
1465        let config = db.config();
1466
1467        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1468            .await
1469            .unwrap();
1470        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1471            MockMerkleTree::new(MockMerkleTree::tree_height());
1472
1473        // insert some entries into the tree and the header table
1474        // Header table is used the get_path query to check if the header exists for the block height.
1475        let mut tx = storage.write().await.unwrap();
1476
1477        for block_height in 0..250 {
1478            test_tree.update(block_height, block_height).unwrap();
1479
1480            // data field of the header
1481            let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1482            tx.upsert(
1483                "header",
1484                [
1485                    "height",
1486                    "hash",
1487                    "payload_hash",
1488                    "timestamp",
1489                    "data",
1490                    "ns_table",
1491                ],
1492                ["height"],
1493                [(
1494                    block_height as i64,
1495                    format!("randomHash{block_height}"),
1496                    "t".to_string(),
1497                    0,
1498                    test_data,
1499                    "ns_table".to_string(),
1500                )],
1501            )
1502            .await
1503            .unwrap();
1504            // proof for the index from the tree
1505            let (_, proof) = test_tree.lookup(block_height).expect_ok().unwrap();
1506            // traversal path for the index.
1507            let traversal_path =
1508                <usize as ToTraversalPath<8>>::to_traversal_path(&block_height, test_tree.height());
1509
1510            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1511                &mut tx,
1512                proof.clone(),
1513                traversal_path.clone(),
1514                block_height as u64,
1515            )
1516            .await
1517            .expect("failed to insert nodes");
1518        }
1519
1520        // update saved state height
1521        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 250)
1522            .await
1523            .unwrap();
1524        tx.commit().await.unwrap();
1525
1526        let mut tx = storage.read().await.unwrap();
1527
1528        // checking if the data is inserted correctly
1529        // there should be multiple nodes with same index but different created time
1530        let (count,) = query_as::<(i64,)>(
1531            " SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1532             count(*) > 1) AS s",
1533        )
1534        .fetch_one(tx.as_mut())
1535        .await
1536        .unwrap();
1537
1538        tracing::info!("Number of nodes with multiple snapshots : {count}");
1539        assert!(count > 0);
1540
1541        // This should delete all the nodes having height < 250 and is not the newest node with its position
1542        let mut tx = storage.write().await.unwrap();
1543        tx.delete_batch(vec!["test_tree".to_string()], 250)
1544            .await
1545            .unwrap();
1546
1547        tx.commit().await.unwrap();
1548        let mut tx = storage.read().await.unwrap();
1549        let (count,) = query_as::<(i64,)>(
1550            "SELECT count(*) FROM (SELECT count(*) as count FROM test_tree GROUP BY path having \
1551             count(*) > 1) AS s",
1552        )
1553        .fetch_one(tx.as_mut())
1554        .await
1555        .unwrap();
1556
1557        tracing::info!("Number of nodes with multiple snapshots : {count}");
1558
1559        assert!(count == 0);
1560    }
1561
1562    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1563    async fn test_minimum_retention_pruning() {
1564        let db = TmpDb::init().await;
1565
1566        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1567            .await
1568            .unwrap();
1569        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1570            &TestValidatedState::default(),
1571            &TestInstanceState::default(),
1572            TEST_VERSIONS.test,
1573        )
1574        .await;
1575        // insert some mock data
1576        for i in 0..20 {
1577            leaf.leaf.block_header_mut().block_number = i;
1578            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1579            let mut tx = storage.write().await.unwrap();
1580            tx.insert_leaf(leaf.clone()).await.unwrap();
1581            tx.commit().await.unwrap();
1582        }
1583
1584        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1585        let cfg = PrunerCfg::new();
1586        // Set pruning_threshold to 1
1587        // SQL storage size is more than 1000 bytes even without any data indexed
1588        // This would mean that the threshold would always be greater than the disk usage
1589        // However, minimum retention is set to 24 hours by default so the data would not be pruned
1590        storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1591        println!("{:?}", storage.get_pruning_config().unwrap());
1592        // Pruning would not delete any data
1593        // All the data is younger than minimum retention period even though the usage > threshold
1594        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1595        // Vacuum the database to reclaim space.
1596        // This is necessary to ensure the test passes.
1597        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1598        vacuum(&storage).await;
1599
1600        // Pruned height should be none
1601        assert!(pruned_height.is_none());
1602
1603        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1604
1605        assert_eq!(
1606            height_after_pruning, height_before_pruning,
1607            "some data has been pruned"
1608        );
1609
1610        // Change minimum retention to 1s
1611        storage.set_pruning_config(
1612            cfg.with_minimum_retention(Duration::from_secs(1))
1613                .with_pruning_threshold(1),
1614        );
1615        // sleep for 2s to make sure the data is older than minimum retention
1616        sleep(Duration::from_secs(2)).await;
1617        // This would prune all the data
1618        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1619        // Vacuum the database to reclaim space.
1620        // This is necessary to ensure the test passes.
1621        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1622        vacuum(&storage).await;
1623
1624        // Pruned height should be some
1625        assert!(pruned_height.is_some());
1626        // All the tables should be empty
1627        // counting rows in header table
1628        let header_rows = storage
1629            .read()
1630            .await
1631            .unwrap()
1632            .fetch_one("select count(*) as count from header")
1633            .await
1634            .unwrap()
1635            .get::<i64, _>("count");
1636        // the table should be empty
1637        assert_eq!(header_rows, 0);
1638    }
1639
1640    #[tokio::test]
1641    #[test_log::test]
1642    async fn test_payload_pruning() {
1643        let db = TmpDb::init().await;
1644        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1645            .await
1646            .unwrap();
1647        storage.set_pruning_config(Default::default());
1648
1649        // Insert some mock data.
1650        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1651            &TestValidatedState::default(),
1652            &TestInstanceState::default(),
1653            TEST_VERSIONS.test,
1654        )
1655        .await;
1656        let block = BlockQueryData::<MockTypes>::genesis(
1657            &Default::default(),
1658            &Default::default(),
1659            TEST_VERSIONS.test.base,
1660        )
1661        .await;
1662        let vid = VidCommonQueryData::<MockTypes>::genesis(
1663            &Default::default(),
1664            &Default::default(),
1665            TEST_VERSIONS.test.base,
1666        )
1667        .await;
1668        {
1669            let mut tx = storage.write().await.unwrap();
1670            tx.insert_leaf(leaf.clone()).await.unwrap();
1671            tx.insert_block(block.clone()).await.unwrap();
1672            tx.insert_vid(vid.clone(), None).await.unwrap();
1673            tx.commit().await.unwrap();
1674        }
1675
1676        // Insert a second leaf sharing the same payload.
1677        leaf.leaf.block_header_mut().block_number += 1;
1678        {
1679            let mut tx = storage.write().await.unwrap();
1680            tx.insert_leaf(leaf.clone()).await.unwrap();
1681            tx.commit().await.unwrap();
1682        }
1683        {
1684            let mut tx = storage.read().await.unwrap();
1685            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1686                .fetch_one(tx.as_mut())
1687                .await
1688                .unwrap();
1689            assert_eq!(num_payloads, 1);
1690            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1691                .fetch_one(tx.as_mut())
1692                .await
1693                .unwrap();
1694            assert_eq!(num_vid, 1);
1695        }
1696
1697        // Prune the first leaf but not the second (and thus not the payload or VID).
1698        let pruned_height = storage
1699            .prune(&mut Pruner {
1700                pruned_height: None,
1701                target_height: Some(0),
1702                minimum_retention_height: None,
1703            })
1704            .await
1705            .unwrap();
1706        tracing::info!(?pruned_height, "first pruning run complete");
1707        {
1708            let mut tx = storage.read().await.unwrap();
1709
1710            // First block is pruned.
1711            let err = tx
1712                .get_block(BlockId::<MockTypes>::Number(0))
1713                .await
1714                .unwrap_err();
1715            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1716            let err = tx
1717                .get_vid_common(BlockId::<MockTypes>::Number(0))
1718                .await
1719                .unwrap_err();
1720            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1721
1722            // Second block is still available.
1723            assert_eq!(
1724                tx.get_block(BlockId::<MockTypes>::Number(1)).await.unwrap(),
1725                BlockQueryData::new(leaf.header().clone(), block.payload)
1726            );
1727            assert_eq!(
1728                tx.get_vid_common(BlockId::<MockTypes>::Number(1))
1729                    .await
1730                    .unwrap(),
1731                VidCommonQueryData::new(leaf.header().clone(), vid.common)
1732            );
1733
1734            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1735                .fetch_one(tx.as_mut())
1736                .await
1737                .unwrap();
1738            assert_eq!(num_payloads, 1);
1739
1740            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1741                .fetch_one(tx.as_mut())
1742                .await
1743                .unwrap();
1744            assert_eq!(num_vid, 1);
1745        }
1746
1747        // Now prune the second leaf, ensuring the payload and VID get deleted as well.
1748        let pruned_height = storage
1749            .prune(&mut Pruner {
1750                pruned_height,
1751                target_height: Some(1),
1752                minimum_retention_height: None,
1753            })
1754            .await
1755            .unwrap();
1756        tracing::info!(?pruned_height, "second pruning run complete");
1757
1758        let mut tx = storage.read().await.unwrap();
1759        for i in 0..2 {
1760            let err = tx
1761                .get_block(BlockId::<MockTypes>::Number(i))
1762                .await
1763                .unwrap_err();
1764            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1765
1766            let err = tx
1767                .get_vid_common(BlockId::<MockTypes>::Number(i))
1768                .await
1769                .unwrap_err();
1770            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1771        }
1772        let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1773            .fetch_one(tx.as_mut())
1774            .await
1775            .unwrap();
1776        assert_eq!(num_payloads, 0);
1777
1778        let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1779            .fetch_one(tx.as_mut())
1780            .await
1781            .unwrap();
1782        assert_eq!(num_vid, 0);
1783    }
1784
1785    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1786    async fn test_pruned_height_storage() {
1787        let db = TmpDb::init().await;
1788        let cfg = db.config();
1789
1790        let storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1791            .await
1792            .unwrap();
1793        assert!(
1794            storage
1795                .read()
1796                .await
1797                .unwrap()
1798                .load_pruned_height()
1799                .await
1800                .unwrap()
1801                .is_none()
1802        );
1803        for height in [10, 20, 30] {
1804            let mut tx = storage.write().await.unwrap();
1805            tx.save_pruned_height(height).await.unwrap();
1806            tx.commit().await.unwrap();
1807            assert_eq!(
1808                storage
1809                    .read()
1810                    .await
1811                    .unwrap()
1812                    .load_pruned_height()
1813                    .await
1814                    .unwrap(),
1815                Some(height)
1816            );
1817        }
1818    }
1819
1820    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1821    async fn test_transaction_upsert_retries() {
1822        let db = TmpDb::init().await;
1823        let config = db.config();
1824
1825        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1826            .await
1827            .unwrap();
1828
1829        let mut tx = storage.write().await.unwrap();
1830
1831        // Try to upsert into a table that does not exist.
1832        // This will fail, so our `upsert` function will enter the retry loop.
1833        // Since the table does not exist, all retries will eventually
1834        // fail and we expect an error to be returned.
1835        //
1836        // Previously, this case would cause  a panic because we were calling
1837        // methods on `QueryBuilder` after `.build()` without first
1838        // calling `.reset()`and according to the sqlx docs, that always panics.
1839        // Now, since we are properly calling `.reset()` inside `upsert()` for
1840        // the query builder, the function returns an error instead of panicking.
1841        tx.upsert("does_not_exist", ["test"], ["test"], [(1_i64,)])
1842            .await
1843            .unwrap_err();
1844    }
1845}