hotshot_query_service/data_source/storage/sql/
queries.rs

1#![allow(clippy::needless_lifetimes)]
2// Copyright (c) 2022 Espresso Systems (espressosys.com)
3// This file is part of the HotShot Query Service library.
4//
5// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
6// General Public License as published by the Free Software Foundation, either version 3 of the
7// License, or (at your option) any later version.
8// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
9// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10// General Public License for more details.
11// You should have received a copy of the GNU General Public License along with this program. If not,
12// see <https://www.gnu.org/licenses/>.
13
14//! Immutable query functionality of a SQL database.
15
16use std::{
17    fmt::Display,
18    ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24    simple_certificate::{
25        LightClientStateUpdateCertificateV1, LightClientStateUpdateCertificateV2,
26        QuorumCertificate2,
27    },
28    traits::{
29        block_contents::{BlockHeader, BlockPayload},
30        node_implementation::NodeType,
31    },
32};
33use sqlx::{Arguments, FromRow, Row};
34
35use super::{Database, Db, Query, QueryAs, Transaction};
36use crate::{
37    availability::{
38        BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
39        QueryablePayload, StateCertQueryDataV2, VidCommonQueryData,
40    },
41    data_source::storage::{PayloadMetadata, VidCommonMetadata},
42    Header, Leaf2, Payload, QueryError, QueryResult,
43};
44
45pub(super) mod availability;
46pub(super) mod explorer;
47pub(super) mod node;
48pub(super) mod state;
49
50/// Helper type for programmatically constructing queries.
51///
52/// This type can be used to bind arguments of various types, similar to [`Query`] or [`QueryAs`].
53/// With [`QueryBuilder`], though, the arguments are bound *first* and the SQL statement is given
54/// last. Each time an argument is bound, a SQL fragment is returned as a string which can be used
55/// to represent that argument in the statement (e.g. `$1` for the first argument bound). This makes
56/// it easier to programmatically construct queries where the statement is not a compile time
57/// constant.
58///
59/// # Example
60///
61/// ```
62/// # use hotshot_query_service::{
63/// #   data_source::storage::sql::{
64/// #       Database, Db, QueryBuilder, Transaction,
65/// #   },
66/// #   QueryResult,
67/// # };
68/// # use sqlx::FromRow;
69/// async fn search_and_maybe_filter<T, Mode>(
70///     tx: &mut Transaction<Mode>,
71///     id: Option<i64>,
72/// ) -> QueryResult<Vec<T>>
73/// where
74///     for<'r> T: FromRow<'r, <Db as Database>::Row> + Send + Unpin,
75/// {
76///     let mut query = QueryBuilder::default();
77///     let mut sql = "SELECT * FROM table".into();
78///     if let Some(id) = id {
79///         sql = format!("{sql} WHERE id = {}", query.bind(id)?);
80///     }
81///     let results = query
82///         .query_as(&sql)
83///         .fetch_all(tx.as_mut())
84///         .await?;
85///     Ok(results)
86/// }
87/// ```
88#[derive(Derivative, Default)]
89#[derivative(Debug)]
90pub struct QueryBuilder<'a> {
91    #[derivative(Debug = "ignore")]
92    arguments: <Db as Database>::Arguments<'a>,
93}
94
95impl<'q> QueryBuilder<'q> {
96    /// Add an argument and return its name as a formal parameter in a SQL prepared statement.
97    pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
98    where
99        T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
100    {
101        self.arguments.add(arg).map_err(|err| QueryError::Error {
102            message: format!("{err:#}"),
103        })?;
104
105        Ok(format!("${}", self.arguments.len()))
106    }
107
108    /// Finalize the query with a constructed SQL statement.
109    pub fn query(self, sql: &'q str) -> Query<'q> {
110        sqlx::query_with(sql, self.arguments)
111    }
112
113    /// Finalize the query with a constructed SQL statement and a specified output type.
114    pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
115    where
116        T: for<'r> FromRow<'r, <Db as Database>::Row>,
117    {
118        sqlx::query_as_with(sql, self.arguments)
119    }
120}
121
122impl QueryBuilder<'_> {
123    /// Construct a SQL `WHERE` clause which filters for a header exactly matching `id`.
124    pub fn header_where_clause<Types: NodeType>(
125        &mut self,
126        id: BlockId<Types>,
127    ) -> QueryResult<String> {
128        let clause = match id {
129            BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
130            BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
131            BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
132        };
133        Ok(clause)
134    }
135
136    /// Convert range bounds to a SQL `WHERE` clause constraining a given column.
137    pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
138    where
139        R: RangeBounds<usize>,
140    {
141        let mut bounds = vec![];
142
143        match range.start_bound() {
144            Bound::Included(n) => {
145                bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
146            },
147            Bound::Excluded(n) => {
148                bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
149            },
150            Bound::Unbounded => {},
151        }
152        match range.end_bound() {
153            Bound::Included(n) => {
154                bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
155            },
156            Bound::Excluded(n) => {
157                bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
158            },
159            Bound::Unbounded => {},
160        }
161
162        let mut where_clause = bounds.join(" AND ");
163        if !where_clause.is_empty() {
164            where_clause = format!(" WHERE {where_clause}");
165        }
166
167        Ok(where_clause)
168    }
169}
170
171const LEAF_COLUMNS: &str = "leaf, qc";
172
173impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
174where
175    Types: NodeType,
176{
177    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
178        let leaf = row.try_get("leaf")?;
179        let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
180
181        let qc = row.try_get("qc")?;
182        let qc: QuorumCertificate2<Types> =
183            serde_json::from_value(qc).decode_error("malformed QC")?;
184
185        Ok(Self { leaf, qc })
186    }
187}
188
189const BLOCK_COLUMNS: &str =
190    "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
191
192impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
193where
194    Types: NodeType,
195    Header<Types>: QueryableHeader<Types>,
196    Payload<Types>: QueryablePayload<Types>,
197{
198    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
199        // First, check if we have the payload for this block yet.
200        let size: Option<i32> = row.try_get("payload_size")?;
201        let payload_data: Option<Vec<u8>> = row.try_get("payload_data")?;
202        let (size, payload_data) = size.zip(payload_data).ok_or(sqlx::Error::RowNotFound)?;
203        let size = size as u64;
204
205        // Reconstruct the full header.
206        let header_data = row.try_get("header_data")?;
207        let header: Header<Types> =
208            serde_json::from_value(header_data).decode_error("malformed header")?;
209
210        // Reconstruct the full block payload.
211        let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
212
213        // Reconstruct the query data by adding metadata.
214        let hash: String = row.try_get("hash")?;
215        let hash = hash.parse().decode_error("malformed block hash")?;
216
217        Ok(Self {
218            num_transactions: payload.len(header.metadata()) as u64,
219            header,
220            payload,
221            size,
222            hash,
223        })
224    }
225}
226
227const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
228
229impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
230where
231    Types: NodeType,
232    Header<Types>: QueryableHeader<Types>,
233    Payload<Types>: QueryablePayload<Types>,
234{
235    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
236        <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
237    }
238}
239
240const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
241                                        payload_hash, p.size AS payload_size, p.num_transactions \
242                                        AS num_transactions";
243
244impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
245where
246    Types: NodeType,
247    Header<Types>: QueryableHeader<Types>,
248{
249    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
250        Ok(Self {
251            height: row.try_get::<i64, _>("height")? as u64,
252            block_hash: row
253                .try_get::<String, _>("hash")?
254                .parse()
255                .decode_error("malformed block hash")?,
256            hash: row
257                .try_get::<String, _>("payload_hash")?
258                .parse()
259                .decode_error("malformed payload hash")?,
260            size: row
261                .try_get::<Option<i32>, _>("payload_size")?
262                .ok_or(sqlx::Error::RowNotFound)? as u64,
263            num_transactions: row
264                .try_get::<Option<i32>, _>("num_transactions")?
265                .ok_or(sqlx::Error::RowNotFound)? as u64,
266
267            // Per-namespace info must be loaded in a separate query.
268            namespaces: Default::default(),
269        })
270    }
271}
272
273const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
274                                  payload_hash, v.common AS common_data";
275
276impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
277where
278    Types: NodeType,
279    Header<Types>: QueryableHeader<Types>,
280    Payload<Types>: QueryablePayload<Types>,
281{
282    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
283        let height = row.try_get::<i64, _>("height")? as u64;
284        let block_hash: String = row.try_get("block_hash")?;
285        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
286        let payload_hash: String = row.try_get("payload_hash")?;
287        let payload_hash = payload_hash
288            .parse()
289            .decode_error("malformed payload hash")?;
290        let common_data: Vec<u8> = row.try_get("common_data")?;
291        let common =
292            bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
293        Ok(Self {
294            height,
295            block_hash,
296            payload_hash,
297            common,
298        })
299    }
300}
301
302const VID_COMMON_METADATA_COLUMNS: &str =
303    "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
304
305impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
306where
307    Types: NodeType,
308    Header<Types>: QueryableHeader<Types>,
309    Payload<Types>: QueryablePayload<Types>,
310{
311    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
312        let height = row.try_get::<i64, _>("height")? as u64;
313        let block_hash: String = row.try_get("block_hash")?;
314        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
315        let payload_hash: String = row.try_get("payload_hash")?;
316        let payload_hash = payload_hash
317            .parse()
318            .decode_error("malformed payload hash")?;
319        Ok(Self {
320            height,
321            block_hash,
322            payload_hash,
323        })
324    }
325}
326
327const HEADER_COLUMNS: &str = "h.data AS data";
328
329// We can't implement `FromRow` for `Header<Types>` since `Header<Types>` is not actually a type
330// defined in this crate; it's just an alias for `Types::BlockHeader`. So this standalone function
331// will have to do.
332fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
333where
334    Types: NodeType,
335{
336    // Reconstruct the full header.
337    let data = row.try_get("data")?;
338    serde_json::from_value(data).decode_error("malformed header")
339}
340
341impl From<sqlx::Error> for QueryError {
342    fn from(err: sqlx::Error) -> Self {
343        if matches!(err, sqlx::Error::RowNotFound) {
344            Self::NotFound
345        } else {
346            Self::Error {
347                message: err.to_string(),
348            }
349        }
350    }
351}
352
353const STATE_CERT_COLUMNS: &str = "state_cert";
354
355impl<'r, Types> FromRow<'r, <Db as Database>::Row> for StateCertQueryDataV2<Types>
356where
357    Types: NodeType,
358{
359    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
360        let state_cert: LightClientStateUpdateCertificateV2<Types> = {
361            let bytes: &[u8] = row.try_get("state_cert")?;
362            match bincode::deserialize::<LightClientStateUpdateCertificateV2<Types>>(bytes) {
363                Ok(cert) => cert,
364                Err(err) => {
365                    tracing::info!(
366                        "Falling back to V1 deserialization for LightClientStateUpdateCertificate"
367                    );
368
369                    match bincode::deserialize::<LightClientStateUpdateCertificateV1<Types>>(bytes)
370                    {
371                        Ok(legacy) => legacy.into(),
372                        Err(err_legacy) => {
373                            tracing::error!(
374                                "Failed to deserialize state_cert with v1 and v2 v2 error: {err}. \
375                                 v1 error: {err_legacy}",
376                            );
377                            return Err(sqlx::Error::Decode(err_legacy));
378                        },
379                    }
380                },
381            }
382        };
383        Ok(state_cert.into())
384    }
385}
386
387impl<Mode> Transaction<Mode> {
388    /// Load a header from storage.
389    ///
390    /// This function is similar to `AvailabilityStorage::get_header`, but
391    /// * does not require the `QueryablePayload<Types>` bound that that trait impl does
392    /// * makes it easier to specify types since the type parameter is on the function and not on a
393    ///   trait impl
394    /// * allows type conversions for the `id` parameter
395    ///
396    /// This more ergonomic interface is useful as loading headers is important for many SQL storage
397    /// functions, not just the `AvailabilityStorage` interface.
398    pub async fn load_header<Types: NodeType>(
399        &mut self,
400        id: impl Into<BlockId<Types>> + Send,
401    ) -> QueryResult<Header<Types>> {
402        let mut query = QueryBuilder::default();
403        let where_clause = query.header_where_clause(id.into())?;
404        // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when
405        // selecting by payload ID, as payloads are not unique), we return the first one.
406        let sql = format!(
407            "SELECT {HEADER_COLUMNS}
408               FROM header AS h
409              WHERE {where_clause}
410              ORDER BY h.height
411              LIMIT 1"
412        );
413
414        let row = query.query(&sql).fetch_one(self.as_mut()).await?;
415        let header = parse_header::<Types>(row)?;
416
417        Ok(header)
418    }
419}
420
421pub(super) trait DecodeError {
422    type Ok;
423    fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
424}
425
426impl<T, E> DecodeError for Result<T, E>
427where
428    E: std::error::Error + Send + Sync + 'static,
429{
430    type Ok = T;
431    fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
432        self.context(msg.to_string())
433            .map_err(|err| sqlx::Error::Decode(err.into()))
434    }
435}