hotshot_query_service/data_source/storage/sql/
queries.rs1#![allow(clippy::needless_lifetimes)]
2use std::{
17 fmt::Display,
18 ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24 simple_certificate::{
25 LightClientStateUpdateCertificateV1, LightClientStateUpdateCertificateV2,
26 QuorumCertificate2,
27 },
28 traits::{
29 block_contents::{BlockHeader, BlockPayload},
30 node_implementation::NodeType,
31 },
32};
33use sqlx::{Arguments, FromRow, Row};
34
35use super::{Database, Db, Query, QueryAs, Transaction};
36use crate::{
37 availability::{
38 BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
39 QueryablePayload, StateCertQueryDataV2, VidCommonQueryData,
40 },
41 data_source::storage::{PayloadMetadata, VidCommonMetadata},
42 Header, Leaf2, Payload, QueryError, QueryResult,
43};
44
45pub(super) mod availability;
46pub(super) mod explorer;
47pub(super) mod node;
48pub(super) mod state;
49
50#[derive(Derivative, Default)]
89#[derivative(Debug)]
90pub struct QueryBuilder<'a> {
91 #[derivative(Debug = "ignore")]
92 arguments: <Db as Database>::Arguments<'a>,
93}
94
95impl<'q> QueryBuilder<'q> {
96 pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
98 where
99 T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
100 {
101 self.arguments.add(arg).map_err(|err| QueryError::Error {
102 message: format!("{err:#}"),
103 })?;
104
105 Ok(format!("${}", self.arguments.len()))
106 }
107
108 pub fn query(self, sql: &'q str) -> Query<'q> {
110 sqlx::query_with(sql, self.arguments)
111 }
112
113 pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
115 where
116 T: for<'r> FromRow<'r, <Db as Database>::Row>,
117 {
118 sqlx::query_as_with(sql, self.arguments)
119 }
120}
121
122impl QueryBuilder<'_> {
123 pub fn header_where_clause<Types: NodeType>(
125 &mut self,
126 id: BlockId<Types>,
127 ) -> QueryResult<String> {
128 let clause = match id {
129 BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
130 BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
131 BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
132 };
133 Ok(clause)
134 }
135
136 pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
138 where
139 R: RangeBounds<usize>,
140 {
141 let mut bounds = vec![];
142
143 match range.start_bound() {
144 Bound::Included(n) => {
145 bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
146 },
147 Bound::Excluded(n) => {
148 bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
149 },
150 Bound::Unbounded => {},
151 }
152 match range.end_bound() {
153 Bound::Included(n) => {
154 bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
155 },
156 Bound::Excluded(n) => {
157 bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
158 },
159 Bound::Unbounded => {},
160 }
161
162 let mut where_clause = bounds.join(" AND ");
163 if !where_clause.is_empty() {
164 where_clause = format!(" WHERE {where_clause}");
165 }
166
167 Ok(where_clause)
168 }
169}
170
171const LEAF_COLUMNS: &str = "leaf, qc";
172
173impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
174where
175 Types: NodeType,
176{
177 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
178 let leaf = row.try_get("leaf")?;
179 let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
180
181 let qc = row.try_get("qc")?;
182 let qc: QuorumCertificate2<Types> =
183 serde_json::from_value(qc).decode_error("malformed QC")?;
184
185 Ok(Self { leaf, qc })
186 }
187}
188
189const BLOCK_COLUMNS: &str =
190 "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
191
192impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
193where
194 Types: NodeType,
195 Header<Types>: QueryableHeader<Types>,
196 Payload<Types>: QueryablePayload<Types>,
197{
198 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
199 let size: Option<i32> = row.try_get("payload_size")?;
201 let payload_data: Option<Vec<u8>> = row.try_get("payload_data")?;
202 let (size, payload_data) = size.zip(payload_data).ok_or(sqlx::Error::RowNotFound)?;
203 let size = size as u64;
204
205 let header_data = row.try_get("header_data")?;
207 let header: Header<Types> =
208 serde_json::from_value(header_data).decode_error("malformed header")?;
209
210 let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
212
213 let hash: String = row.try_get("hash")?;
215 let hash = hash.parse().decode_error("malformed block hash")?;
216
217 Ok(Self {
218 num_transactions: payload.len(header.metadata()) as u64,
219 header,
220 payload,
221 size,
222 hash,
223 })
224 }
225}
226
227const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
228
229impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
230where
231 Types: NodeType,
232 Header<Types>: QueryableHeader<Types>,
233 Payload<Types>: QueryablePayload<Types>,
234{
235 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
236 <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
237 }
238}
239
240const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
241 payload_hash, p.size AS payload_size, p.num_transactions \
242 AS num_transactions";
243
244impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
245where
246 Types: NodeType,
247 Header<Types>: QueryableHeader<Types>,
248{
249 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
250 Ok(Self {
251 height: row.try_get::<i64, _>("height")? as u64,
252 block_hash: row
253 .try_get::<String, _>("hash")?
254 .parse()
255 .decode_error("malformed block hash")?,
256 hash: row
257 .try_get::<String, _>("payload_hash")?
258 .parse()
259 .decode_error("malformed payload hash")?,
260 size: row
261 .try_get::<Option<i32>, _>("payload_size")?
262 .ok_or(sqlx::Error::RowNotFound)? as u64,
263 num_transactions: row
264 .try_get::<Option<i32>, _>("num_transactions")?
265 .ok_or(sqlx::Error::RowNotFound)? as u64,
266
267 namespaces: Default::default(),
269 })
270 }
271}
272
273const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
274 payload_hash, v.common AS common_data";
275
276impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
277where
278 Types: NodeType,
279 Header<Types>: QueryableHeader<Types>,
280 Payload<Types>: QueryablePayload<Types>,
281{
282 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
283 let height = row.try_get::<i64, _>("height")? as u64;
284 let block_hash: String = row.try_get("block_hash")?;
285 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
286 let payload_hash: String = row.try_get("payload_hash")?;
287 let payload_hash = payload_hash
288 .parse()
289 .decode_error("malformed payload hash")?;
290 let common_data: Vec<u8> = row.try_get("common_data")?;
291 let common =
292 bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
293 Ok(Self {
294 height,
295 block_hash,
296 payload_hash,
297 common,
298 })
299 }
300}
301
302const VID_COMMON_METADATA_COLUMNS: &str =
303 "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
304
305impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
306where
307 Types: NodeType,
308 Header<Types>: QueryableHeader<Types>,
309 Payload<Types>: QueryablePayload<Types>,
310{
311 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
312 let height = row.try_get::<i64, _>("height")? as u64;
313 let block_hash: String = row.try_get("block_hash")?;
314 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
315 let payload_hash: String = row.try_get("payload_hash")?;
316 let payload_hash = payload_hash
317 .parse()
318 .decode_error("malformed payload hash")?;
319 Ok(Self {
320 height,
321 block_hash,
322 payload_hash,
323 })
324 }
325}
326
327const HEADER_COLUMNS: &str = "h.data AS data";
328
329fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
333where
334 Types: NodeType,
335{
336 let data = row.try_get("data")?;
338 serde_json::from_value(data).decode_error("malformed header")
339}
340
341impl From<sqlx::Error> for QueryError {
342 fn from(err: sqlx::Error) -> Self {
343 if matches!(err, sqlx::Error::RowNotFound) {
344 Self::NotFound
345 } else {
346 Self::Error {
347 message: err.to_string(),
348 }
349 }
350 }
351}
352
353const STATE_CERT_COLUMNS: &str = "state_cert";
354
355impl<'r, Types> FromRow<'r, <Db as Database>::Row> for StateCertQueryDataV2<Types>
356where
357 Types: NodeType,
358{
359 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
360 let state_cert: LightClientStateUpdateCertificateV2<Types> = {
361 let bytes: &[u8] = row.try_get("state_cert")?;
362 match bincode::deserialize::<LightClientStateUpdateCertificateV2<Types>>(bytes) {
363 Ok(cert) => cert,
364 Err(err) => {
365 tracing::info!(
366 "Falling back to V1 deserialization for LightClientStateUpdateCertificate"
367 );
368
369 match bincode::deserialize::<LightClientStateUpdateCertificateV1<Types>>(bytes)
370 {
371 Ok(legacy) => legacy.into(),
372 Err(err_legacy) => {
373 tracing::error!(
374 "Failed to deserialize state_cert with v1 and v2 v2 error: {err}. \
375 v1 error: {err_legacy}",
376 );
377 return Err(sqlx::Error::Decode(err_legacy));
378 },
379 }
380 },
381 }
382 };
383 Ok(state_cert.into())
384 }
385}
386
387impl<Mode> Transaction<Mode> {
388 pub async fn load_header<Types: NodeType>(
399 &mut self,
400 id: impl Into<BlockId<Types>> + Send,
401 ) -> QueryResult<Header<Types>> {
402 let mut query = QueryBuilder::default();
403 let where_clause = query.header_where_clause(id.into())?;
404 let sql = format!(
407 "SELECT {HEADER_COLUMNS}
408 FROM header AS h
409 WHERE {where_clause}
410 ORDER BY h.height
411 LIMIT 1"
412 );
413
414 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
415 let header = parse_header::<Types>(row)?;
416
417 Ok(header)
418 }
419}
420
421pub(super) trait DecodeError {
422 type Ok;
423 fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
424}
425
426impl<T, E> DecodeError for Result<T, E>
427where
428 E: std::error::Error + Send + Sync + 'static,
429{
430 type Ok = T;
431 fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
432 self.context(msg.to_string())
433 .map_err(|err| sqlx::Error::Decode(err.into()))
434 }
435}