hotshot_query_service/data_source/storage/sql/
queries.rs

1#![allow(clippy::needless_lifetimes)]
2// Copyright (c) 2022 Espresso Systems (espressosys.com)
3// This file is part of the HotShot Query Service library.
4//
5// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
6// General Public License as published by the Free Software Foundation, either version 3 of the
7// License, or (at your option) any later version.
8// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
9// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10// General Public License for more details.
11// You should have received a copy of the GNU General Public License along with this program. If not,
12// see <https://www.gnu.org/licenses/>.
13
14//! Immutable query functionality of a SQL database.
15
16use std::{
17    fmt::Display,
18    ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24    simple_certificate::QuorumCertificate2,
25    traits::{
26        block_contents::{BlockHeader, BlockPayload},
27        node_implementation::NodeType,
28    },
29};
30use sqlx::{Arguments, FromRow, Row};
31
32use super::{Database, Db, Query, QueryAs, Transaction};
33use crate::{
34    availability::{
35        BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
36        QueryablePayload, VidCommonQueryData,
37    },
38    data_source::storage::{PayloadMetadata, VidCommonMetadata},
39    Header, Leaf2, Payload, QueryError, QueryResult,
40};
41
42pub(super) mod availability;
43pub(super) mod explorer;
44pub(super) mod node;
45pub(super) mod state;
46
47/// Helper type for programmatically constructing queries.
48///
49/// This type can be used to bind arguments of various types, similar to [`Query`] or [`QueryAs`].
50/// With [`QueryBuilder`], though, the arguments are bound *first* and the SQL statement is given
51/// last. Each time an argument is bound, a SQL fragment is returned as a string which can be used
52/// to represent that argument in the statement (e.g. `$1` for the first argument bound). This makes
53/// it easier to programmatically construct queries where the statement is not a compile time
54/// constant.
55///
56/// # Example
57///
58/// ```
59/// # use hotshot_query_service::{
60/// #   data_source::storage::sql::{
61/// #       Database, Db, QueryBuilder, Transaction,
62/// #   },
63/// #   QueryResult,
64/// # };
65/// # use sqlx::FromRow;
66/// async fn search_and_maybe_filter<T, Mode>(
67///     tx: &mut Transaction<Mode>,
68///     id: Option<i64>,
69/// ) -> QueryResult<Vec<T>>
70/// where
71///     for<'r> T: FromRow<'r, <Db as Database>::Row> + Send + Unpin,
72/// {
73///     let mut query = QueryBuilder::default();
74///     let mut sql = "SELECT * FROM table".into();
75///     if let Some(id) = id {
76///         sql = format!("{sql} WHERE id = {}", query.bind(id)?);
77///     }
78///     let results = query
79///         .query_as(&sql)
80///         .fetch_all(tx.as_mut())
81///         .await?;
82///     Ok(results)
83/// }
84/// ```
85#[derive(Derivative, Default)]
86#[derivative(Debug)]
87pub struct QueryBuilder<'a> {
88    #[derivative(Debug = "ignore")]
89    arguments: <Db as Database>::Arguments<'a>,
90}
91
92impl<'q> QueryBuilder<'q> {
93    /// Add an argument and return its name as a formal parameter in a SQL prepared statement.
94    pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
95    where
96        T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
97    {
98        self.arguments.add(arg).map_err(|err| QueryError::Error {
99            message: format!("{err:#}"),
100        })?;
101
102        Ok(format!("${}", self.arguments.len()))
103    }
104
105    /// Finalize the query with a constructed SQL statement.
106    pub fn query(self, sql: &'q str) -> Query<'q> {
107        sqlx::query_with(sql, self.arguments)
108    }
109
110    /// Finalize the query with a constructed SQL statement and a specified output type.
111    pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
112    where
113        T: for<'r> FromRow<'r, <Db as Database>::Row>,
114    {
115        sqlx::query_as_with(sql, self.arguments)
116    }
117}
118
119impl QueryBuilder<'_> {
120    /// Construct a SQL `WHERE` clause which filters for a header exactly matching `id`.
121    pub fn header_where_clause<Types: NodeType>(
122        &mut self,
123        id: BlockId<Types>,
124    ) -> QueryResult<String> {
125        let clause = match id {
126            BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
127            BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
128            BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
129        };
130        Ok(clause)
131    }
132
133    /// Convert range bounds to a SQL `WHERE` clause constraining a given column.
134    pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
135    where
136        R: RangeBounds<usize>,
137    {
138        let mut bounds = vec![];
139
140        match range.start_bound() {
141            Bound::Included(n) => {
142                bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
143            },
144            Bound::Excluded(n) => {
145                bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
146            },
147            Bound::Unbounded => {},
148        }
149        match range.end_bound() {
150            Bound::Included(n) => {
151                bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
152            },
153            Bound::Excluded(n) => {
154                bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
155            },
156            Bound::Unbounded => {},
157        }
158
159        let mut where_clause = bounds.join(" AND ");
160        if !where_clause.is_empty() {
161            where_clause = format!(" WHERE {where_clause}");
162        }
163
164        Ok(where_clause)
165    }
166}
167
168const LEAF_COLUMNS: &str = "leaf, qc";
169
170impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
171where
172    Types: NodeType,
173{
174    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
175        let leaf = row.try_get("leaf")?;
176        let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
177
178        let qc = row.try_get("qc")?;
179        let qc: QuorumCertificate2<Types> =
180            serde_json::from_value(qc).decode_error("malformed QC")?;
181
182        Ok(Self { leaf, qc })
183    }
184}
185
186const BLOCK_COLUMNS: &str =
187    "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
188
189impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
190where
191    Types: NodeType,
192    Header<Types>: QueryableHeader<Types>,
193    Payload<Types>: QueryablePayload<Types>,
194{
195    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
196        // First, check if we have the payload for this block yet.
197        let size: Option<i32> = row.try_get("payload_size")?;
198        let payload_data: Option<Vec<u8>> = row.try_get("payload_data")?;
199        let (size, payload_data) = size.zip(payload_data).ok_or(sqlx::Error::RowNotFound)?;
200        let size = size as u64;
201
202        // Reconstruct the full header.
203        let header_data = row.try_get("header_data")?;
204        let header: Header<Types> =
205            serde_json::from_value(header_data).decode_error("malformed header")?;
206
207        // Reconstruct the full block payload.
208        let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
209
210        // Reconstruct the query data by adding metadata.
211        let hash: String = row.try_get("hash")?;
212        let hash = hash.parse().decode_error("malformed block hash")?;
213
214        Ok(Self {
215            num_transactions: payload.len(header.metadata()) as u64,
216            header,
217            payload,
218            size,
219            hash,
220        })
221    }
222}
223
224const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
225
226impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
227where
228    Types: NodeType,
229    Header<Types>: QueryableHeader<Types>,
230    Payload<Types>: QueryablePayload<Types>,
231{
232    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
233        <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
234    }
235}
236
237const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
238                                        payload_hash, p.size AS payload_size, p.num_transactions \
239                                        AS num_transactions";
240
241impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
242where
243    Types: NodeType,
244    Header<Types>: QueryableHeader<Types>,
245{
246    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
247        Ok(Self {
248            height: row.try_get::<i64, _>("height")? as u64,
249            block_hash: row
250                .try_get::<String, _>("hash")?
251                .parse()
252                .decode_error("malformed block hash")?,
253            hash: row
254                .try_get::<String, _>("payload_hash")?
255                .parse()
256                .decode_error("malformed payload hash")?,
257            size: row
258                .try_get::<Option<i32>, _>("payload_size")?
259                .ok_or(sqlx::Error::RowNotFound)? as u64,
260            num_transactions: row
261                .try_get::<Option<i32>, _>("num_transactions")?
262                .ok_or(sqlx::Error::RowNotFound)? as u64,
263
264            // Per-namespace info must be loaded in a separate query.
265            namespaces: Default::default(),
266        })
267    }
268}
269
270const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
271                                  payload_hash, v.common AS common_data";
272
273impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
274where
275    Types: NodeType,
276    Header<Types>: QueryableHeader<Types>,
277    Payload<Types>: QueryablePayload<Types>,
278{
279    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
280        let height = row.try_get::<i64, _>("height")? as u64;
281        let block_hash: String = row.try_get("block_hash")?;
282        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
283        let payload_hash: String = row.try_get("payload_hash")?;
284        let payload_hash = payload_hash
285            .parse()
286            .decode_error("malformed payload hash")?;
287        let common_data: Vec<u8> = row.try_get("common_data")?;
288        let common =
289            bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
290        Ok(Self {
291            height,
292            block_hash,
293            payload_hash,
294            common,
295        })
296    }
297}
298
299const VID_COMMON_METADATA_COLUMNS: &str =
300    "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
301
302impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
303where
304    Types: NodeType,
305    Header<Types>: QueryableHeader<Types>,
306    Payload<Types>: QueryablePayload<Types>,
307{
308    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
309        let height = row.try_get::<i64, _>("height")? as u64;
310        let block_hash: String = row.try_get("block_hash")?;
311        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
312        let payload_hash: String = row.try_get("payload_hash")?;
313        let payload_hash = payload_hash
314            .parse()
315            .decode_error("malformed payload hash")?;
316        Ok(Self {
317            height,
318            block_hash,
319            payload_hash,
320        })
321    }
322}
323
324const HEADER_COLUMNS: &str = "h.data AS data";
325
326// We can't implement `FromRow` for `Header<Types>` since `Header<Types>` is not actually a type
327// defined in this crate; it's just an alias for `Types::BlockHeader`. So this standalone function
328// will have to do.
329fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
330where
331    Types: NodeType,
332{
333    // Reconstruct the full header.
334    let data = row.try_get("data")?;
335    serde_json::from_value(data).decode_error("malformed header")
336}
337
338impl From<sqlx::Error> for QueryError {
339    fn from(err: sqlx::Error) -> Self {
340        if matches!(err, sqlx::Error::RowNotFound) {
341            Self::NotFound
342        } else {
343            Self::Error {
344                message: err.to_string(),
345            }
346        }
347    }
348}
349
350impl<Mode> Transaction<Mode> {
351    /// Load a header from storage.
352    ///
353    /// This function is similar to `AvailabilityStorage::get_header`, but
354    /// * does not require the `QueryablePayload<Types>` bound that that trait impl does
355    /// * makes it easier to specify types since the type parameter is on the function and not on a
356    ///   trait impl
357    /// * allows type conversions for the `id` parameter
358    ///
359    /// This more ergonomic interface is useful as loading headers is important for many SQL storage
360    /// functions, not just the `AvailabilityStorage` interface.
361    pub async fn load_header<Types: NodeType>(
362        &mut self,
363        id: impl Into<BlockId<Types>> + Send,
364    ) -> QueryResult<Header<Types>> {
365        let mut query = QueryBuilder::default();
366        let where_clause = query.header_where_clause(id.into())?;
367        // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when
368        // selecting by payload ID, as payloads are not unique), we return the first one.
369        let sql = format!(
370            "SELECT {HEADER_COLUMNS}
371               FROM header AS h
372              WHERE {where_clause}
373              ORDER BY h.height
374              LIMIT 1"
375        );
376
377        let row = query.query(&sql).fetch_one(self.as_mut()).await?;
378        let header = parse_header::<Types>(row)?;
379
380        Ok(header)
381    }
382}
383
384pub(super) trait DecodeError {
385    type Ok;
386    fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
387}
388
389impl<T, E> DecodeError for Result<T, E>
390where
391    E: std::error::Error + Send + Sync + 'static,
392{
393    type Ok = T;
394    fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
395        self.context(msg.to_string())
396            .map_err(|err| sqlx::Error::Decode(err.into()))
397    }
398}