1use std::{
16 collections::{HashMap, HashSet, VecDeque},
17 sync::Arc,
18};
19
20use ark_serialize::CanonicalDeserialize;
21use async_trait::async_trait;
22use futures::stream::TryStreamExt;
23use hotshot_types::traits::node_implementation::NodeType;
24use jf_merkle_tree_compat::{
25 prelude::{MerkleNode, MerkleProof},
26 DigestAlgorithm, MerkleCommitment, ToTraversalPath,
27};
28use sqlx::types::{BitVec, JsonValue};
29
30use super::{
31 super::transaction::{query_as, Transaction, TransactionMode, Write},
32 DecodeError, QueryBuilder,
33};
34use crate::{
35 data_source::storage::{
36 pruning::PrunedHeightStorage,
37 sql::{build_where_in, sqlx::Row},
38 MerklizedStateHeightStorage, MerklizedStateStorage,
39 },
40 merklized_state::{MerklizedState, Snapshot},
41 QueryError, QueryResult,
42};
43
44#[async_trait]
45impl<Mode, Types, State, const ARITY: usize> MerklizedStateStorage<Types, State, ARITY>
46 for Transaction<Mode>
47where
48 Mode: TransactionMode,
49 Types: NodeType,
50 State: MerklizedState<Types, ARITY> + 'static,
51{
52 async fn get_path(
54 &mut self,
55 snapshot: Snapshot<Types, State, ARITY>,
56 key: State::Key,
57 ) -> QueryResult<MerkleProof<State::Entry, State::Key, State::T, ARITY>> {
58 let state_type = State::state_type();
59 let tree_height = State::tree_height();
60
61 let traversal_path = State::Key::to_traversal_path(&key, tree_height);
63 let (created, merkle_commitment) = self.snapshot_info(snapshot).await?;
64
65 let (query, sql) = build_get_path_query(state_type, traversal_path.clone(), created)?;
68 let rows = query.query(&sql).fetch_all(self.as_mut()).await?;
69
70 let nodes: Vec<Node> = rows.into_iter().map(|r| r.into()).collect();
71
72 let mut hash_ids = HashSet::new();
75 for node in nodes.iter() {
76 hash_ids.insert(node.hash_id);
77 if let Some(children) = &node.children {
78 let children: Vec<i32> =
79 serde_json::from_value(children.clone()).map_err(|e| QueryError::Error {
80 message: format!("Error deserializing 'children' into Vec<i32>: {e}"),
81 })?;
82 hash_ids.extend(children);
83 }
84 }
85
86 let hashes = if !hash_ids.is_empty() {
89 let (query, sql) = build_where_in("SELECT id, value FROM hash", "id", hash_ids)?;
90 query
91 .query_as(&sql)
92 .fetch(self.as_mut())
93 .try_collect::<HashMap<i32, Vec<u8>>>()
94 .await?
95 } else {
96 HashMap::new()
97 };
98
99 let mut proof_path = VecDeque::with_capacity(State::tree_height());
100 for Node {
101 hash_id,
102 children,
103 children_bitvec,
104 idx,
105 entry,
106 ..
107 } in nodes.iter()
108 {
109 {
110 let value = hashes.get(hash_id).ok_or(QueryError::Error {
111 message: format!("node's value references non-existent hash {hash_id}"),
112 })?;
113
114 match (children, children_bitvec, idx, entry) {
115 (Some(children), Some(children_bitvec), None, None) => {
117 let children: Vec<i32> =
118 serde_json::from_value(children.clone()).map_err(|e| {
119 QueryError::Error {
120 message: format!(
121 "Error deserializing 'children' into Vec<i32>: {e}"
122 ),
123 }
124 })?;
125 let mut children = children.iter();
126
127 let child_nodes = children_bitvec
130 .iter()
131 .map(|bit| {
132 if bit {
133 let hash_id = children.next().ok_or(QueryError::Error {
134 message: "node has fewer children than set bits".into(),
135 })?;
136 let value = hashes.get(hash_id).ok_or(QueryError::Error {
137 message: format!(
138 "node's child references non-existent hash {hash_id}"
139 ),
140 })?;
141 Ok(Arc::new(MerkleNode::ForgettenSubtree {
142 value: State::T::deserialize_compressed(value.as_slice())
143 .decode_error("malformed merkle node value")?,
144 }))
145 } else {
146 Ok(Arc::new(MerkleNode::Empty))
147 }
148 })
149 .collect::<QueryResult<Vec<_>>>()?;
150 proof_path.push_back(MerkleNode::Branch {
152 value: State::T::deserialize_compressed(value.as_slice())
153 .decode_error("malformed merkle node value")?,
154 children: child_nodes,
155 });
156 },
157 (None, None, Some(index), Some(entry)) => {
159 proof_path.push_back(MerkleNode::Leaf {
160 value: State::T::deserialize_compressed(value.as_slice())
161 .decode_error("malformed merkle node value")?,
162 pos: serde_json::from_value(index.clone())
163 .decode_error("malformed merkle node index")?,
164 elem: serde_json::from_value(entry.clone())
165 .decode_error("malformed merkle element")?,
166 });
167 },
168 (None, None, Some(_), None) => {
170 proof_path.push_back(MerkleNode::Empty);
171 },
172 _ => {
173 return Err(QueryError::Error {
174 message: "Invalid type of merkle node found".to_string(),
175 });
176 },
177 }
178 }
179 }
180
181 let init = if let Some(MerkleNode::Leaf { value, .. }) = proof_path.front() {
183 *value
184 } else {
185 while proof_path.len() <= State::tree_height() {
191 proof_path.push_front(MerkleNode::Empty);
192 }
193 State::T::default()
194 };
195 let commitment_from_path = traversal_path
196 .iter()
197 .zip(proof_path.iter().skip(1))
198 .try_fold(init, |val, (branch, node)| -> QueryResult<State::T> {
199 match node {
200 MerkleNode::Branch { value: _, children } => {
201 let data = children
202 .iter()
203 .map(|node| match node.as_ref() {
204 MerkleNode::ForgettenSubtree { value } => Ok(*value),
205 MerkleNode::Empty => Ok(State::T::default()),
206 _ => Err(QueryError::Error {
207 message: "Invalid child node".to_string(),
208 }),
209 })
210 .collect::<QueryResult<Vec<_>>>()?;
211
212 if data[*branch] != val {
213 tracing::warn!(
217 ?key,
218 parent = ?data[*branch],
219 child = ?val,
220 branch = %*branch,
221 %created,
222 %merkle_commitment,
223 "missing data in merklized state; parent-child mismatch",
224 );
225 return Err(QueryError::Missing);
226 }
227
228 State::Digest::digest(&data).map_err(|err| QueryError::Error {
229 message: format!("failed to update digest: {err:#}"),
230 })
231 },
232 MerkleNode::Empty => Ok(init),
233 _ => Err(QueryError::Error {
234 message: "Invalid type of Node in the proof".to_string(),
235 }),
236 }
237 })?;
238
239 if commitment_from_path != merkle_commitment.digest() {
240 return Err(QueryError::Error {
241 message: format!(
242 "Commitment calculated from merkle path ({commitment_from_path:?}) does not \
243 match the commitment in the header ({:?})",
244 merkle_commitment.digest()
245 ),
246 });
247 }
248
249 Ok(MerkleProof {
250 pos: key,
251 proof: proof_path.into(),
252 })
253 }
254}
255
256#[async_trait]
257impl<Mode: TransactionMode> MerklizedStateHeightStorage for Transaction<Mode> {
258 async fn get_last_state_height(&mut self) -> QueryResult<usize> {
259 let Some((height,)) = query_as::<(i64,)>("SELECT height from last_merklized_state_height")
260 .fetch_optional(self.as_mut())
261 .await?
262 else {
263 return Ok(0);
264 };
265 Ok(height as usize)
266 }
267}
268
269impl<Mode: TransactionMode> Transaction<Mode> {
270 async fn snapshot_info<Types, State, const ARITY: usize>(
276 &mut self,
277 snapshot: Snapshot<Types, State, ARITY>,
278 ) -> QueryResult<(i64, State::Commit)>
279 where
280 Types: NodeType,
281 State: MerklizedState<Types, ARITY>,
282 {
283 let header_state_commitment_field = State::header_state_commitment_field();
284
285 let (created, commit) = match snapshot {
286 Snapshot::Commit(commit) => {
287 let (height,) = query_as(&format!(
293 "SELECT height
294 FROM header
295 WHERE {header_state_commitment_field} = $1
296 LIMIT 1"
297 ))
298 .bind(commit.to_string())
299 .fetch_one(self.as_mut())
300 .await?;
301
302 (height, commit)
303 },
304 Snapshot::Index(created) => {
305 let created = created as i64;
306 let (commit,) = query_as::<(String,)>(&format!(
307 "SELECT {header_state_commitment_field} AS root_commitment
308 FROM header
309 WHERE height = $1
310 LIMIT 1"
311 ))
312 .bind(created)
313 .fetch_one(self.as_mut())
314 .await?;
315 let commit = serde_json::from_value(commit.into())
316 .decode_error("malformed state commitment")?;
317 (created, commit)
318 },
319 };
320
321 let height = self.get_last_state_height().await?;
323
324 if height < (created as usize) {
325 return Err(QueryError::NotFound);
326 }
327
328 let pruned_height = self
329 .load_pruned_height()
330 .await
331 .map_err(|e| QueryError::Error {
332 message: format!("failed to load pruned height: {e}"),
333 })?;
334
335 if pruned_height.is_some_and(|h| height <= h as usize) {
336 return Err(QueryError::NotFound);
337 }
338
339 Ok((created, commit))
340 }
341}
342
343pub(crate) fn build_hash_batch_insert(
345 hashes: &[Vec<u8>],
346) -> QueryResult<(QueryBuilder<'_>, String)> {
347 let mut query = QueryBuilder::default();
348 let params = hashes
349 .iter()
350 .map(|hash| Ok(format!("({})", query.bind(hash)?)))
351 .collect::<QueryResult<Vec<String>>>()?;
352 let sql = format!(
353 "INSERT INTO hash(value) values {} ON CONFLICT (value) DO UPDATE SET value = \
354 EXCLUDED.value returning value, id",
355 params.join(",")
356 );
357 Ok((query, sql))
358}
359
360#[derive(Debug, Default, Clone)]
362pub(crate) struct Node {
363 pub(crate) path: JsonValue,
364 pub(crate) created: i64,
365 pub(crate) hash_id: i32,
366 pub(crate) children: Option<JsonValue>,
367 pub(crate) children_bitvec: Option<BitVec>,
368 pub(crate) idx: Option<JsonValue>,
369 pub(crate) entry: Option<JsonValue>,
370}
371
372#[cfg(feature = "embedded-db")]
373impl From<sqlx::sqlite::SqliteRow> for Node {
374 fn from(row: sqlx::sqlite::SqliteRow) -> Self {
375 let bit_string: Option<String> = row.get_unchecked("children_bitvec");
376 let children_bitvec: Option<BitVec> =
377 bit_string.map(|b| b.chars().map(|c| c == '1').collect());
378
379 Self {
380 path: row.get_unchecked("path"),
381 created: row.get_unchecked("created"),
382 hash_id: row.get_unchecked("hash_id"),
383 children: row.get_unchecked("children"),
384 children_bitvec,
385 idx: row.get_unchecked("idx"),
386 entry: row.get_unchecked("entry"),
387 }
388 }
389}
390
391#[cfg(not(feature = "embedded-db"))]
392impl From<sqlx::postgres::PgRow> for Node {
393 fn from(row: sqlx::postgres::PgRow) -> Self {
394 Self {
395 path: row.get_unchecked("path"),
396 created: row.get_unchecked("created"),
397 hash_id: row.get_unchecked("hash_id"),
398 children: row.get_unchecked("children"),
399 children_bitvec: row.get_unchecked("children_bitvec"),
400 idx: row.get_unchecked("idx"),
401 entry: row.get_unchecked("entry"),
402 }
403 }
404}
405
406impl Node {
407 pub(crate) async fn upsert(
408 name: &str,
409 nodes: impl IntoIterator<Item = Self>,
410 tx: &mut Transaction<Write>,
411 ) -> anyhow::Result<()> {
412 tx.upsert(
413 name,
414 [
415 "path",
416 "created",
417 "hash_id",
418 "children",
419 "children_bitvec",
420 "idx",
421 "entry",
422 ],
423 ["path", "created"],
424 nodes.into_iter().map(|n| {
425 #[cfg(feature = "embedded-db")]
426 let children_bitvec: Option<String> = n
427 .children_bitvec
428 .clone()
429 .map(|b| b.iter().map(|bit| if bit { '1' } else { '0' }).collect());
430
431 #[cfg(not(feature = "embedded-db"))]
432 let children_bitvec = n.children_bitvec.clone();
433
434 (
435 n.path.clone(),
436 n.created,
437 n.hash_id,
438 n.children.clone(),
439 children_bitvec,
440 n.idx.clone(),
441 n.entry.clone(),
442 )
443 }),
444 )
445 .await
446 }
447}
448
449fn build_get_path_query<'q>(
450 table: &'static str,
451 traversal_path: Vec<usize>,
452 created: i64,
453) -> QueryResult<(QueryBuilder<'q>, String)> {
454 let mut query = QueryBuilder::default();
455 let mut traversal_path = traversal_path.into_iter().map(|x| x as i32);
456
457 let len = traversal_path.len();
459 let mut sub_queries = Vec::new();
460
461 query.bind(created)?;
462
463 for _ in 0..=len {
464 let path = traversal_path.clone().rev().collect::<Vec<_>>();
465 let path: serde_json::Value = path.into();
466 let node_path = query.bind(path)?;
467
468 let sub_query = format!(
469 "SELECT * FROM (SELECT * FROM {table} WHERE path = {node_path} AND created <= $1 \
470 ORDER BY created DESC LIMIT 1)",
471 );
472
473 sub_queries.push(sub_query);
474 traversal_path.next();
475 }
476
477 let mut sql: String = sub_queries.join(" UNION ");
478
479 sql = format!("SELECT * FROM ({sql}) as t ");
480
481 if cfg!(feature = "embedded-db") {
484 sql.push_str("ORDER BY length(t.path) DESC");
485 } else {
486 sql.push_str("ORDER BY t.path DESC");
487 }
488
489 Ok((query, sql))
490}
491
492#[cfg(test)]
493mod test {
494 use futures::stream::StreamExt;
495 use jf_merkle_tree_compat::{
496 universal_merkle_tree::UniversalMerkleTree, LookupResult, MerkleTreeScheme,
497 UniversalMerkleTreeScheme,
498 };
499 use rand::{seq::IteratorRandom, RngCore};
500
501 use super::*;
502 use crate::{
503 data_source::{
504 storage::sql::{testing::TmpDb, *},
505 VersionedDataSource,
506 },
507 merklized_state::UpdateStateData,
508 testing::mocks::{MockMerkleTree, MockTypes},
509 };
510
511 #[test_log::test(tokio::test(flavor = "multi_thread"))]
512 async fn test_merklized_state_storage() {
513 let db = TmpDb::init().await;
517 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
518 .await
519 .unwrap();
520
521 let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
523 MockMerkleTree::new(MockMerkleTree::tree_height());
524 let block_height = 1;
525
526 let mut tx = storage.write().await.unwrap();
529 for i in 0..27 {
530 test_tree.update(i, i).unwrap();
531
532 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
534 tx.upsert(
535 "header",
536 ["height", "hash", "payload_hash", "timestamp", "data"],
537 ["height"],
538 [(
539 block_height as i64,
540 format!("randomHash{i}"),
541 "t".to_string(),
542 0,
543 test_data,
544 )],
545 )
546 .await
547 .unwrap();
548 let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
550 let traversal_path =
552 <usize as ToTraversalPath<8>>::to_traversal_path(&i, test_tree.height());
553
554 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
555 &mut tx,
556 proof.clone(),
557 traversal_path.clone(),
558 block_height as u64,
559 )
560 .await
561 .expect("failed to insert nodes");
562 }
563 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
565 .await
566 .unwrap();
567 tx.commit().await.unwrap();
568
569 for i in 0..27 {
571 let mut tx = storage.read().await.unwrap();
573 let merkle_path = tx
574 .get_path(
575 Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
576 i,
577 )
578 .await
579 .unwrap();
580
581 let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
582
583 tracing::info!("merkle path {:?}", merkle_path);
584
585 assert_eq!(merkle_path, proof.clone(), "merkle paths mismatch");
587 }
588
589 let (_, proof_bh_1) = test_tree.lookup(0).expect_ok().unwrap();
591 test_tree.update(0, 99).unwrap();
595 let mut tx = storage.write().await.unwrap();
599 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
600 tx.upsert(
601 "header",
602 ["height", "hash", "payload_hash", "timestamp", "data"],
603 ["height"],
604 [(
605 2i64,
606 "randomstring".to_string(),
607 "t".to_string(),
608 0,
609 test_data,
610 )],
611 )
612 .await
613 .unwrap();
614 let (_, proof_bh_2) = test_tree.lookup(0).expect_ok().unwrap();
615 let traversal_path =
617 <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
618 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
621 &mut tx,
622 proof_bh_2.clone(),
623 traversal_path.clone(),
624 2,
625 )
626 .await
627 .expect("failed to insert nodes");
628 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 2)
630 .await
631 .unwrap();
632 tx.commit().await.unwrap();
633
634 let node_path = traversal_path
635 .into_iter()
636 .rev()
637 .map(|n| n as i32)
638 .collect::<Vec<_>>();
639
640 let mut tx = storage.read().await.unwrap();
642 let rows = query("SELECT * from test_tree where path = $1 ORDER BY created")
643 .bind(serde_json::to_value(node_path).unwrap())
644 .fetch(tx.as_mut());
645
646 let nodes: Vec<Node> = rows.map(|res| res.unwrap().into()).collect().await;
647 assert!(nodes.len() == 2, "incorrect number of nodes");
649 assert_eq!(nodes[0].created, 1, "wrong block height");
650 assert_eq!(nodes[1].created, 2, "wrong block height");
651
652 let path_with_bh_2 = storage
657 .read()
658 .await
659 .unwrap()
660 .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2), 0)
661 .await
662 .unwrap();
663
664 assert_eq!(path_with_bh_2, proof_bh_2);
665 let path_with_bh_1 = storage
666 .read()
667 .await
668 .unwrap()
669 .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(1), 0)
670 .await
671 .unwrap();
672 assert_eq!(path_with_bh_1, proof_bh_1);
673 }
674
675 #[test_log::test(tokio::test(flavor = "multi_thread"))]
676 async fn test_merklized_state_non_membership_proof() {
677 let db = TmpDb::init().await;
684 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
685 .await
686 .unwrap();
687
688 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
690 let block_height = 1;
691 test_tree.update(0, 0).unwrap();
693 let commitment = test_tree.commitment();
694
695 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()});
696 let mut tx = storage.write().await.unwrap();
698 tx.upsert(
699 "header",
700 ["height", "hash", "payload_hash", "timestamp", "data"],
701 ["height"],
702 [(
703 block_height as i64,
704 "randomString".to_string(),
705 "t".to_string(),
706 0,
707 test_data,
708 )],
709 )
710 .await
711 .unwrap();
712 let (_, proof_before_remove) = test_tree.lookup(0).expect_ok().unwrap();
714 let traversal_path =
716 <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
717 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
719 &mut tx,
720 proof_before_remove.clone(),
721 traversal_path.clone(),
722 block_height as u64,
723 )
724 .await
725 .expect("failed to insert nodes");
726 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
728 .await
729 .unwrap();
730 tx.commit().await.unwrap();
731 let merkle_path = storage
733 .read()
734 .await
735 .unwrap()
736 .get_path(
737 Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
738 0,
739 )
740 .await
741 .unwrap();
742
743 assert_eq!(
745 merkle_path,
746 proof_before_remove.clone(),
747 "merkle paths mismatch"
748 );
749
750 test_tree.remove(0).expect("failed to delete index 0 ");
752
753 let proof_after_remove = test_tree.universal_lookup(0).expect_not_found().unwrap();
756
757 let mut tx = storage.write().await.unwrap();
758 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
759 &mut tx,
760 proof_after_remove.clone(),
761 traversal_path.clone(),
762 2_u64,
763 )
764 .await
765 .expect("failed to insert nodes");
766 tx.upsert(
768 "header",
769 ["height", "hash", "payload_hash", "timestamp", "data"],
770 ["height"],
771 [(
772 2i64,
773 "randomString2".to_string(),
774 "t".to_string(),
775 0,
776 serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}),
777 )],
778 )
779 .await
780 .unwrap();
781 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 2)
783 .await
784 .unwrap();
785 tx.commit().await.unwrap();
786 let non_membership_path = storage
788 .read()
789 .await
790 .unwrap()
791 .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2_u64), 0)
792 .await
793 .unwrap();
794 assert_eq!(
796 non_membership_path, proof_after_remove,
797 "merkle paths dont match"
798 );
799
800 let proof_bh_1 = storage
805 .read()
806 .await
807 .unwrap()
808 .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(1_u64), 0)
809 .await
810 .unwrap();
811 assert_eq!(proof_bh_1, proof_before_remove, "merkle paths dont match");
812 }
813
814 #[test_log::test(tokio::test(flavor = "multi_thread"))]
815 async fn test_merklized_state_non_membership_proof_unseen_entry() {
816 let db = TmpDb::init().await;
817 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
818 .await
819 .unwrap();
820
821 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
823
824 for i in 0..=2 {
827 tracing::info!(i, ?test_tree, "testing non-membership proof");
828 let mut tx = storage.write().await.unwrap();
829
830 tx.upsert(
832 "header",
833 ["height", "hash", "payload_hash", "timestamp", "data"],
834 ["height"],
835 [(
836 i as i64,
837 format!("hash{i}"),
838 "t".to_string(),
839 0,
840 serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()})
841 )],
842 )
843 .await
844 .unwrap();
845 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, i)
847 .await
848 .unwrap();
849 tx.commit().await.unwrap();
850
851 let proof = storage
853 .read()
854 .await
855 .unwrap()
856 .get_path(
857 Snapshot::<MockTypes, MockMerkleTree, 8>::Index(i as u64),
858 100,
859 )
860 .await
861 .unwrap();
862 assert_eq!(proof.elem(), None);
863
864 assert!(
865 MockMerkleTree::non_membership_verify(test_tree.commitment(), 100, proof).unwrap()
866 );
867
868 test_tree.update(i, i).unwrap();
870 let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
871 let traversal_path = ToTraversalPath::<8>::to_traversal_path(&i, test_tree.height());
872 let mut tx = storage.write().await.unwrap();
873 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
874 &mut tx,
875 proof,
876 traversal_path,
877 (i + 1) as u64,
878 )
879 .await
880 .expect("failed to insert nodes");
881 tx.commit().await.unwrap();
882 }
883 }
884
885 #[test_log::test(tokio::test(flavor = "multi_thread"))]
886 async fn test_merklized_storage_with_commit() {
887 let db = TmpDb::init().await;
890 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
891 .await
892 .unwrap();
893
894 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
896 let block_height = 1;
897 test_tree.update(0, 0).unwrap();
899 let commitment = test_tree.commitment();
900
901 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()});
902 let mut tx = storage.write().await.unwrap();
904 tx.upsert(
905 "header",
906 ["height", "hash", "payload_hash", "timestamp", "data"],
907 ["height"],
908 [(
909 block_height as i64,
910 "randomString".to_string(),
911 "t".to_string(),
912 0,
913 test_data,
914 )],
915 )
916 .await
917 .unwrap();
918 let (_, proof) = test_tree.lookup(0).expect_ok().unwrap();
920 let traversal_path =
922 <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
923 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
925 &mut tx,
926 proof.clone(),
927 traversal_path.clone(),
928 block_height as u64,
929 )
930 .await
931 .expect("failed to insert nodes");
932 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
934 .await
935 .unwrap();
936 tx.commit().await.unwrap();
937
938 let merkle_proof = storage
939 .read()
940 .await
941 .unwrap()
942 .get_path(Snapshot::<_, MockMerkleTree, 8>::Commit(commitment), 0)
943 .await
944 .unwrap();
945
946 let (_, proof) = test_tree.lookup(0).expect_ok().unwrap();
947
948 assert_eq!(merkle_proof, proof.clone(), "merkle paths mismatch");
949 }
950 #[test_log::test(tokio::test(flavor = "multi_thread"))]
951 async fn test_merklized_state_missing_state() {
952 let db = TmpDb::init().await;
960 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
961 .await
962 .unwrap();
963
964 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
966 let block_height = 1;
967 let mut tx = storage.write().await.unwrap();
970 for i in 0..27 {
971 test_tree.update(i, i).unwrap();
972 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
973 tx.upsert(
975 "header",
976 ["height", "hash", "payload_hash", "timestamp", "data"],
977 ["height"],
978 [(
979 block_height as i64,
980 format!("rarndomString{i}"),
981 "t".to_string(),
982 0,
983 test_data,
984 )],
985 )
986 .await
987 .unwrap();
988 let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
990 let traversal_path =
992 <usize as ToTraversalPath<8>>::to_traversal_path(&i, test_tree.height());
993 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
995 &mut tx,
996 proof.clone(),
997 traversal_path.clone(),
998 block_height as u64,
999 )
1000 .await
1001 .expect("failed to insert nodes");
1002 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
1004 .await
1005 .unwrap();
1006 }
1007
1008 test_tree.update(1, 100).unwrap();
1009 let traversal_path =
1011 <usize as ToTraversalPath<8>>::to_traversal_path(&1, test_tree.height());
1012 let (_, proof) = test_tree.lookup(1).expect_ok().unwrap();
1013
1014 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1016 &mut tx,
1017 proof.clone(),
1018 traversal_path.clone(),
1019 block_height as u64,
1020 )
1021 .await
1022 .expect("failed to insert nodes");
1023 tx.commit().await.unwrap();
1024
1025 let merkle_path = storage
1026 .read()
1027 .await
1028 .unwrap()
1029 .get_path(
1030 Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
1031 1,
1032 )
1033 .await;
1034 assert!(merkle_path.is_err());
1035
1036 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1037 let mut tx = storage.write().await.unwrap();
1039 tx.upsert(
1040 "header",
1041 ["height", "hash", "payload_hash", "timestamp", "data"],
1042 ["height"],
1043 [(
1044 block_height as i64,
1045 "randomStringgg".to_string(),
1046 "t".to_string(),
1047 0,
1048 test_data,
1049 )],
1050 )
1051 .await
1052 .unwrap();
1053 tx.commit().await.unwrap();
1054 let merkle_proof = storage
1056 .read()
1057 .await
1058 .unwrap()
1059 .get_path(
1060 Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
1061 1,
1062 )
1063 .await
1064 .unwrap();
1065 assert_eq!(merkle_proof, proof, "path dont match");
1066
1067 test_tree.update(1, 200).unwrap();
1070
1071 let (_, proof) = test_tree.lookup(1).expect_ok().unwrap();
1072 let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1073
1074 let mut tx = storage.write().await.unwrap();
1076 tx.upsert(
1077 "header",
1078 ["height", "hash", "payload_hash", "timestamp", "data"],
1079 ["height"],
1080 [(
1081 2i64,
1082 "randomHashString".to_string(),
1083 "t".to_string(),
1084 0,
1085 test_data,
1086 )],
1087 )
1088 .await
1089 .unwrap();
1090 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1091 &mut tx,
1092 proof.clone(),
1093 traversal_path.clone(),
1094 2_u64,
1095 )
1096 .await
1097 .expect("failed to insert nodes");
1098
1099 let node_path = traversal_path
1101 .iter()
1102 .skip(1)
1103 .rev()
1104 .map(|n| *n as i32)
1105 .collect::<Vec<_>>();
1106 tx.execute(
1107 query(&format!(
1108 "DELETE FROM {} WHERE created = 2 and path = $1",
1109 MockMerkleTree::state_type()
1110 ))
1111 .bind(serde_json::to_value(node_path).unwrap()),
1112 )
1113 .await
1114 .expect("failed to delete internal node");
1115 tx.commit().await.unwrap();
1116
1117 let merkle_path = storage
1118 .read()
1119 .await
1120 .unwrap()
1121 .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2_u64), 1)
1122 .await;
1123
1124 assert!(merkle_path.is_err());
1125 }
1126
1127 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1128 async fn test_merklized_state_snapshot() {
1129 let db = TmpDb::init().await;
1130 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1131 .await
1132 .unwrap();
1133
1134 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1136
1137 const RESERVED_KEY: usize = (u32::MAX as usize) + 1;
1140
1141 #[tracing::instrument(skip(tree, expected))]
1144 fn randomize(tree: &mut MockMerkleTree, expected: &mut HashMap<usize, Option<usize>>) {
1145 let mut rng = rand::thread_rng();
1146 tracing::info!("randomizing tree");
1147
1148 for _ in 0..50 {
1149 if !expected.values().any(|v| v.is_some()) || rng.next_u32().is_multiple_of(2) {
1152 let key = rng.next_u32() as usize;
1154 let val = rng.next_u32() as usize;
1155 tracing::info!(key, val, "inserting");
1156
1157 tree.update(key, val).unwrap();
1158 expected.insert(key, Some(val));
1159 } else {
1160 let key = expected
1162 .iter()
1163 .filter_map(|(k, v)| if v.is_some() { Some(k) } else { None })
1164 .choose(&mut rng)
1165 .unwrap();
1166 tracing::info!(key, "deleting");
1167
1168 tree.remove(key).unwrap();
1169 expected.insert(*key, None);
1170 }
1171 }
1172 }
1173
1174 #[tracing::instrument(skip(storage, tree, expected))]
1176 async fn store(
1177 storage: &SqlStorage,
1178 tree: &MockMerkleTree,
1179 expected: &HashMap<usize, Option<usize>>,
1180 block_height: u64,
1181 ) {
1182 tracing::info!("persisting tree");
1183 let mut tx = storage.write().await.unwrap();
1184
1185 for key in expected.keys() {
1186 let proof = match tree.universal_lookup(key) {
1187 LookupResult::Ok(_, proof) => proof,
1188 LookupResult::NotFound(proof) => proof,
1189 LookupResult::NotInMemory => panic!("failed to find key {key}"),
1190 };
1191 let traversal_path = ToTraversalPath::<8>::to_traversal_path(key, tree.height());
1192 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1193 &mut tx,
1194 proof,
1195 traversal_path,
1196 block_height,
1197 )
1198 .await
1199 .unwrap();
1200 }
1201 tx
1203 .upsert("header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"],
1204 [(
1205 block_height as i64,
1206 format!("hash{block_height}"),
1207 "hash".to_string(),
1208 0i64,
1209 serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(tree.commitment()).unwrap()}),
1210 )],
1211 )
1212 .await
1213 .unwrap();
1214 UpdateStateData::<MockTypes, MockMerkleTree, 8>::set_last_state_height(
1215 &mut tx,
1216 block_height as usize,
1217 )
1218 .await
1219 .unwrap();
1220 tx.commit().await.unwrap();
1221 }
1222
1223 #[tracing::instrument(skip(storage, tree, expected))]
1224 async fn validate(
1225 storage: &SqlStorage,
1226 tree: &MockMerkleTree,
1227 expected: &HashMap<usize, Option<usize>>,
1228 block_height: u64,
1229 ) {
1230 tracing::info!("validating snapshot");
1231
1232 let snapshot = Snapshot::<_, MockMerkleTree, 8>::Index(block_height);
1234
1235 for (key, val) in expected {
1236 let proof = match tree.universal_lookup(key) {
1237 LookupResult::Ok(_, proof) => proof,
1238 LookupResult::NotFound(proof) => proof,
1239 LookupResult::NotInMemory => panic!("failed to find key {key}"),
1240 };
1241 assert_eq!(
1242 proof,
1243 storage
1244 .read()
1245 .await
1246 .unwrap()
1247 .get_path(snapshot, *key)
1248 .await
1249 .unwrap()
1250 );
1251 assert_eq!(val.as_ref(), proof.elem());
1252 if val.is_some() {
1254 MockMerkleTree::verify(tree.commitment(), key, proof)
1255 .unwrap()
1256 .unwrap();
1257 } else {
1258 assert!(
1259 MockMerkleTree::non_membership_verify(tree.commitment(), key, proof)
1260 .unwrap()
1261 );
1262 }
1263 }
1264
1265 let proof = match tree.universal_lookup(RESERVED_KEY) {
1267 LookupResult::Ok(_, proof) => proof,
1268 LookupResult::NotFound(proof) => proof,
1269 LookupResult::NotInMemory => panic!("failed to find reserved key {RESERVED_KEY}"),
1270 };
1271 assert_eq!(
1272 proof,
1273 storage
1274 .read()
1275 .await
1276 .unwrap()
1277 .get_path(snapshot, RESERVED_KEY)
1278 .await
1279 .unwrap()
1280 );
1281 assert_eq!(proof.elem(), None);
1282 assert!(
1284 MockMerkleTree::non_membership_verify(tree.commitment(), RESERVED_KEY, proof)
1285 .unwrap()
1286 );
1287 }
1288
1289 let mut expected = HashMap::<usize, Option<usize>>::new();
1291 randomize(&mut test_tree, &mut expected);
1292
1293 store(&storage, &test_tree, &expected, 1).await;
1295 validate(&storage, &test_tree, &expected, 1).await;
1296
1297 let mut expected2 = expected.clone();
1299 let mut test_tree2 = test_tree.clone();
1300 randomize(&mut test_tree2, &mut expected2);
1301 store(&storage, &test_tree2, &expected2, 2).await;
1302 validate(&storage, &test_tree2, &expected2, 2).await;
1303
1304 validate(&storage, &test_tree, &expected, 1).await;
1306 }
1307
1308 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1309 async fn test_merklized_state_missing_leaf() {
1310 for tree_size in 1..=3 {
1317 let db = TmpDb::init().await;
1318 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1319 .await
1320 .unwrap();
1321
1322 let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1324 for i in 0..tree_size {
1325 test_tree.update(i, i).unwrap();
1326 }
1327
1328 let mut tx = storage.write().await.unwrap();
1329
1330 tx.upsert(
1332 "header",
1333 ["height", "hash", "payload_hash", "timestamp", "data"],
1334 ["height"],
1335 [(
1336 0i64,
1337 "hash".to_string(),
1338 "hash".to_string(),
1339 0,
1340 serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}),
1341 )],
1342 )
1343 .await
1344 .unwrap();
1345
1346 for i in 0..tree_size {
1348 let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1349 let traversal_path =
1350 ToTraversalPath::<8>::to_traversal_path(&i, test_tree.height());
1351 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1352 &mut tx,
1353 proof,
1354 traversal_path,
1355 0,
1356 )
1357 .await
1358 .unwrap();
1359 }
1360 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 0)
1361 .await
1362 .unwrap();
1363 tx.commit().await.unwrap();
1364
1365 let snapshot = Snapshot::<MockTypes, MockMerkleTree, 8>::Index(0);
1367 for i in 0..tree_size {
1368 let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1369 assert_eq!(
1370 proof,
1371 storage
1372 .read()
1373 .await
1374 .unwrap()
1375 .get_path(snapshot, i)
1376 .await
1377 .unwrap()
1378 );
1379 assert_eq!(*proof.elem().unwrap(), i);
1380 }
1381
1382 let index = serde_json::to_value(tree_size - 1).unwrap();
1384 let mut tx = storage.write().await.unwrap();
1385
1386 tx.execute(
1387 query(&format!(
1388 "DELETE FROM {} WHERE idx = $1",
1389 MockMerkleTree::state_type()
1390 ))
1391 .bind(serde_json::to_value(index).unwrap()),
1392 )
1393 .await
1394 .unwrap();
1395 tx.commit().await.unwrap();
1396
1397 for i in 0..tree_size - 1 {
1399 let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1400 assert_eq!(
1401 proof,
1402 storage
1403 .read()
1404 .await
1405 .unwrap()
1406 .get_path(snapshot, i)
1407 .await
1408 .unwrap()
1409 );
1410 assert_eq!(*proof.elem().unwrap(), i);
1411 }
1412
1413 let err = storage
1415 .read()
1416 .await
1417 .unwrap()
1418 .get_path(snapshot, tree_size - 1)
1419 .await
1420 .unwrap_err();
1421 assert!(matches!(err, QueryError::Missing));
1422 }
1423 }
1424}