1use std::ops::{Bound, RangeBounds};
16
17use alloy::primitives::map::HashMap;
18use anyhow::anyhow;
19use async_trait::async_trait;
20use futures::stream::{StreamExt, TryStreamExt};
21use hotshot_types::{
22 data::VidShare,
23 simple_certificate::CertificatePair,
24 traits::{block_contents::BlockHeader, node_implementation::NodeType},
25};
26use snafu::OptionExt;
27use sqlx::Row;
28
29use super::{
30 super::transaction::{query, query_as, Transaction, TransactionMode, Write},
31 parse_header, DecodeError, QueryBuilder, HEADER_COLUMNS,
32};
33use crate::{
34 availability::{NamespaceId, QueryableHeader},
35 data_source::storage::{
36 Aggregate, AggregatesStorage, NodeStorage, PayloadMetadata, UpdateAggregatesStorage,
37 },
38 node::{BlockId, SyncStatus, TimeWindowQueryData, WindowStart},
39 types::HeightIndexed,
40 Header, MissingSnafu, NotFoundSnafu, QueryError, QueryResult,
41};
42
43#[async_trait]
44impl<Mode, Types> NodeStorage<Types> for Transaction<Mode>
45where
46 Mode: TransactionMode,
47 Types: NodeType,
48 Header<Types>: QueryableHeader<Types>,
49{
50 async fn block_height(&mut self) -> QueryResult<usize> {
51 match query_as::<(Option<i64>,)>("SELECT max(height) FROM header")
52 .fetch_one(self.as_mut())
53 .await?
54 {
55 (Some(height),) => {
56 Ok(height as usize + 1)
59 },
60 (None,) => {
61 Ok(0)
63 },
64 }
65 }
66
67 async fn count_transactions_in_range(
68 &mut self,
69 range: impl RangeBounds<usize> + Send,
70 namespace: Option<NamespaceId<Types>>,
71 ) -> QueryResult<usize> {
72 let namespace: i64 = namespace.map(|ns| ns.into()).unwrap_or(-1);
73 let Some((from, to)) = aggregate_range_bounds::<Types>(self, range).await? else {
74 return Ok(0);
75 };
76 let (count,) = query_as::<(i64,)>(
77 "SELECT num_transactions FROM aggregate WHERE height = $1 AND namespace = $2",
78 )
79 .bind(to as i64)
80 .bind(namespace)
81 .fetch_one(self.as_mut())
82 .await?;
83 let mut count = count as usize;
84
85 if from > 0 {
86 let (prev_count,) = query_as::<(i64,)>(
87 "SELECT num_transactions FROM aggregate WHERE height = $1 AND namespace = $2",
88 )
89 .bind((from - 1) as i64)
90 .bind(namespace)
91 .fetch_one(self.as_mut())
92 .await?;
93 count = count.saturating_sub(prev_count as usize);
94 }
95
96 Ok(count)
97 }
98
99 async fn payload_size_in_range(
100 &mut self,
101 range: impl RangeBounds<usize> + Send,
102 namespace: Option<NamespaceId<Types>>,
103 ) -> QueryResult<usize> {
104 let namespace: i64 = namespace.map(|ns| ns.into()).unwrap_or(-1);
105 let Some((from, to)) = aggregate_range_bounds::<Types>(self, range).await? else {
106 return Ok(0);
107 };
108 let (size,) = query_as::<(i64,)>(
109 "SELECT payload_size FROM aggregate WHERE height = $1 AND namespace = $2",
110 )
111 .bind(to as i64)
112 .bind(namespace)
113 .fetch_one(self.as_mut())
114 .await?;
115 let mut size = size as usize;
116
117 if from > 0 {
118 let (prev_size,) = query_as::<(i64,)>(
119 "SELECT payload_size FROM aggregate WHERE height = $1 AND namespace = $2",
120 )
121 .bind((from - 1) as i64)
122 .bind(namespace)
123 .fetch_one(self.as_mut())
124 .await?;
125 size = size.saturating_sub(prev_size as usize);
126 }
127
128 Ok(size)
129 }
130
131 async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
132 where
133 ID: Into<BlockId<Types>> + Send + Sync,
134 {
135 let mut query = QueryBuilder::default();
136 let where_clause = query.header_where_clause(id.into())?;
137 let sql = format!(
140 "SELECT v.share AS share FROM vid2 AS v
141 JOIN header AS h ON v.height = h.height
142 WHERE {where_clause}
143 ORDER BY h.height
144 LIMIT 1"
145 );
146 let (share_data,) = query
147 .query_as::<(Option<Vec<u8>>,)>(&sql)
148 .fetch_one(self.as_mut())
149 .await?;
150 let share_data = share_data.context(MissingSnafu)?;
151 let share = bincode::deserialize(&share_data).decode_error("malformed VID share")?;
152 Ok(share)
153 }
154
155 async fn sync_status(&mut self) -> QueryResult<SyncStatus> {
156 let sql = "SELECT l.max_height, l.total_leaves, p.null_payloads, v.total_vid, \
180 vn.null_vid, pruned_height FROM
181 (SELECT max(leaf2.height) AS max_height, count(*) AS total_leaves FROM leaf2) AS l,
182 (SELECT count(*) AS null_payloads FROM payload WHERE data IS NULL) AS p,
183 (SELECT count(*) AS total_vid FROM vid2) AS v,
184 (SELECT count(*) AS null_vid FROM vid2 WHERE share IS NULL) AS vn,
185 (SELECT(SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1) as \
186 pruned_height)
187 ";
188 let row = query(sql)
189 .fetch_optional(self.as_mut())
190 .await?
191 .context(NotFoundSnafu)?;
192
193 let block_height = match row.get::<Option<i64>, _>("max_height") {
194 Some(height) => {
195 height as usize + 1
198 },
199 None => {
200 0
202 },
203 };
204 let total_leaves = row.get::<i64, _>("total_leaves") as usize;
205 let null_payloads = row.get::<i64, _>("null_payloads") as usize;
206 let total_vid = row.get::<i64, _>("total_vid") as usize;
207 let null_vid = row.get::<i64, _>("null_vid") as usize;
208 let pruned_height = row
209 .get::<Option<i64>, _>("pruned_height")
210 .map(|h| h as usize);
211
212 let missing_leaves = block_height.saturating_sub(total_leaves);
213 let missing_blocks = missing_leaves + null_payloads;
214 let missing_vid_common = block_height.saturating_sub(total_vid);
215 let missing_vid_shares = missing_vid_common + null_vid;
216
217 Ok(SyncStatus {
218 missing_leaves,
219 missing_blocks,
220 missing_vid_common,
221 missing_vid_shares,
222 pruned_height,
223 })
224 }
225
226 async fn get_header_window(
227 &mut self,
228 start: impl Into<WindowStart<Types>> + Send + Sync,
229 end: u64,
230 limit: usize,
231 ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
232 let first_block = match start.into() {
234 WindowStart::Time(t) => {
235 return self.time_window::<Types>(t, end, limit).await;
240 },
241 WindowStart::Height(h) => h,
242 WindowStart::Hash(h) => self.load_header::<Types>(h).await?.block_number(),
243 };
244
245 let sql = format!(
249 "SELECT {HEADER_COLUMNS}
250 FROM header AS h
251 WHERE h.height >= $1 AND h.timestamp < $2
252 ORDER BY h.height
253 LIMIT $3"
254 );
255 let rows = query(&sql)
256 .bind(first_block as i64)
257 .bind(end as i64)
258 .bind(limit as i64)
259 .fetch(self.as_mut());
260 let window = rows
261 .map(|row| parse_header::<Types>(row?))
262 .try_collect::<Vec<_>>()
263 .await?;
264
265 let prev = if first_block > 0 {
267 Some(self.load_header::<Types>(first_block as usize - 1).await?)
268 } else {
269 None
270 };
271
272 let next = if window.len() < limit {
273 let sql = format!(
282 "SELECT {HEADER_COLUMNS}
283 FROM header AS h
284 WHERE h.timestamp >= $1
285 ORDER BY h.timestamp, h.height
286 LIMIT 1"
287 );
288 query(&sql)
289 .bind(end as i64)
290 .fetch_optional(self.as_mut())
291 .await?
292 .map(parse_header::<Types>)
293 .transpose()?
294 } else {
295 tracing::debug!(limit, "cutting off header window request due to limit");
299 None
300 };
301
302 Ok(TimeWindowQueryData { window, prev, next })
303 }
304
305 async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
306 let Some((json,)) = query_as("SELECT qcs FROM latest_qc_chain LIMIT 1")
307 .fetch_optional(self.as_mut())
308 .await?
309 else {
310 return Ok(None);
311 };
312 let qcs = serde_json::from_value(json).decode_error("malformed QC")?;
313 Ok(qcs)
314 }
315}
316
317impl<Types, Mode: TransactionMode> AggregatesStorage<Types> for Transaction<Mode>
318where
319 Types: NodeType,
320 Header<Types>: QueryableHeader<Types>,
321{
322 async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
323 let (height,): (i64,) = query_as("SELECT coalesce(max(height) + 1, 0) FROM aggregate")
324 .fetch_one(self.as_mut())
325 .await?;
326 Ok(height as usize)
327 }
328
329 async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
330 let res: (Option<i64>,) =
333 query_as("SELECT max(height) FROM aggregate WHERE namespace = -1")
334 .fetch_one(self.as_mut())
335 .await?;
336
337 let (Some(max_height),) = res else {
338 return Ok(None);
339 };
340
341 let rows: Vec<(i64, i64, i64)> = query_as(
342 r#"
343 SELECT namespace, num_transactions, payload_size from aggregate WHERE height = $1
344 "#,
345 )
346 .bind(max_height)
347 .fetch_all(self.as_mut())
348 .await?;
349
350 let mut num_transactions = HashMap::default();
351 let mut payload_size = HashMap::default();
352
353 for (namespace_id, num_tx, payload_sz) in rows {
354 let key = if namespace_id == -1 {
358 None
359 } else {
360 Some(namespace_id.into())
361 };
362 num_transactions.insert(key, num_tx as usize);
363 payload_size.insert(key, payload_sz as usize);
364 }
365
366 Ok(Some(Aggregate {
367 height: max_height,
368 num_transactions,
369 payload_size,
370 }))
371 }
372}
373
374impl<Types: NodeType> UpdateAggregatesStorage<Types> for Transaction<Write>
375where
376 Header<Types>: QueryableHeader<Types>,
377{
378 async fn update_aggregates(
379 &mut self,
380 prev: Aggregate<Types>,
381 blocks: &[PayloadMetadata<Types>],
382 ) -> anyhow::Result<Aggregate<Types>> {
383 let height = blocks[0].height();
384 let (prev_tx_count, prev_size) = (prev.num_transactions, prev.payload_size);
385
386 let mut rows = Vec::new();
387
388 let aggregates = blocks
390 .iter()
391 .scan(
392 (height, prev_tx_count, prev_size),
393 |(height, tx_count, size), block| {
394 if *height != block.height {
395 return Some(Err(anyhow!(
396 "blocks in update_aggregates are not sequential; expected {}, got {}",
397 *height,
398 block.height()
399 )));
400 }
401 *height += 1;
402
403 *tx_count.entry(None).or_insert(0) += block.num_transactions as usize;
408 *size.entry(None).or_insert(0) += block.size as usize;
409
410 rows.push((
413 block.height as i64,
414 -1,
415 tx_count[&None] as i64,
416 size[&None] as i64,
417 ));
418
419 for (&ns_id, info) in &block.namespaces {
421 let key = Some(ns_id);
422
423 *tx_count.entry(key).or_insert(0) += info.num_transactions as usize;
424 *size.entry(key).or_insert(0) += info.size as usize;
425 }
426
427 for ns_id in tx_count.keys().filter_map(|k| k.as_ref()) {
431 let key = Some(*ns_id);
432 rows.push((
433 block.height as i64,
434 (*ns_id).into(),
435 tx_count[&key] as i64,
436 size[&key] as i64,
437 ));
438 }
439
440 Some(Ok((block.height as i64, tx_count.clone(), size.clone())))
441 },
442 )
443 .collect::<anyhow::Result<Vec<_>>>()?;
444 let last_aggregate = aggregates.last().cloned();
445
446 let (height, num_transactions, payload_size) =
447 last_aggregate.ok_or_else(|| anyhow!("no row"))?;
448
449 self.upsert(
450 "aggregate",
451 ["height", "namespace", "num_transactions", "payload_size"],
452 ["height", "namespace"],
453 rows,
454 )
455 .await?;
456 Ok(Aggregate {
457 height,
458 num_transactions,
459 payload_size,
460 })
461 }
462}
463
464impl<Mode: TransactionMode> Transaction<Mode> {
465 async fn time_window<Types: NodeType>(
466 &mut self,
467 start: u64,
468 end: u64,
469 limit: usize,
470 ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
471 let sql = format!(
483 "SELECT {HEADER_COLUMNS}
484 FROM header AS h
485 WHERE h.timestamp >= $1 AND h.timestamp < $2
486 ORDER BY h.timestamp, h.height
487 LIMIT $3"
488 );
489 let rows = query(&sql)
490 .bind(start as i64)
491 .bind(end as i64)
492 .bind(limit as i64)
493 .fetch(self.as_mut());
494 let window: Vec<_> = rows
495 .map(|row| parse_header::<Types>(row?))
496 .try_collect()
497 .await?;
498
499 let next = if window.len() < limit {
500 let sql = format!(
502 "SELECT {HEADER_COLUMNS}
503 FROM header AS h
504 WHERE h.timestamp >= $1
505 ORDER BY h.timestamp, h.height
506 LIMIT 1"
507 );
508 query(&sql)
509 .bind(end as i64)
510 .fetch_optional(self.as_mut())
511 .await?
512 .map(parse_header::<Types>)
513 .transpose()?
514 } else {
515 tracing::debug!(limit, "cutting off header window request due to limit");
519 None
520 };
521
522 if window.is_empty() && next.is_none() {
530 return Err(QueryError::NotFound);
531 }
532
533 let sql = format!(
535 "SELECT {HEADER_COLUMNS}
536 FROM header AS h
537 WHERE h.timestamp < $1
538 ORDER BY h.timestamp DESC, h.height DESC
539 LIMIT 1"
540 );
541 let prev = query(&sql)
542 .bind(start as i64)
543 .fetch_optional(self.as_mut())
544 .await?
545 .map(parse_header::<Types>)
546 .transpose()?;
547
548 Ok(TimeWindowQueryData { window, prev, next })
549 }
550}
551
552async fn aggregate_range_bounds<Types>(
557 tx: &mut Transaction<impl TransactionMode>,
558 range: impl RangeBounds<usize>,
559) -> QueryResult<Option<(usize, usize)>>
560where
561 Types: NodeType,
562 Header<Types>: QueryableHeader<Types>,
563{
564 let from = match range.start_bound() {
565 Bound::Included(from) => *from,
566 Bound::Excluded(from) => *from + 1,
567 Bound::Unbounded => 0,
568 };
569 let to = match range.end_bound() {
570 Bound::Included(to) => *to,
571 Bound::Excluded(0) => return Ok(None),
572 Bound::Excluded(to) => *to - 1,
573 Bound::Unbounded => {
574 let height = AggregatesStorage::<Types>::aggregates_height(tx)
575 .await
576 .map_err(|err| QueryError::Error {
577 message: format!("{err:#}"),
578 })?;
579 if height == 0 {
580 return Ok(None);
581 }
582 if height < from {
583 return Ok(None);
584 }
585 height - 1
586 },
587 };
588 Ok(Some((from, to)))
589}