hotshot_query_service/data_source/storage/sql/
queries.rs1#![allow(clippy::needless_lifetimes)]
2use std::{
17 fmt::Display,
18 ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24 simple_certificate::QuorumCertificate2,
25 traits::{
26 block_contents::{BlockHeader, BlockPayload},
27 node_implementation::NodeType,
28 },
29};
30use sqlx::{Arguments, FromRow, Row};
31
32use super::{Database, Db, Query, QueryAs, Transaction};
33use crate::{
34 availability::{
35 BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
36 QueryablePayload, VidCommonQueryData,
37 },
38 data_source::storage::{PayloadMetadata, VidCommonMetadata},
39 Header, Leaf2, Payload, QueryError, QueryResult,
40};
41
42pub(super) mod availability;
43pub(super) mod explorer;
44pub(super) mod node;
45pub(super) mod state;
46
47#[derive(Derivative, Default)]
86#[derivative(Debug)]
87pub struct QueryBuilder<'a> {
88 #[derivative(Debug = "ignore")]
89 arguments: <Db as Database>::Arguments<'a>,
90}
91
92impl<'q> QueryBuilder<'q> {
93 pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
95 where
96 T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
97 {
98 self.arguments.add(arg).map_err(|err| QueryError::Error {
99 message: format!("{err:#}"),
100 })?;
101
102 Ok(format!("${}", self.arguments.len()))
103 }
104
105 pub fn query(self, sql: &'q str) -> Query<'q> {
107 sqlx::query_with(sql, self.arguments)
108 }
109
110 pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
112 where
113 T: for<'r> FromRow<'r, <Db as Database>::Row>,
114 {
115 sqlx::query_as_with(sql, self.arguments)
116 }
117}
118
119impl QueryBuilder<'_> {
120 pub fn header_where_clause<Types: NodeType>(
122 &mut self,
123 id: BlockId<Types>,
124 ) -> QueryResult<String> {
125 let clause = match id {
126 BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
127 BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
128 BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
129 };
130 Ok(clause)
131 }
132
133 pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
135 where
136 R: RangeBounds<usize>,
137 {
138 let mut bounds = vec![];
139
140 match range.start_bound() {
141 Bound::Included(n) => {
142 bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
143 },
144 Bound::Excluded(n) => {
145 bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
146 },
147 Bound::Unbounded => {},
148 }
149 match range.end_bound() {
150 Bound::Included(n) => {
151 bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
152 },
153 Bound::Excluded(n) => {
154 bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
155 },
156 Bound::Unbounded => {},
157 }
158
159 let mut where_clause = bounds.join(" AND ");
160 if !where_clause.is_empty() {
161 where_clause = format!(" WHERE {where_clause}");
162 }
163
164 Ok(where_clause)
165 }
166}
167
168const LEAF_COLUMNS: &str = "leaf, qc";
169
170impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
171where
172 Types: NodeType,
173{
174 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
175 let leaf = row.try_get("leaf")?;
176 let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
177
178 let qc = row.try_get("qc")?;
179 let qc: QuorumCertificate2<Types> =
180 serde_json::from_value(qc).decode_error("malformed QC")?;
181
182 Ok(Self { leaf, qc })
183 }
184}
185
186const BLOCK_COLUMNS: &str =
187 "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
188
189impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
190where
191 Types: NodeType,
192 Header<Types>: QueryableHeader<Types>,
193 Payload<Types>: QueryablePayload<Types>,
194{
195 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
196 let size: Option<i32> = row.try_get("payload_size")?;
198 let payload_data: Option<Vec<u8>> = row.try_get("payload_data")?;
199 let (size, payload_data) = size.zip(payload_data).ok_or(sqlx::Error::RowNotFound)?;
200 let size = size as u64;
201
202 let header_data = row.try_get("header_data")?;
204 let header: Header<Types> =
205 serde_json::from_value(header_data).decode_error("malformed header")?;
206
207 let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
209
210 let hash: String = row.try_get("hash")?;
212 let hash = hash.parse().decode_error("malformed block hash")?;
213
214 Ok(Self {
215 num_transactions: payload.len(header.metadata()) as u64,
216 header,
217 payload,
218 size,
219 hash,
220 })
221 }
222}
223
224const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
225
226impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
227where
228 Types: NodeType,
229 Header<Types>: QueryableHeader<Types>,
230 Payload<Types>: QueryablePayload<Types>,
231{
232 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
233 <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
234 }
235}
236
237const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
238 payload_hash, p.size AS payload_size, p.num_transactions \
239 AS num_transactions";
240
241impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
242where
243 Types: NodeType,
244 Header<Types>: QueryableHeader<Types>,
245{
246 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
247 Ok(Self {
248 height: row.try_get::<i64, _>("height")? as u64,
249 block_hash: row
250 .try_get::<String, _>("hash")?
251 .parse()
252 .decode_error("malformed block hash")?,
253 hash: row
254 .try_get::<String, _>("payload_hash")?
255 .parse()
256 .decode_error("malformed payload hash")?,
257 size: row
258 .try_get::<Option<i32>, _>("payload_size")?
259 .ok_or(sqlx::Error::RowNotFound)? as u64,
260 num_transactions: row
261 .try_get::<Option<i32>, _>("num_transactions")?
262 .ok_or(sqlx::Error::RowNotFound)? as u64,
263
264 namespaces: Default::default(),
266 })
267 }
268}
269
270const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
271 payload_hash, v.common AS common_data";
272
273impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
274where
275 Types: NodeType,
276 Header<Types>: QueryableHeader<Types>,
277 Payload<Types>: QueryablePayload<Types>,
278{
279 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
280 let height = row.try_get::<i64, _>("height")? as u64;
281 let block_hash: String = row.try_get("block_hash")?;
282 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
283 let payload_hash: String = row.try_get("payload_hash")?;
284 let payload_hash = payload_hash
285 .parse()
286 .decode_error("malformed payload hash")?;
287 let common_data: Vec<u8> = row.try_get("common_data")?;
288 let common =
289 bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
290 Ok(Self {
291 height,
292 block_hash,
293 payload_hash,
294 common,
295 })
296 }
297}
298
299const VID_COMMON_METADATA_COLUMNS: &str =
300 "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
301
302impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
303where
304 Types: NodeType,
305 Header<Types>: QueryableHeader<Types>,
306 Payload<Types>: QueryablePayload<Types>,
307{
308 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
309 let height = row.try_get::<i64, _>("height")? as u64;
310 let block_hash: String = row.try_get("block_hash")?;
311 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
312 let payload_hash: String = row.try_get("payload_hash")?;
313 let payload_hash = payload_hash
314 .parse()
315 .decode_error("malformed payload hash")?;
316 Ok(Self {
317 height,
318 block_hash,
319 payload_hash,
320 })
321 }
322}
323
324const HEADER_COLUMNS: &str = "h.data AS data";
325
326fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
330where
331 Types: NodeType,
332{
333 let data = row.try_get("data")?;
335 serde_json::from_value(data).decode_error("malformed header")
336}
337
338impl From<sqlx::Error> for QueryError {
339 fn from(err: sqlx::Error) -> Self {
340 if matches!(err, sqlx::Error::RowNotFound) {
341 Self::NotFound
342 } else {
343 Self::Error {
344 message: err.to_string(),
345 }
346 }
347 }
348}
349
350impl<Mode> Transaction<Mode> {
351 pub async fn load_header<Types: NodeType>(
362 &mut self,
363 id: impl Into<BlockId<Types>> + Send,
364 ) -> QueryResult<Header<Types>> {
365 let mut query = QueryBuilder::default();
366 let where_clause = query.header_where_clause(id.into())?;
367 let sql = format!(
370 "SELECT {HEADER_COLUMNS}
371 FROM header AS h
372 WHERE {where_clause}
373 ORDER BY h.height
374 LIMIT 1"
375 );
376
377 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
378 let header = parse_header::<Types>(row)?;
379
380 Ok(header)
381 }
382}
383
384pub(super) trait DecodeError {
385 type Ok;
386 fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
387}
388
389impl<T, E> DecodeError for Result<T, E>
390where
391 E: std::error::Error + Send + Sync + 'static,
392{
393 type Ok = T;
394 fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
395 self.context(msg.to_string())
396 .map_err(|err| sqlx::Error::Decode(err.into()))
397 }
398}