1use std::ops::RangeBounds;
16
17use async_trait::async_trait;
18use futures::stream::{StreamExt, TryStreamExt};
19use hotshot_types::traits::node_implementation::NodeType;
20use snafu::OptionExt;
21use sqlx::FromRow;
22
23use super::{
24 super::transaction::{Transaction, TransactionMode, query},
25 BLOCK_COLUMNS, LEAF_COLUMNS, PAYLOAD_COLUMNS, PAYLOAD_METADATA_COLUMNS, QueryBuilder,
26 VID_COMMON_COLUMNS, VID_COMMON_METADATA_COLUMNS,
27};
28use crate::{
29 Header, MissingSnafu, Payload, QueryError, QueryResult,
30 availability::{
31 BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceInfo, NamespaceMap,
32 PayloadQueryData, QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
33 },
34 data_source::storage::{
35 AvailabilityStorage, PayloadMetadata, VidCommonMetadata, sql::sqlx::Row,
36 },
37 types::HeightIndexed,
38};
39
40#[async_trait]
41impl<Mode, Types> AvailabilityStorage<Types> for Transaction<Mode>
42where
43 Types: NodeType,
44 Mode: TransactionMode,
45 Payload<Types>: QueryablePayload<Types>,
46 Header<Types>: QueryableHeader<Types>,
47{
48 async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
49 let mut query = QueryBuilder::default();
50 let where_clause = match id {
51 LeafId::Number(n) => format!("height = {}", query.bind(n as i64)?),
52 LeafId::Hash(h) => format!("hash = {}", query.bind(h.to_string())?),
53 };
54 let row = query
55 .query(&format!(
56 "SELECT {LEAF_COLUMNS} FROM leaf2 WHERE {where_clause} LIMIT 1"
57 ))
58 .fetch_one(self.as_mut())
59 .await?;
60 let leaf = LeafQueryData::from_row(&row)?;
61 Ok(leaf)
62 }
63
64 async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
65 let mut query = QueryBuilder::default();
66 let where_clause = query.header_where_clause(id)?;
67 let sql = format!(
70 "SELECT {BLOCK_COLUMNS}
71 FROM header AS h
72 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
73 WHERE {where_clause}
74 ORDER BY h.height
75 LIMIT 1"
76 );
77 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
78 let block = BlockQueryData::from_row(&row)?;
79 Ok(block)
80 }
81
82 async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
83 self.load_header(id).await
84 }
85
86 async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
87 let mut query = QueryBuilder::default();
88 let where_clause = query.header_where_clause(id)?;
89 let sql = format!(
92 "SELECT {PAYLOAD_COLUMNS}
93 FROM header AS h
94 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
95 WHERE {where_clause}
96 ORDER BY h.height
97 LIMIT 1"
98 );
99 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
100 let payload = PayloadQueryData::from_row(&row)?;
101 Ok(payload)
102 }
103
104 async fn get_payload_metadata(
105 &mut self,
106 id: BlockId<Types>,
107 ) -> QueryResult<PayloadMetadata<Types>> {
108 let mut query = QueryBuilder::default();
109 let where_clause = query.header_where_clause(id)?;
110 let sql = format!(
113 "SELECT {PAYLOAD_METADATA_COLUMNS}
114 FROM header AS h
115 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
116 WHERE {where_clause}
117 ORDER BY h.height ASC
118 LIMIT 1"
119 );
120 let row = query
121 .query(&sql)
122 .fetch_optional(self.as_mut())
123 .await?
124 .context(MissingSnafu)?;
125 let mut payload = PayloadMetadata::from_row(&row)?;
126 payload.namespaces = self
127 .load_namespaces::<Types>(payload.height(), payload.size)
128 .await?;
129 Ok(payload)
130 }
131
132 async fn get_vid_common(
133 &mut self,
134 id: BlockId<Types>,
135 ) -> QueryResult<VidCommonQueryData<Types>> {
136 let mut query = QueryBuilder::default();
137 let where_clause = query.header_where_clause(id)?;
138 let sql = format!(
141 "SELECT {VID_COMMON_COLUMNS}
142 FROM header AS h
143 JOIN vid_common AS v ON h.payload_hash = v.hash
144 WHERE {where_clause}
145 ORDER BY h.height
146 LIMIT 1"
147 );
148 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
149 let common = VidCommonQueryData::from_row(&row)?;
150 Ok(common)
151 }
152
153 async fn get_vid_common_metadata(
154 &mut self,
155 id: BlockId<Types>,
156 ) -> QueryResult<VidCommonMetadata<Types>> {
157 let mut query = QueryBuilder::default();
158 let where_clause = query.header_where_clause(id)?;
159 let sql = format!(
162 "SELECT {VID_COMMON_METADATA_COLUMNS}
163 FROM header AS h
164 JOIN vid_common AS v ON h.payload_hash = v.hash
165 WHERE {where_clause}
166 ORDER BY h.height ASC
167 LIMIT 1"
168 );
169 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
170 let common = VidCommonMetadata::from_row(&row)?;
171 Ok(common)
172 }
173
174 async fn get_leaf_range<R>(
175 &mut self,
176 range: R,
177 ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
178 where
179 R: RangeBounds<usize> + Send,
180 {
181 let mut query = QueryBuilder::default();
182 let where_clause = query.bounds_to_where_clause(range, "height")?;
183 let sql = format!("SELECT {LEAF_COLUMNS} FROM leaf2 {where_clause} ORDER BY height ASC");
184 Ok(query
185 .query(&sql)
186 .fetch(self.as_mut())
187 .map(|res| LeafQueryData::from_row(&res?))
188 .map_err(QueryError::from)
189 .collect()
190 .await)
191 }
192
193 async fn get_block_range<R>(
194 &mut self,
195 range: R,
196 ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
197 where
198 R: RangeBounds<usize> + Send,
199 {
200 let mut query = QueryBuilder::default();
201 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
202 let sql = format!(
203 "SELECT {BLOCK_COLUMNS}
204 FROM header AS h
205 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
206 {where_clause}
207 ORDER BY h.height"
208 );
209 Ok(query
210 .query(&sql)
211 .fetch(self.as_mut())
212 .map(|res| BlockQueryData::from_row(&res?))
213 .map_err(QueryError::from)
214 .collect()
215 .await)
216 }
217
218 async fn get_header_range<R>(
219 &mut self,
220 range: R,
221 ) -> QueryResult<Vec<QueryResult<Header<Types>>>>
222 where
223 R: RangeBounds<usize> + Send,
224 {
225 let mut query = QueryBuilder::default();
226 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
227
228 let headers = query
229 .query(&format!(
230 "SELECT data
231 FROM header AS h
232 {where_clause}
233 ORDER BY h.height"
234 ))
235 .fetch(self.as_mut())
236 .map(|res| serde_json::from_value(res?.get("data")).unwrap())
237 .collect()
238 .await;
239
240 Ok(headers)
241 }
242
243 async fn get_payload_range<R>(
244 &mut self,
245 range: R,
246 ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
247 where
248 R: RangeBounds<usize> + Send,
249 {
250 let mut query = QueryBuilder::default();
251 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
252 let sql = format!(
253 "SELECT {PAYLOAD_COLUMNS}
254 FROM header AS h
255 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
256 {where_clause}
257 ORDER BY h.height"
258 );
259 Ok(query
260 .query(&sql)
261 .fetch(self.as_mut())
262 .map(|res| PayloadQueryData::from_row(&res?))
263 .map_err(QueryError::from)
264 .collect()
265 .await)
266 }
267
268 async fn get_payload_metadata_range<R>(
269 &mut self,
270 range: R,
271 ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
272 where
273 R: RangeBounds<usize> + Send + 'static,
274 {
275 let mut query = QueryBuilder::default();
276 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
277 let sql = format!(
278 "SELECT {PAYLOAD_METADATA_COLUMNS}
279 FROM header AS h
280 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
281 {where_clause}
282 ORDER BY h.height ASC"
283 );
284 let rows = query
285 .query(&sql)
286 .fetch(self.as_mut())
287 .collect::<Vec<_>>()
288 .await;
289 let mut payloads = vec![];
290 for row in rows {
291 let res = async {
292 let mut meta = PayloadMetadata::from_row(&row?)?;
293 meta.namespaces = self
294 .load_namespaces::<Types>(meta.height(), meta.size)
295 .await?;
296 Ok(meta)
297 }
298 .await;
299 payloads.push(res);
300 }
301 Ok(payloads)
302 }
303
304 async fn get_vid_common_range<R>(
305 &mut self,
306 range: R,
307 ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
308 where
309 R: RangeBounds<usize> + Send,
310 {
311 let mut query = QueryBuilder::default();
312 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
313 let sql = format!(
314 "SELECT {VID_COMMON_COLUMNS}
315 FROM header AS h
316 JOIN vid_common AS v ON h.payload_hash = v.hash
317 {where_clause}
318 ORDER BY h.height"
319 );
320 Ok(query
321 .query(&sql)
322 .fetch(self.as_mut())
323 .map(|res| VidCommonQueryData::from_row(&res?))
324 .map_err(QueryError::from)
325 .collect()
326 .await)
327 }
328
329 async fn get_vid_common_metadata_range<R>(
330 &mut self,
331 range: R,
332 ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
333 where
334 R: RangeBounds<usize> + Send,
335 {
336 let mut query = QueryBuilder::default();
337 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
338 let sql = format!(
339 "SELECT {VID_COMMON_METADATA_COLUMNS}
340 FROM header AS h
341 JOIN vid_common AS v ON h.payload_hash = v.hash
342 {where_clause}
343 ORDER BY h.height ASC"
344 );
345 Ok(query
346 .query(&sql)
347 .fetch(self.as_mut())
348 .map(|res| VidCommonMetadata::from_row(&res?))
349 .map_err(QueryError::from)
350 .collect()
351 .await)
352 }
353
354 async fn get_block_with_transaction(
355 &mut self,
356 hash: TransactionHash<Types>,
357 ) -> QueryResult<BlockQueryData<Types>> {
358 let mut query = QueryBuilder::default();
359 let hash_param = query.bind(hash.to_string())?;
360
361 let sql = format!(
364 "SELECT {BLOCK_COLUMNS}
365 FROM header AS h
366 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
367 JOIN transactions AS t ON t.block_height = h.height
368 WHERE t.hash = {hash_param}
369 ORDER BY t.block_height, t.ns_id, t.position
370 LIMIT 1"
371 );
372 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
373 Ok(BlockQueryData::from_row(&row)?)
374 }
375
376 async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
377 let row = query(&format!(
378 "SELECT {LEAF_COLUMNS} FROM leaf2 WHERE height >= $1 ORDER BY height ASC LIMIT 1"
379 ))
380 .bind(from as i64)
381 .fetch_one(self.as_mut())
382 .await?;
383 let leaf = LeafQueryData::from_row(&row)?;
384 Ok(leaf)
385 }
386}
387
388impl<Mode> Transaction<Mode>
389where
390 Mode: TransactionMode,
391{
392 async fn load_namespaces<Types>(
393 &mut self,
394 height: u64,
395 payload_size: u64,
396 ) -> QueryResult<NamespaceMap<Types>>
397 where
398 Types: NodeType,
399 Header<Types>: QueryableHeader<Types>,
400 Payload<Types>: QueryablePayload<Types>,
401 {
402 let header = self
403 .get_header(BlockId::<Types>::from(height as usize))
404 .await?;
405 let map = query(
406 "SELECT ns_id, ns_index, max(position) + 1 AS count
407 FROM transactions
408 WHERE block_height = $1
409 GROUP BY ns_id, ns_index",
410 )
411 .bind(height as i64)
412 .fetch(self.as_mut())
413 .map_ok(|row| {
414 let ns = row.get::<i64, _>("ns_index").into();
415 let id = row.get::<i64, _>("ns_id").into();
416 let num_transactions = row.get::<i64, _>("count") as u64;
417 let size = header.namespace_size(&ns, payload_size as usize);
418 (
419 id,
420 NamespaceInfo {
421 num_transactions,
422 size,
423 },
424 )
425 })
426 .try_collect()
427 .await?;
428 Ok(map)
429 }
430}
431
432#[cfg(test)]
433mod test {
434 use hotshot_example_types::node_types::TEST_VERSIONS;
435 use hotshot_types::{data::VidCommon, vid::advz::advz_scheme};
436 use jf_advz::VidScheme;
437 use pretty_assertions::assert_eq;
438
439 use super::*;
440 use crate::{
441 data_source::{
442 Transaction, VersionedDataSource,
443 sql::testing::TmpDb,
444 storage::{SqlStorage, StorageConnectionType, UpdateAvailabilityStorage},
445 },
446 testing::mocks::MockTypes,
447 };
448
449 #[tokio::test]
450 #[test_log::test]
451 async fn test_duplicate_payload() {
452 let storage = TmpDb::init().await;
453 let db = SqlStorage::connect(storage.config(), StorageConnectionType::Query)
454 .await
455 .unwrap();
456 let mut vid = advz_scheme(2);
457
458 let mut leaves = vec![
460 LeafQueryData::<MockTypes>::genesis(
461 &Default::default(),
462 &Default::default(),
463 TEST_VERSIONS.test,
464 )
465 .await,
466 ];
467 let mut blocks = vec![
468 BlockQueryData::<MockTypes>::genesis(
469 &Default::default(),
470 &Default::default(),
471 TEST_VERSIONS.test.base,
472 )
473 .await,
474 ];
475 let dispersal = vid.disperse([]).unwrap();
476 let mut vid = vec![VidCommonQueryData::<MockTypes>::new(
477 leaves[0].header().clone(),
478 VidCommon::V0(dispersal.common.clone()),
479 )];
480
481 let mut leaf = leaves[0].clone();
482 leaf.leaf.block_header_mut().block_number += 1;
483 let block = BlockQueryData::new(leaf.header().clone(), blocks[0].payload().clone());
484 let common =
485 VidCommonQueryData::new(leaf.header().clone(), VidCommon::V0(dispersal.common));
486 leaves.push(leaf);
487 blocks.push(block);
488 vid.push(common);
489
490 {
492 let mut tx = db.write().await.unwrap();
493 tx.insert_leaf(leaves[0].clone()).await.unwrap();
494 tx.commit().await.unwrap();
495 }
496
497 {
499 let mut tx = db.read().await.unwrap();
500 assert_eq!(tx.get_leaf(LeafId::Number(0)).await.unwrap(), leaves[0]);
501 assert_absent(
502 tx.get_block(BlockId::<MockTypes>::Number(0))
503 .await
504 .unwrap_err(),
505 );
506 assert_absent(
507 tx.get_vid_common(BlockId::<MockTypes>::Number(0))
508 .await
509 .unwrap_err(),
510 );
511 }
512
513 {
515 let mut tx = db.write().await.unwrap();
516 tx.insert_leaf(leaves[1].clone()).await.unwrap();
517 tx.insert_block(blocks[1].clone()).await.unwrap();
518 tx.insert_vid(vid[1].clone(), None).await.unwrap();
519 tx.commit().await.unwrap();
520 }
521
522 for i in 0..2 {
524 let mut tx = db.read().await.unwrap();
525 assert_eq!(tx.get_leaf(LeafId::Number(i)).await.unwrap(), leaves[i]);
526 assert_eq!(tx.get_block(BlockId::Number(i)).await.unwrap(), blocks[i]);
527 assert_eq!(tx.get_vid_common(BlockId::Number(i)).await.unwrap(), vid[i]);
528 }
529 }
530
531 #[tokio::test]
532 #[test_log::test]
533 async fn test_same_payload_different_ns_table() {
534 let storage = TmpDb::init().await;
535 let db = SqlStorage::connect(storage.config(), StorageConnectionType::Query)
536 .await
537 .unwrap();
538 let mut vid = advz_scheme(2);
539
540 let mut leaves = vec![
544 LeafQueryData::<MockTypes>::genesis(
545 &Default::default(),
546 &Default::default(),
547 TEST_VERSIONS.test,
548 )
549 .await,
550 ];
551 let mut blocks = vec![
552 BlockQueryData::<MockTypes>::genesis(
553 &Default::default(),
554 &Default::default(),
555 TEST_VERSIONS.test.base,
556 )
557 .await,
558 ];
559 let dispersal = vid.disperse([]).unwrap();
560 let mut vid = vec![VidCommonQueryData::<MockTypes>::new(
561 leaves[0].header().clone(),
562 VidCommon::V0(dispersal.common.clone()),
563 )];
564
565 let mut leaf = leaves[0].clone();
566 leaf.leaf.block_header_mut().block_number += 1;
567 leaf.leaf.block_header_mut().metadata.num_transactions += 1;
568 let block = BlockQueryData::new(leaf.header().clone(), blocks[0].payload().clone());
569 let common =
570 VidCommonQueryData::new(leaf.header().clone(), VidCommon::V0(dispersal.common));
571 leaves.push(leaf);
572 blocks.push(block);
573 vid.push(common);
574
575 {
577 let mut tx = db.write().await.unwrap();
578 tx.insert_leaf(leaves[0].clone()).await.unwrap();
579 tx.commit().await.unwrap();
580 }
581
582 {
584 let mut tx = db.read().await.unwrap();
585 assert_eq!(tx.get_leaf(LeafId::Number(0)).await.unwrap(), leaves[0]);
586 assert_absent(
587 tx.get_block(BlockId::<MockTypes>::Number(0))
588 .await
589 .unwrap_err(),
590 );
591 assert_absent(
592 tx.get_vid_common(BlockId::<MockTypes>::Number(0))
593 .await
594 .unwrap_err(),
595 );
596 }
597
598 {
600 let mut tx = db.write().await.unwrap();
601 tx.insert_leaf(leaves[1].clone()).await.unwrap();
602 tx.insert_block(blocks[1].clone()).await.unwrap();
603 tx.insert_vid(vid[1].clone(), None).await.unwrap();
604 tx.commit().await.unwrap();
605 }
606
607 let mut tx = db.read().await.unwrap();
609 for i in 0..2 {
610 assert_eq!(tx.get_leaf(LeafId::Number(i)).await.unwrap(), leaves[i]);
611 assert_eq!(tx.get_vid_common(BlockId::Number(i)).await.unwrap(), vid[i]);
612 }
613
614 assert_absent(
616 tx.get_block(BlockId::<MockTypes>::Number(0))
617 .await
618 .unwrap_err(),
619 );
620 assert_eq!(tx.get_block(BlockId::Number(1)).await.unwrap(), blocks[1]);
621 }
622
623 fn assert_absent(err: QueryError) {
624 assert!(
625 matches!(err, QueryError::Missing | QueryError::NotFound),
626 "{err:#}"
627 );
628 }
629}