hotshot_query_service/data_source/storage/sql/
queries.rs

1#![allow(clippy::needless_lifetimes)]
2// Copyright (c) 2022 Espresso Systems (espressosys.com)
3// This file is part of the HotShot Query Service library.
4//
5// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
6// General Public License as published by the Free Software Foundation, either version 3 of the
7// License, or (at your option) any later version.
8// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
9// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10// General Public License for more details.
11// You should have received a copy of the GNU General Public License along with this program. If not,
12// see <https://www.gnu.org/licenses/>.
13
14//! Immutable query functionality of a SQL database.
15
16use std::{
17    fmt::Display,
18    ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24    simple_certificate::QuorumCertificate2,
25    traits::{
26        block_contents::{BlockHeader, BlockPayload},
27        node_implementation::NodeType,
28    },
29};
30use sqlx::{Arguments, FromRow, Row};
31
32use super::{Database, Db, Query, QueryAs, Transaction};
33use crate::{
34    Header, Leaf2, Payload, QueryError, QueryResult,
35    availability::{
36        BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
37        QueryablePayload, VidCommonQueryData,
38    },
39    data_source::storage::{PayloadMetadata, VidCommonMetadata},
40};
41
42pub(super) mod availability;
43pub(super) mod explorer;
44pub(super) mod node;
45pub(super) mod state;
46
47/// Helper type for programmatically constructing queries.
48///
49/// This type can be used to bind arguments of various types, similar to [`Query`] or [`QueryAs`].
50/// With [`QueryBuilder`], though, the arguments are bound *first* and the SQL statement is given
51/// last. Each time an argument is bound, a SQL fragment is returned as a string which can be used
52/// to represent that argument in the statement (e.g. `$1` for the first argument bound). This makes
53/// it easier to programmatically construct queries where the statement is not a compile time
54/// constant.
55///
56/// # Example
57///
58/// ```
59/// # use hotshot_query_service::{
60/// #   data_source::storage::sql::{
61/// #       Database, Db, QueryBuilder, Transaction,
62/// #   },
63/// #   QueryResult,
64/// # };
65/// # use sqlx::FromRow;
66/// async fn search_and_maybe_filter<T, Mode>(
67///     tx: &mut Transaction<Mode>,
68///     id: Option<i64>,
69/// ) -> QueryResult<Vec<T>>
70/// where
71///     for<'r> T: FromRow<'r, <Db as Database>::Row> + Send + Unpin,
72/// {
73///     let mut query = QueryBuilder::default();
74///     let mut sql = "SELECT * FROM table".into();
75///     if let Some(id) = id {
76///         sql = format!("{sql} WHERE id = {}", query.bind(id)?);
77///     }
78///     let results = query
79///         .query_as(&sql)
80///         .fetch_all(tx.as_mut())
81///         .await?;
82///     Ok(results)
83/// }
84/// ```
85#[derive(Derivative, Default)]
86#[derivative(Debug)]
87pub struct QueryBuilder<'a> {
88    #[derivative(Debug = "ignore")]
89    arguments: <Db as Database>::Arguments<'a>,
90}
91
92impl<'q> QueryBuilder<'q> {
93    /// Add an argument and return its name as a formal parameter in a SQL prepared statement.
94    pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
95    where
96        T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
97    {
98        self.arguments.add(arg).map_err(|err| QueryError::Error {
99            message: format!("{err:#}"),
100        })?;
101
102        Ok(format!("${}", self.arguments.len()))
103    }
104
105    /// Finalize the query with a constructed SQL statement.
106    pub fn query(self, sql: &'q str) -> Query<'q> {
107        sqlx::query_with(sql, self.arguments)
108    }
109
110    /// Finalize the query with a constructed SQL statement and a specified output type.
111    pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
112    where
113        T: for<'r> FromRow<'r, <Db as Database>::Row>,
114    {
115        sqlx::query_as_with(sql, self.arguments)
116    }
117}
118
119impl QueryBuilder<'_> {
120    /// Construct a SQL `WHERE` clause which filters for a header exactly matching `id`.
121    pub fn header_where_clause<Types: NodeType>(
122        &mut self,
123        id: BlockId<Types>,
124    ) -> QueryResult<String> {
125        let clause = match id {
126            BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
127            BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
128            BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
129        };
130        Ok(clause)
131    }
132
133    /// Convert range bounds to a SQL `WHERE` clause constraining a given column.
134    pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
135    where
136        R: RangeBounds<usize>,
137    {
138        let mut bounds = vec![];
139
140        match range.start_bound() {
141            Bound::Included(n) => {
142                bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
143            },
144            Bound::Excluded(n) => {
145                bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
146            },
147            Bound::Unbounded => {},
148        }
149        match range.end_bound() {
150            Bound::Included(n) => {
151                bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
152            },
153            Bound::Excluded(n) => {
154                bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
155            },
156            Bound::Unbounded => {},
157        }
158
159        let mut where_clause = bounds.join(" AND ");
160        if !where_clause.is_empty() {
161            where_clause = format!(" WHERE {where_clause}");
162        }
163
164        Ok(where_clause)
165    }
166}
167
168const LEAF_COLUMNS: &str = "leaf, qc";
169
170impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
171where
172    Types: NodeType,
173{
174    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
175        let leaf = row.try_get("leaf")?;
176        let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
177
178        let qc = row.try_get("qc")?;
179        let qc: QuorumCertificate2<Types> =
180            serde_json::from_value(qc).decode_error("malformed QC")?;
181
182        Ok(Self { leaf, qc })
183    }
184}
185
186const BLOCK_COLUMNS: &str =
187    "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
188
189impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
190where
191    Types: NodeType,
192    Header<Types>: QueryableHeader<Types>,
193    Payload<Types>: QueryablePayload<Types>,
194{
195    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
196        // First, check if we have the payload for this block yet.
197        let size = row.try_get::<i32, _>("payload_size")? as u64;
198        let payload_data = row.try_get::<Vec<u8>, _>("payload_data")?;
199
200        // Reconstruct the full header.
201        let header_data = row.try_get("header_data")?;
202        let header: Header<Types> =
203            serde_json::from_value(header_data).decode_error("malformed header")?;
204
205        // Reconstruct the full block payload.
206        let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
207
208        // Reconstruct the query data by adding metadata.
209        let hash: String = row.try_get("hash")?;
210        let hash = hash.parse().decode_error("malformed block hash")?;
211
212        Ok(Self {
213            num_transactions: payload.len(header.metadata()) as u64,
214            header,
215            payload,
216            size,
217            hash,
218        })
219    }
220}
221
222const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
223
224impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
225where
226    Types: NodeType,
227    Header<Types>: QueryableHeader<Types>,
228    Payload<Types>: QueryablePayload<Types>,
229{
230    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
231        <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
232    }
233}
234
235const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
236                                        payload_hash, p.size AS payload_size, p.num_transactions \
237                                        AS num_transactions";
238
239impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
240where
241    Types: NodeType,
242    Header<Types>: QueryableHeader<Types>,
243{
244    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
245        Ok(Self {
246            height: row.try_get::<i64, _>("height")? as u64,
247            block_hash: row
248                .try_get::<String, _>("hash")?
249                .parse()
250                .decode_error("malformed block hash")?,
251            hash: row
252                .try_get::<String, _>("payload_hash")?
253                .parse()
254                .decode_error("malformed payload hash")?,
255            size: row.try_get::<i32, _>("payload_size")? as u64,
256            num_transactions: row.try_get::<i32, _>("num_transactions")? as u64,
257
258            // Per-namespace info must be loaded in a separate query.
259            namespaces: Default::default(),
260        })
261    }
262}
263
264const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
265                                  payload_hash, v.data AS common_data";
266
267impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
268where
269    Types: NodeType,
270    Header<Types>: QueryableHeader<Types>,
271    Payload<Types>: QueryablePayload<Types>,
272{
273    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
274        let height = row.try_get::<i64, _>("height")? as u64;
275        let block_hash: String = row.try_get("block_hash")?;
276        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
277        let payload_hash: String = row.try_get("payload_hash")?;
278        let payload_hash = payload_hash
279            .parse()
280            .decode_error("malformed payload hash")?;
281        let common_data: Vec<u8> = row.try_get("common_data")?;
282        let common =
283            bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
284        Ok(Self {
285            height,
286            block_hash,
287            payload_hash,
288            common,
289        })
290    }
291}
292
293const VID_COMMON_METADATA_COLUMNS: &str =
294    "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
295
296impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
297where
298    Types: NodeType,
299    Header<Types>: QueryableHeader<Types>,
300    Payload<Types>: QueryablePayload<Types>,
301{
302    fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
303        let height = row.try_get::<i64, _>("height")? as u64;
304        let block_hash: String = row.try_get("block_hash")?;
305        let block_hash = block_hash.parse().decode_error("malformed block hash")?;
306        let payload_hash: String = row.try_get("payload_hash")?;
307        let payload_hash = payload_hash
308            .parse()
309            .decode_error("malformed payload hash")?;
310        Ok(Self {
311            height,
312            block_hash,
313            payload_hash,
314        })
315    }
316}
317
318const HEADER_COLUMNS: &str = "h.data AS data";
319
320// We can't implement `FromRow` for `Header<Types>` since `Header<Types>` is not actually a type
321// defined in this crate; it's just an alias for `Types::BlockHeader`. So this standalone function
322// will have to do.
323fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
324where
325    Types: NodeType,
326{
327    // Reconstruct the full header.
328    let data = row.try_get("data")?;
329    serde_json::from_value(data).decode_error("malformed header")
330}
331
332impl From<sqlx::Error> for QueryError {
333    fn from(err: sqlx::Error) -> Self {
334        if matches!(err, sqlx::Error::RowNotFound) {
335            Self::NotFound
336        } else {
337            Self::Error {
338                message: err.to_string(),
339            }
340        }
341    }
342}
343
344impl<Mode> Transaction<Mode> {
345    /// Load a header from storage.
346    ///
347    /// This function is similar to `AvailabilityStorage::get_header`, but
348    /// * does not require the `QueryablePayload<Types>` bound that that trait impl does
349    /// * makes it easier to specify types since the type parameter is on the function and not on a
350    ///   trait impl
351    /// * allows type conversions for the `id` parameter
352    ///
353    /// This more ergonomic interface is useful as loading headers is important for many SQL storage
354    /// functions, not just the `AvailabilityStorage` interface.
355    pub async fn load_header<Types: NodeType>(
356        &mut self,
357        id: impl Into<BlockId<Types>> + Send,
358    ) -> QueryResult<Header<Types>> {
359        let mut query = QueryBuilder::default();
360        let where_clause = query.header_where_clause(id.into())?;
361        // ORDER BY h.height ASC ensures that if there are duplicate blocks (this can happen when
362        // selecting by payload ID, as payloads are not unique), we return the first one.
363        let sql = format!(
364            "SELECT {HEADER_COLUMNS}
365               FROM header AS h
366              WHERE {where_clause}
367              ORDER BY h.height
368              LIMIT 1"
369        );
370
371        let row = query.query(&sql).fetch_one(self.as_mut()).await?;
372        let header = parse_header::<Types>(row)?;
373
374        Ok(header)
375    }
376}
377
378pub(super) trait DecodeError {
379    type Ok;
380    fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
381}
382
383impl<T, E> DecodeError for Result<T, E>
384where
385    E: std::error::Error + Send + Sync + 'static,
386{
387    type Ok = T;
388    fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
389        self.context(msg.to_string())
390            .map_err(|err| sqlx::Error::Decode(err.into()))
391    }
392}