hotshot_query_service/data_source/storage/sql/
queries.rs1#![allow(clippy::needless_lifetimes)]
2use std::{
17 fmt::Display,
18 ops::{Bound, RangeBounds},
19};
20
21use anyhow::Context;
22use derivative::Derivative;
23use hotshot_types::{
24 simple_certificate::QuorumCertificate2,
25 traits::{
26 block_contents::{BlockHeader, BlockPayload},
27 node_implementation::NodeType,
28 },
29};
30use sqlx::{Arguments, FromRow, Row};
31
32use super::{Database, Db, Query, QueryAs, Transaction};
33use crate::{
34 Header, Leaf2, Payload, QueryError, QueryResult,
35 availability::{
36 BlockId, BlockQueryData, LeafQueryData, PayloadQueryData, QueryableHeader,
37 QueryablePayload, VidCommonQueryData,
38 },
39 data_source::storage::{PayloadMetadata, VidCommonMetadata},
40};
41
42pub(super) mod availability;
43pub(super) mod explorer;
44pub(super) mod node;
45pub(super) mod state;
46
47#[derive(Derivative, Default)]
86#[derivative(Debug)]
87pub struct QueryBuilder<'a> {
88 #[derivative(Debug = "ignore")]
89 arguments: <Db as Database>::Arguments<'a>,
90}
91
92impl<'q> QueryBuilder<'q> {
93 pub fn bind<T>(&mut self, arg: T) -> QueryResult<String>
95 where
96 T: 'q + sqlx::Encode<'q, Db> + sqlx::Type<Db>,
97 {
98 self.arguments.add(arg).map_err(|err| QueryError::Error {
99 message: format!("{err:#}"),
100 })?;
101
102 Ok(format!("${}", self.arguments.len()))
103 }
104
105 pub fn query(self, sql: &'q str) -> Query<'q> {
107 sqlx::query_with(sql, self.arguments)
108 }
109
110 pub fn query_as<T>(self, sql: &'q str) -> QueryAs<'q, T>
112 where
113 T: for<'r> FromRow<'r, <Db as Database>::Row>,
114 {
115 sqlx::query_as_with(sql, self.arguments)
116 }
117}
118
119impl QueryBuilder<'_> {
120 pub fn header_where_clause<Types: NodeType>(
122 &mut self,
123 id: BlockId<Types>,
124 ) -> QueryResult<String> {
125 let clause = match id {
126 BlockId::Number(n) => format!("h.height = {}", self.bind(n as i64)?),
127 BlockId::Hash(h) => format!("h.hash = {}", self.bind(h.to_string())?),
128 BlockId::PayloadHash(h) => format!("h.payload_hash = {}", self.bind(h.to_string())?),
129 };
130 Ok(clause)
131 }
132
133 pub fn bounds_to_where_clause<R>(&mut self, range: R, column: &str) -> QueryResult<String>
135 where
136 R: RangeBounds<usize>,
137 {
138 let mut bounds = vec![];
139
140 match range.start_bound() {
141 Bound::Included(n) => {
142 bounds.push(format!("{column} >= {}", self.bind(*n as i64)?));
143 },
144 Bound::Excluded(n) => {
145 bounds.push(format!("{column} > {}", self.bind(*n as i64)?));
146 },
147 Bound::Unbounded => {},
148 }
149 match range.end_bound() {
150 Bound::Included(n) => {
151 bounds.push(format!("{column} <= {}", self.bind(*n as i64)?));
152 },
153 Bound::Excluded(n) => {
154 bounds.push(format!("{column} < {}", self.bind(*n as i64)?));
155 },
156 Bound::Unbounded => {},
157 }
158
159 let mut where_clause = bounds.join(" AND ");
160 if !where_clause.is_empty() {
161 where_clause = format!(" WHERE {where_clause}");
162 }
163
164 Ok(where_clause)
165 }
166}
167
168const LEAF_COLUMNS: &str = "leaf, qc";
169
170impl<'r, Types> FromRow<'r, <Db as Database>::Row> for LeafQueryData<Types>
171where
172 Types: NodeType,
173{
174 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
175 let leaf = row.try_get("leaf")?;
176 let leaf: Leaf2<Types> = serde_json::from_value(leaf).decode_error("malformed leaf")?;
177
178 let qc = row.try_get("qc")?;
179 let qc: QuorumCertificate2<Types> =
180 serde_json::from_value(qc).decode_error("malformed QC")?;
181
182 Ok(Self { leaf, qc })
183 }
184}
185
186const BLOCK_COLUMNS: &str =
187 "h.hash AS hash, h.data AS header_data, p.size AS payload_size, p.data AS payload_data";
188
189impl<'r, Types> FromRow<'r, <Db as Database>::Row> for BlockQueryData<Types>
190where
191 Types: NodeType,
192 Header<Types>: QueryableHeader<Types>,
193 Payload<Types>: QueryablePayload<Types>,
194{
195 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
196 let size = row.try_get::<i32, _>("payload_size")? as u64;
198 let payload_data = row.try_get::<Vec<u8>, _>("payload_data")?;
199
200 let header_data = row.try_get("header_data")?;
202 let header: Header<Types> =
203 serde_json::from_value(header_data).decode_error("malformed header")?;
204
205 let payload = Payload::<Types>::from_bytes(&payload_data, header.metadata());
207
208 let hash: String = row.try_get("hash")?;
210 let hash = hash.parse().decode_error("malformed block hash")?;
211
212 Ok(Self {
213 num_transactions: payload.len(header.metadata()) as u64,
214 header,
215 payload,
216 size,
217 hash,
218 })
219 }
220}
221
222const PAYLOAD_COLUMNS: &str = BLOCK_COLUMNS;
223
224impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadQueryData<Types>
225where
226 Types: NodeType,
227 Header<Types>: QueryableHeader<Types>,
228 Payload<Types>: QueryablePayload<Types>,
229{
230 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
231 <BlockQueryData<Types> as FromRow<<Db as Database>::Row>>::from_row(row).map(Self::from)
232 }
233}
234
235const PAYLOAD_METADATA_COLUMNS: &str = "h.height AS height, h.hash AS hash, h.payload_hash AS \
236 payload_hash, p.size AS payload_size, p.num_transactions \
237 AS num_transactions";
238
239impl<'r, Types> FromRow<'r, <Db as Database>::Row> for PayloadMetadata<Types>
240where
241 Types: NodeType,
242 Header<Types>: QueryableHeader<Types>,
243{
244 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
245 Ok(Self {
246 height: row.try_get::<i64, _>("height")? as u64,
247 block_hash: row
248 .try_get::<String, _>("hash")?
249 .parse()
250 .decode_error("malformed block hash")?,
251 hash: row
252 .try_get::<String, _>("payload_hash")?
253 .parse()
254 .decode_error("malformed payload hash")?,
255 size: row.try_get::<i32, _>("payload_size")? as u64,
256 num_transactions: row.try_get::<i32, _>("num_transactions")? as u64,
257
258 namespaces: Default::default(),
260 })
261 }
262}
263
264const VID_COMMON_COLUMNS: &str = "h.height AS height, h.hash AS block_hash, h.payload_hash AS \
265 payload_hash, v.data AS common_data";
266
267impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonQueryData<Types>
268where
269 Types: NodeType,
270 Header<Types>: QueryableHeader<Types>,
271 Payload<Types>: QueryablePayload<Types>,
272{
273 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
274 let height = row.try_get::<i64, _>("height")? as u64;
275 let block_hash: String = row.try_get("block_hash")?;
276 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
277 let payload_hash: String = row.try_get("payload_hash")?;
278 let payload_hash = payload_hash
279 .parse()
280 .decode_error("malformed payload hash")?;
281 let common_data: Vec<u8> = row.try_get("common_data")?;
282 let common =
283 bincode::deserialize(&common_data).decode_error("malformed VID common data")?;
284 Ok(Self {
285 height,
286 block_hash,
287 payload_hash,
288 common,
289 })
290 }
291}
292
293const VID_COMMON_METADATA_COLUMNS: &str =
294 "h.height AS height, h.hash AS block_hash, h.payload_hash AS payload_hash";
295
296impl<'r, Types> FromRow<'r, <Db as Database>::Row> for VidCommonMetadata<Types>
297where
298 Types: NodeType,
299 Header<Types>: QueryableHeader<Types>,
300 Payload<Types>: QueryablePayload<Types>,
301{
302 fn from_row(row: &'r <Db as Database>::Row) -> sqlx::Result<Self> {
303 let height = row.try_get::<i64, _>("height")? as u64;
304 let block_hash: String = row.try_get("block_hash")?;
305 let block_hash = block_hash.parse().decode_error("malformed block hash")?;
306 let payload_hash: String = row.try_get("payload_hash")?;
307 let payload_hash = payload_hash
308 .parse()
309 .decode_error("malformed payload hash")?;
310 Ok(Self {
311 height,
312 block_hash,
313 payload_hash,
314 })
315 }
316}
317
318const HEADER_COLUMNS: &str = "h.data AS data";
319
320fn parse_header<Types>(row: <Db as Database>::Row) -> sqlx::Result<Header<Types>>
324where
325 Types: NodeType,
326{
327 let data = row.try_get("data")?;
329 serde_json::from_value(data).decode_error("malformed header")
330}
331
332impl From<sqlx::Error> for QueryError {
333 fn from(err: sqlx::Error) -> Self {
334 if matches!(err, sqlx::Error::RowNotFound) {
335 Self::NotFound
336 } else {
337 Self::Error {
338 message: err.to_string(),
339 }
340 }
341 }
342}
343
344impl<Mode> Transaction<Mode> {
345 pub async fn load_header<Types: NodeType>(
356 &mut self,
357 id: impl Into<BlockId<Types>> + Send,
358 ) -> QueryResult<Header<Types>> {
359 let mut query = QueryBuilder::default();
360 let where_clause = query.header_where_clause(id.into())?;
361 let sql = format!(
364 "SELECT {HEADER_COLUMNS}
365 FROM header AS h
366 WHERE {where_clause}
367 ORDER BY h.height
368 LIMIT 1"
369 );
370
371 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
372 let header = parse_header::<Types>(row)?;
373
374 Ok(header)
375 }
376}
377
378pub(super) trait DecodeError {
379 type Ok;
380 fn decode_error(self, msg: impl Display) -> sqlx::Result<Self::Ok>;
381}
382
383impl<T, E> DecodeError for Result<T, E>
384where
385 E: std::error::Error + Send + Sync + 'static,
386{
387 type Ok = T;
388 fn decode_error(self, msg: impl Display) -> sqlx::Result<<Self as DecodeError>::Ok> {
389 self.context(msg.to_string())
390 .map_err(|err| sqlx::Error::Decode(err.into()))
391 }
392}