hotshot_query_service/data_source/storage/sql/queries/
state.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13//! Merklized state storage implementation for a database query engine.
14
15use std::{
16    collections::{HashMap, HashSet, VecDeque},
17    sync::Arc,
18};
19
20#[cfg(not(feature = "embedded-db"))]
21use anyhow::Context;
22use ark_serialize::CanonicalDeserialize;
23use async_trait::async_trait;
24use futures::stream::TryStreamExt;
25use hotshot_types::traits::node_implementation::NodeType;
26use jf_merkle_tree_compat::{
27    DigestAlgorithm, MerkleCommitment, ToTraversalPath,
28    prelude::{MerkleNode, MerkleProof},
29};
30use sqlx::types::{BitVec, JsonValue};
31
32use super::{
33    super::transaction::{Transaction, TransactionMode, Write, query_as},
34    DecodeError, QueryBuilder,
35};
36use crate::{
37    QueryError, QueryResult,
38    data_source::storage::{
39        MerklizedStateHeightStorage, MerklizedStateStorage,
40        pruning::PrunedHeightStorage,
41        sql::{build_where_in, sqlx::Row},
42    },
43    merklized_state::{MerklizedState, Snapshot},
44};
45
46#[async_trait]
47impl<Mode, Types, State, const ARITY: usize> MerklizedStateStorage<Types, State, ARITY>
48    for Transaction<Mode>
49where
50    Mode: TransactionMode,
51    Types: NodeType,
52    State: MerklizedState<Types, ARITY> + 'static,
53{
54    /// Retrieves a Merkle path from the database
55    async fn get_path(
56        &mut self,
57        snapshot: Snapshot<Types, State, ARITY>,
58        key: State::Key,
59    ) -> QueryResult<MerkleProof<State::Entry, State::Key, State::T, ARITY>> {
60        let state_type = State::state_type();
61        let tree_height = State::tree_height();
62
63        // Get the traversal path of the index
64        let traversal_path = State::Key::to_traversal_path(&key, tree_height);
65        let (created, merkle_commitment) = self.snapshot_info(snapshot).await?;
66
67        // Get all the nodes in the path to the index.
68        // Order by pos DESC is to return nodes from the leaf to the root
69        let (query, sql) = build_get_path_query(state_type, traversal_path.clone(), created)?;
70        let rows = query.query(&sql).fetch_all(self.as_mut()).await?;
71
72        let nodes: Vec<Node> = rows.into_iter().map(|r| r.into()).collect();
73
74        // insert all the hash ids to a hashset which is used to query later
75        // HashSet is used to avoid duplicates
76        let mut hash_ids = HashSet::new();
77        for node in nodes.iter() {
78            hash_ids.insert(node.hash_id);
79            if let Some(children) = &node.children {
80                let children: Vec<i32> =
81                    serde_json::from_value(children.clone()).map_err(|e| QueryError::Error {
82                        message: format!("Error deserializing 'children' into Vec<i32>: {e}"),
83                    })?;
84                hash_ids.extend(children);
85            }
86        }
87
88        // Find all the hash values and create a hashmap
89        // Hashmap will be used to get the hash value of the nodes children and the node itself.
90        let hashes = if !hash_ids.is_empty() {
91            let (query, sql) = build_where_in("SELECT id, value FROM hash", "id", hash_ids)?;
92            query
93                .query_as(&sql)
94                .fetch(self.as_mut())
95                .try_collect::<HashMap<i32, Vec<u8>>>()
96                .await?
97        } else {
98            HashMap::new()
99        };
100
101        let mut proof_path = VecDeque::with_capacity(State::tree_height());
102        for Node {
103            hash_id,
104            children,
105            children_bitvec,
106            idx,
107            entry,
108            ..
109        } in nodes.iter()
110        {
111            {
112                let value = hashes.get(hash_id).ok_or(QueryError::Error {
113                    message: format!("node's value references non-existent hash {hash_id}"),
114                })?;
115
116                match (children, children_bitvec, idx, entry) {
117                    // If the row has children then its a branch
118                    (Some(children), Some(children_bitvec), None, None) => {
119                        let children: Vec<i32> =
120                            serde_json::from_value(children.clone()).map_err(|e| {
121                                QueryError::Error {
122                                    message: format!(
123                                        "Error deserializing 'children' into Vec<i32>: {e}"
124                                    ),
125                                }
126                            })?;
127                        let mut children = children.iter();
128
129                        // Reconstruct the Children MerkleNodes from storage.
130                        // Children bit_vec is used to create forgotten  or empty node
131                        let child_nodes = children_bitvec
132                            .iter()
133                            .map(|bit| {
134                                if bit {
135                                    let hash_id = children.next().ok_or(QueryError::Error {
136                                        message: "node has fewer children than set bits".into(),
137                                    })?;
138                                    let value = hashes.get(hash_id).ok_or(QueryError::Error {
139                                        message: format!(
140                                            "node's child references non-existent hash {hash_id}"
141                                        ),
142                                    })?;
143                                    Ok(Arc::new(MerkleNode::ForgettenSubtree {
144                                        value: State::T::deserialize_compressed(value.as_slice())
145                                            .decode_error("malformed merkle node value")?,
146                                    }))
147                                } else {
148                                    Ok(Arc::new(MerkleNode::Empty))
149                                }
150                            })
151                            .collect::<QueryResult<Vec<_>>>()?;
152                        // Use the Children merkle nodes to reconstruct the branch node
153                        proof_path.push_back(MerkleNode::Branch {
154                            value: State::T::deserialize_compressed(value.as_slice())
155                                .decode_error("malformed merkle node value")?,
156                            children: child_nodes,
157                        });
158                    },
159                    // If it has an entry, it's a leaf
160                    (None, None, Some(index), Some(entry)) => {
161                        proof_path.push_back(MerkleNode::Leaf {
162                            value: State::T::deserialize_compressed(value.as_slice())
163                                .decode_error("malformed merkle node value")?,
164                            pos: serde_json::from_value(index.clone())
165                                .decode_error("malformed merkle node index")?,
166                            elem: serde_json::from_value(entry.clone())
167                                .decode_error("malformed merkle element")?,
168                        });
169                    },
170                    // Otherwise, it's empty.
171                    (None, None, Some(_), None) => {
172                        proof_path.push_back(MerkleNode::Empty);
173                    },
174                    _ => {
175                        return Err(QueryError::Error {
176                            message: "Invalid type of merkle node found".to_string(),
177                        });
178                    },
179                }
180            }
181        }
182
183        // Reconstruct the merkle commitment from the path
184        let init = if let Some(MerkleNode::Leaf { value, .. }) = proof_path.front() {
185            *value
186        } else {
187            // If the path ends in a branch (or, as a special case, if the path and thus the entire
188            // tree is empty), we are looking up an entry that is not present in the tree. We always
189            // store all the nodes on all the paths to all the entries in the tree, so the only
190            // nodes we could be missing are empty nodes from unseen entries. Thus, we can
191            // reconstruct what the path should be by prepending empty nodes.
192            while proof_path.len() <= State::tree_height() {
193                proof_path.push_front(MerkleNode::Empty);
194            }
195            State::T::default()
196        };
197        let commitment_from_path = traversal_path
198            .iter()
199            .zip(proof_path.iter().skip(1))
200            .try_fold(init, |val, (branch, node)| -> QueryResult<State::T> {
201                match node {
202                    MerkleNode::Branch { value: _, children } => {
203                        let data = children
204                            .iter()
205                            .map(|node| match node.as_ref() {
206                                MerkleNode::ForgettenSubtree { value } => Ok(*value),
207                                MerkleNode::Empty => Ok(State::T::default()),
208                                _ => Err(QueryError::Error {
209                                    message: "Invalid child node".to_string(),
210                                }),
211                            })
212                            .collect::<QueryResult<Vec<_>>>()?;
213
214                        if data[*branch] != val {
215                            // This can only happen if data is missing: we have an old version of
216                            // one of the nodes in the path, which is why it is not matching up with
217                            // its parent.
218                            tracing::warn!(
219                                ?key,
220                                parent = ?data[*branch],
221                                child = ?val,
222                                branch = %*branch,
223                                %created,
224                                %merkle_commitment,
225                                "missing data in merklized state; parent-child mismatch",
226                            );
227                            return Err(QueryError::Missing);
228                        }
229
230                        State::Digest::digest(&data).map_err(|err| QueryError::Error {
231                            message: format!("failed to update digest: {err:#}"),
232                        })
233                    },
234                    MerkleNode::Empty => Ok(init),
235                    _ => Err(QueryError::Error {
236                        message: "Invalid type of Node in the proof".to_string(),
237                    }),
238                }
239            })?;
240
241        if commitment_from_path != merkle_commitment.digest() {
242            return Err(QueryError::Error {
243                message: format!(
244                    "Commitment calculated from merkle path ({commitment_from_path:?}) does not \
245                     match the commitment in the header ({:?})",
246                    merkle_commitment.digest()
247                ),
248            });
249        }
250
251        Ok(MerkleProof {
252            pos: key,
253            proof: proof_path.into(),
254        })
255    }
256}
257
258#[async_trait]
259impl<Mode: TransactionMode> MerklizedStateHeightStorage for Transaction<Mode> {
260    async fn get_last_state_height(&mut self) -> QueryResult<usize> {
261        let Some((height,)) = query_as::<(i64,)>("SELECT height from last_merklized_state_height")
262            .fetch_optional(self.as_mut())
263            .await?
264        else {
265            return Ok(0);
266        };
267        Ok(height as usize)
268    }
269}
270
271impl<Mode: TransactionMode> Transaction<Mode> {
272    /// Get information identifying a [`Snapshot`].
273    ///
274    /// If the given snapshot is known to the database, this function returns
275    /// * The block height at which the snapshot was created
276    /// * A digest of the Merkle commitment to the snapshotted state
277    async fn snapshot_info<Types, State, const ARITY: usize>(
278        &mut self,
279        snapshot: Snapshot<Types, State, ARITY>,
280    ) -> QueryResult<(i64, State::Commit)>
281    where
282        Types: NodeType,
283        State: MerklizedState<Types, ARITY>,
284    {
285        let header_state_commitment_field = State::header_state_commitment_field();
286
287        let (created, commit) = match snapshot {
288            Snapshot::Commit(commit) => {
289                // Get the block height using the merkle commitment. It is possible that multiple
290                // headers will have the same state commitment. In this case we don't care which
291                // height we get, since any query against equivalent states will yield equivalent
292                // results, regardless of which block the state is from. Thus, we can make this
293                // query fast with `LIMIT 1` and no `ORDER BY`.
294                let (height,) = query_as(&format!(
295                    "SELECT height
296                       FROM header
297                      WHERE {header_state_commitment_field} = $1
298                      LIMIT 1"
299                ))
300                .bind(commit.to_string())
301                .fetch_one(self.as_mut())
302                .await?;
303
304                (height, commit)
305            },
306            Snapshot::Index(created) => {
307                let created = created as i64;
308                let (commit,) = query_as::<(String,)>(&format!(
309                    "SELECT {header_state_commitment_field} AS root_commitment
310                       FROM header
311                      WHERE height = $1
312                      LIMIT 1"
313                ))
314                .bind(created)
315                .fetch_one(self.as_mut())
316                .await?;
317                let commit = serde_json::from_value(commit.into())
318                    .decode_error("malformed state commitment")?;
319                (created, commit)
320            },
321        };
322
323        // Make sure the requested snapshot is up to date.
324        let height = self.get_last_state_height().await?;
325
326        if height < (created as usize) {
327            return Err(QueryError::NotFound);
328        }
329
330        let pruned_height = self
331            .load_pruned_height()
332            .await
333            .map_err(|e| QueryError::Error {
334                message: format!("failed to load pruned height: {e}"),
335            })?;
336
337        if pruned_height.is_some_and(|h| height <= h as usize) {
338            return Err(QueryError::NotFound);
339        }
340
341        Ok((created, commit))
342    }
343}
344
345// TODO: create a generic upsert function with retries that returns the column
346#[cfg(feature = "embedded-db")]
347pub(crate) fn build_hash_batch_insert(
348    hashes: &[Vec<u8>],
349) -> QueryResult<(QueryBuilder<'_>, String)> {
350    let mut query = QueryBuilder::default();
351    let params = hashes
352        .iter()
353        .map(|hash| Ok(format!("({})", query.bind(hash)?)))
354        .collect::<QueryResult<Vec<String>>>()?;
355    let sql = format!(
356        "INSERT INTO hash(value) values {} ON CONFLICT (value) DO UPDATE SET value = \
357         EXCLUDED.value returning value, id",
358        params.join(",")
359    );
360    Ok((query, sql))
361}
362
363/// Batch insert hashes using UNNEST for large batches (postgres only).
364/// Returns a map from hash bytes to their database IDs.
365#[cfg(not(feature = "embedded-db"))]
366pub(crate) async fn batch_insert_hashes(
367    hashes: Vec<Vec<u8>>,
368    tx: &mut Transaction<Write>,
369) -> QueryResult<HashMap<Vec<u8>, i32>> {
370    if hashes.is_empty() {
371        return Ok(HashMap::new());
372    }
373
374    // Use UNNEST-based batch insert (more efficient and avoids parameter limits)
375    let sql = "INSERT INTO hash(value) SELECT * FROM UNNEST($1::bytea[]) ON CONFLICT (value) DO \
376               UPDATE SET value = EXCLUDED.value RETURNING value, id";
377
378    let result: HashMap<Vec<u8>, i32> = sqlx::query_as(sql)
379        .bind(&hashes)
380        .fetch(tx.as_mut())
381        .try_collect()
382        .await
383        .map_err(|e| QueryError::Error {
384            message: format!("batch hash insert failed: {e}"),
385        })?;
386
387    Ok(result)
388}
389
390/// Type alias for a merkle proof with its traversal path.
391pub(crate) type ProofWithPath<Entry, Key, T, const ARITY: usize> =
392    (MerkleProof<Entry, Key, T, ARITY>, Vec<usize>);
393
394/// Collects nodes and hashes from merkle proofs.
395/// Returns (nodes, hashes) for batch insertion.
396pub(crate) fn collect_nodes_from_proofs<Entry, Key, T, const ARITY: usize>(
397    proofs: &[ProofWithPath<Entry, Key, T, ARITY>],
398) -> QueryResult<(Vec<NodeWithHashes>, HashSet<Vec<u8>>)>
399where
400    Entry: jf_merkle_tree_compat::Element + serde::Serialize,
401    Key: jf_merkle_tree_compat::Index + serde::Serialize,
402    T: jf_merkle_tree_compat::NodeValue,
403{
404    let mut nodes = Vec::new();
405    let mut hashes = HashSet::new();
406
407    for (proof, traversal_path) in proofs {
408        let pos = &proof.pos;
409        let path = &proof.proof;
410        let mut trav_path = traversal_path.iter().map(|n| *n as i32);
411
412        for node in path.iter() {
413            match node {
414                MerkleNode::Empty => {
415                    let index =
416                        serde_json::to_value(pos.clone()).map_err(|e| QueryError::Error {
417                            message: format!("malformed merkle position: {e}"),
418                        })?;
419                    let node_path: Vec<i32> = trav_path.clone().rev().collect();
420                    nodes.push((
421                        Node {
422                            path: node_path.into(),
423                            idx: Some(index),
424                            ..Default::default()
425                        },
426                        None,
427                        [0_u8; 32].to_vec(),
428                    ));
429                    hashes.insert([0_u8; 32].to_vec());
430                },
431                MerkleNode::ForgettenSubtree { .. } => {
432                    return Err(QueryError::Error {
433                        message: "Node in the Merkle path contains a forgotten subtree".into(),
434                    });
435                },
436                MerkleNode::Leaf { value, pos, elem } => {
437                    let mut leaf_commit = Vec::new();
438                    value.serialize_compressed(&mut leaf_commit).map_err(|e| {
439                        QueryError::Error {
440                            message: format!("malformed merkle leaf commitment: {e}"),
441                        }
442                    })?;
443
444                    let node_path: Vec<i32> = trav_path.clone().rev().collect();
445
446                    let index =
447                        serde_json::to_value(pos.clone()).map_err(|e| QueryError::Error {
448                            message: format!("malformed merkle position: {e}"),
449                        })?;
450                    let entry = serde_json::to_value(elem).map_err(|e| QueryError::Error {
451                        message: format!("malformed merkle element: {e}"),
452                    })?;
453
454                    nodes.push((
455                        Node {
456                            path: node_path.into(),
457                            idx: Some(index),
458                            entry: Some(entry),
459                            ..Default::default()
460                        },
461                        None,
462                        leaf_commit.clone(),
463                    ));
464
465                    hashes.insert(leaf_commit);
466                },
467                MerkleNode::Branch { value, children } => {
468                    let mut branch_hash = Vec::new();
469                    value.serialize_compressed(&mut branch_hash).map_err(|e| {
470                        QueryError::Error {
471                            message: format!("malformed merkle branch hash: {e}"),
472                        }
473                    })?;
474
475                    let mut children_bitvec = BitVec::new();
476                    let mut children_values = Vec::new();
477                    for child in children {
478                        let child = child.as_ref();
479                        match child {
480                            MerkleNode::Empty => {
481                                children_bitvec.push(false);
482                            },
483                            MerkleNode::Branch { value, .. }
484                            | MerkleNode::Leaf { value, .. }
485                            | MerkleNode::ForgettenSubtree { value } => {
486                                let mut hash = Vec::new();
487                                value.serialize_compressed(&mut hash).map_err(|e| {
488                                    QueryError::Error {
489                                        message: format!("malformed merkle node hash: {e}"),
490                                    }
491                                })?;
492
493                                children_values.push(hash);
494                                children_bitvec.push(true);
495                            },
496                        }
497                    }
498
499                    let node_path: Vec<i32> = trav_path.clone().rev().collect();
500                    nodes.push((
501                        Node {
502                            path: node_path.into(),
503                            children: None,
504                            children_bitvec: Some(children_bitvec),
505                            ..Default::default()
506                        },
507                        Some(children_values.clone()),
508                        branch_hash.clone(),
509                    ));
510                    hashes.insert(branch_hash);
511                    hashes.extend(children_values);
512                },
513            }
514
515            trav_path.next();
516        }
517    }
518
519    Ok((nodes, hashes))
520}
521
522// Represents a row in a state table
523#[derive(Debug, Default, Clone)]
524pub(crate) struct Node {
525    pub(crate) path: JsonValue,
526    pub(crate) created: i64,
527    pub(crate) hash_id: i32,
528    pub(crate) children: Option<JsonValue>,
529    pub(crate) children_bitvec: Option<BitVec>,
530    pub(crate) idx: Option<JsonValue>,
531    pub(crate) entry: Option<JsonValue>,
532}
533
534/// Type alias for node data with optional children hashes and node hash.
535/// Used during batch collection before database insertion.
536pub(crate) type NodeWithHashes = (Node, Option<Vec<Vec<u8>>>, Vec<u8>);
537
538#[cfg(feature = "embedded-db")]
539impl From<sqlx::sqlite::SqliteRow> for Node {
540    fn from(row: sqlx::sqlite::SqliteRow) -> Self {
541        let bit_string: Option<String> = row.get_unchecked("children_bitvec");
542        let children_bitvec: Option<BitVec> =
543            bit_string.map(|b| b.chars().map(|c| c == '1').collect());
544
545        Self {
546            path: row.get_unchecked("path"),
547            created: row.get_unchecked("created"),
548            hash_id: row.get_unchecked("hash_id"),
549            children: row.get_unchecked("children"),
550            children_bitvec,
551            idx: row.get_unchecked("idx"),
552            entry: row.get_unchecked("entry"),
553        }
554    }
555}
556
557#[cfg(not(feature = "embedded-db"))]
558impl From<sqlx::postgres::PgRow> for Node {
559    fn from(row: sqlx::postgres::PgRow) -> Self {
560        Self {
561            path: row.get_unchecked("path"),
562            created: row.get_unchecked("created"),
563            hash_id: row.get_unchecked("hash_id"),
564            children: row.get_unchecked("children"),
565            children_bitvec: row.get_unchecked("children_bitvec"),
566            idx: row.get_unchecked("idx"),
567            entry: row.get_unchecked("entry"),
568        }
569    }
570}
571
572impl Node {
573    pub(crate) async fn upsert(
574        name: &str,
575        nodes: impl IntoIterator<Item = Self>,
576        tx: &mut Transaction<Write>,
577    ) -> anyhow::Result<()> {
578        let nodes: Vec<_> = nodes.into_iter().collect();
579
580        // Use UNNEST-based batch insert for postgres (more efficient and avoids parameter limits)
581        #[cfg(not(feature = "embedded-db"))]
582        return Self::upsert_batch_unnest(name, nodes, tx).await;
583
584        #[cfg(feature = "embedded-db")]
585        {
586            for node_chunk in nodes.chunks(20) {
587                let rows: Vec<_> = node_chunk
588                    .iter()
589                    .map(|n| {
590                        let children_bitvec: Option<String> = n
591                            .children_bitvec
592                            .clone()
593                            .map(|b| b.iter().map(|bit| if bit { '1' } else { '0' }).collect());
594
595                        (
596                            n.path.clone(),
597                            n.created,
598                            n.hash_id,
599                            n.children.clone(),
600                            children_bitvec,
601                            n.idx.clone(),
602                            n.entry.clone(),
603                        )
604                    })
605                    .collect();
606
607                tx.upsert(
608                    name,
609                    [
610                        "path",
611                        "created",
612                        "hash_id",
613                        "children",
614                        "children_bitvec",
615                        "idx",
616                        "entry",
617                    ],
618                    ["path", "created"],
619                    rows,
620                )
621                .await?;
622            }
623            Ok(())
624        }
625    }
626
627    #[cfg(not(feature = "embedded-db"))]
628    async fn upsert_batch_unnest(
629        name: &str,
630        nodes: Vec<Self>,
631        tx: &mut Transaction<Write>,
632    ) -> anyhow::Result<()> {
633        if nodes.is_empty() {
634            return Ok(());
635        }
636
637        // Deduplicate nodes by (path, created) - keep the last occurrence
638        // This is required because UNNEST + ON CONFLICT cannot handle duplicates in the same batch
639        let mut deduped = HashMap::new();
640        for node in nodes {
641            deduped.insert((node.path.to_string(), node.created), node);
642        }
643
644        let mut paths = Vec::with_capacity(deduped.len());
645        let mut createds = Vec::with_capacity(deduped.len());
646        let mut hash_ids = Vec::with_capacity(deduped.len());
647        let mut childrens = Vec::with_capacity(deduped.len());
648        let mut children_bitvecs = Vec::with_capacity(deduped.len());
649        let mut idxs = Vec::with_capacity(deduped.len());
650        let mut entries = Vec::with_capacity(deduped.len());
651
652        for node in deduped.into_values() {
653            paths.push(node.path);
654            createds.push(node.created);
655            hash_ids.push(node.hash_id);
656            childrens.push(node.children);
657            children_bitvecs.push(node.children_bitvec);
658            idxs.push(node.idx);
659            entries.push(node.entry);
660        }
661
662        let sql = format!(
663            r#"
664            INSERT INTO "{name}" (path, created, hash_id, children, children_bitvec, idx, entry)
665            SELECT * FROM UNNEST($1::jsonb[], $2::bigint[], $3::int[], $4::jsonb[], $5::bit varying[], $6::jsonb[], $7::jsonb[])
666            ON CONFLICT (path, created) DO UPDATE SET
667                hash_id = EXCLUDED.hash_id,
668                children = EXCLUDED.children,
669                children_bitvec = EXCLUDED.children_bitvec,
670                idx = EXCLUDED.idx,
671                entry = EXCLUDED.entry
672            "#
673        );
674
675        sqlx::query(&sql)
676            .bind(&paths)
677            .bind(&createds)
678            .bind(&hash_ids)
679            .bind(&childrens)
680            .bind(&children_bitvecs)
681            .bind(&idxs)
682            .bind(&entries)
683            .execute(tx.as_mut())
684            .await
685            .context("batch upsert with UNNEST failed")?;
686
687        Ok(())
688    }
689}
690
691fn build_get_path_query<'q>(
692    table: &'static str,
693    traversal_path: Vec<usize>,
694    created: i64,
695) -> QueryResult<(QueryBuilder<'q>, String)> {
696    let mut query = QueryBuilder::default();
697    let mut traversal_path = traversal_path.into_iter().map(|x| x as i32);
698
699    // We iterate through the path vector skipping the first element after each iteration
700    let len = traversal_path.len();
701    let mut sub_queries = Vec::new();
702
703    query.bind(created)?;
704
705    for _ in 0..=len {
706        let path = traversal_path.clone().rev().collect::<Vec<_>>();
707        let path: serde_json::Value = path.into();
708        let node_path = query.bind(path)?;
709
710        let sub_query = format!(
711            "SELECT * FROM (SELECT * FROM {table} WHERE path = {node_path} AND created <= $1 \
712             ORDER BY created DESC LIMIT 1) AS latest_node",
713        );
714
715        sub_queries.push(sub_query);
716        traversal_path.next();
717    }
718
719    let mut sql: String = sub_queries.join(" UNION ");
720
721    sql = format!("SELECT * FROM ({sql}) as t ");
722
723    // PostgreSQL already orders JSON arrays by length, so no additional function is needed
724    // For SQLite, `length()` is used to sort by length.
725    if cfg!(feature = "embedded-db") {
726        sql.push_str("ORDER BY length(t.path) DESC");
727    } else {
728        sql.push_str("ORDER BY t.path DESC");
729    }
730
731    Ok((query, sql))
732}
733
734#[cfg(test)]
735mod test {
736    use futures::stream::StreamExt;
737    use jf_merkle_tree_compat::{
738        LookupResult, MerkleTreeScheme, UniversalMerkleTreeScheme,
739        universal_merkle_tree::UniversalMerkleTree,
740    };
741    use rand::{RngCore, seq::IteratorRandom};
742
743    use super::*;
744    use crate::{
745        data_source::{
746            VersionedDataSource,
747            storage::sql::{testing::TmpDb, *},
748        },
749        merklized_state::UpdateStateData,
750        testing::mocks::{MockMerkleTree, MockTypes},
751    };
752
753    #[test_log::test(tokio::test(flavor = "multi_thread"))]
754    async fn test_merklized_state_storage() {
755        // In this test we insert some entries into the tree and update the database
756        // Each entry's merkle path is compared with the path from the tree
757
758        let db = TmpDb::init().await;
759        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
760            .await
761            .unwrap();
762
763        // define a test tree
764        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
765            MockMerkleTree::new(MockMerkleTree::tree_height());
766        let block_height = 1;
767
768        // insert some entries into the tree and the header table
769        // Header table is used the get_path query to check if the header exists for the block height.
770        let mut tx = storage.write().await.unwrap();
771        for i in 0..27 {
772            test_tree.update(i, i).unwrap();
773
774            // data field of the header
775            let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
776            tx.upsert(
777                "header",
778                [
779                    "height",
780                    "hash",
781                    "payload_hash",
782                    "timestamp",
783                    "ns_table",
784                    "data",
785                ],
786                ["height"],
787                [(
788                    block_height as i64,
789                    format!("randomHash{i}"),
790                    "t".to_string(),
791                    0,
792                    "ns_table".to_string(),
793                    test_data,
794                )],
795            )
796            .await
797            .unwrap();
798            // proof for the index from the tree
799            let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
800            // traversal path for the index.
801            let traversal_path =
802                <usize as ToTraversalPath<8>>::to_traversal_path(&i, test_tree.height());
803
804            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
805                &mut tx,
806                proof.clone(),
807                traversal_path.clone(),
808                block_height as u64,
809            )
810            .await
811            .expect("failed to insert nodes");
812        }
813        // update saved state height
814        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
815            .await
816            .unwrap();
817        tx.commit().await.unwrap();
818
819        //Get the path and check if it matches the lookup
820        for i in 0..27 {
821            // Query the path for the index
822            let mut tx = storage.read().await.unwrap();
823            let merkle_path = tx
824                .get_path(
825                    Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
826                    i,
827                )
828                .await
829                .unwrap();
830
831            let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
832
833            tracing::info!("merkle path {:?}", merkle_path);
834
835            // merkle path from the storage should match the path from test tree
836            assert_eq!(merkle_path, proof.clone(), "merkle paths mismatch");
837        }
838
839        // Get the proof of index 0 with bh = 1
840        let (_, proof_bh_1) = test_tree.lookup(0).expect_ok().unwrap();
841        // Inserting Index 0 again with created (bh) = 2
842        // Our database should then have 2 versions of this leaf node
843        // Update the node so that proof is also updated
844        test_tree.update(0, 99).unwrap();
845        // Also update the merkle commitment in the header
846
847        // data field of the header
848        let mut tx = storage.write().await.unwrap();
849        let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
850        tx.upsert(
851            "header",
852            [
853                "height",
854                "hash",
855                "payload_hash",
856                "timestamp",
857                "data",
858                "ns_table",
859            ],
860            ["height"],
861            [(
862                2i64,
863                "randomstring".to_string(),
864                "t".to_string(),
865                0,
866                test_data,
867                "ns_table".to_string(),
868            )],
869        )
870        .await
871        .unwrap();
872        let (_, proof_bh_2) = test_tree.lookup(0).expect_ok().unwrap();
873        // traversal path for the index.
874        let traversal_path =
875            <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
876        // Update storage to insert a new version of this code
877
878        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
879            &mut tx,
880            proof_bh_2.clone(),
881            traversal_path.clone(),
882            2,
883        )
884        .await
885        .expect("failed to insert nodes");
886        // update saved state height
887        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 2)
888            .await
889            .unwrap();
890        tx.commit().await.unwrap();
891
892        let node_path = traversal_path
893            .into_iter()
894            .rev()
895            .map(|n| n as i32)
896            .collect::<Vec<_>>();
897
898        // Find all the nodes of Index 0 in table
899        let mut tx = storage.read().await.unwrap();
900        let rows = query("SELECT * from test_tree where path = $1 ORDER BY created")
901            .bind(serde_json::to_value(node_path).unwrap())
902            .fetch(tx.as_mut());
903
904        let nodes: Vec<Node> = rows.map(|res| res.unwrap().into()).collect().await;
905        // There should be only 2 versions of this node
906        assert!(nodes.len() == 2, "incorrect number of nodes");
907        assert_eq!(nodes[0].created, 1, "wrong block height");
908        assert_eq!(nodes[1].created, 2, "wrong block height");
909
910        // Now we can have two snapshots for Index 0
911        // One with created = 1 and other with 2
912        // Query snapshot with created = 2
913
914        let path_with_bh_2 = storage
915            .read()
916            .await
917            .unwrap()
918            .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2), 0)
919            .await
920            .unwrap();
921
922        assert_eq!(path_with_bh_2, proof_bh_2);
923        let path_with_bh_1 = storage
924            .read()
925            .await
926            .unwrap()
927            .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(1), 0)
928            .await
929            .unwrap();
930        assert_eq!(path_with_bh_1, proof_bh_1);
931    }
932
933    #[test_log::test(tokio::test(flavor = "multi_thread"))]
934    async fn test_merklized_state_non_membership_proof() {
935        // This test updates the Merkle tree with a new entry and inserts the corresponding Merkle nodes into the database with created = 1.
936        // A Merkle node is then deleted from the tree.
937        // The database is then updated to reflect the deletion of the entry with a created (block height) of 2
938        // As the leaf node becomes a non-member, we do a universal lookup to obtain its non-membership proof path.
939        // It is expected that the path retrieved from the tree matches the path obtained from the database.
940
941        let db = TmpDb::init().await;
942        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
943            .await
944            .unwrap();
945
946        // define a test tree
947        let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
948        let block_height = 1;
949        //insert an entry into the tree
950        test_tree.update(0, 0).unwrap();
951        let commitment = test_tree.commitment();
952
953        let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()});
954        // insert the header with merkle commitment
955        let mut tx = storage.write().await.unwrap();
956        tx.upsert(
957            "header",
958            [
959                "height",
960                "hash",
961                "payload_hash",
962                "timestamp",
963                "data",
964                "ns_table",
965            ],
966            ["height"],
967            [(
968                block_height as i64,
969                "randomString".to_string(),
970                "t".to_string(),
971                0,
972                test_data,
973                "ns_table".to_string(),
974            )],
975        )
976        .await
977        .unwrap();
978        // proof for the index from the tree
979        let (_, proof_before_remove) = test_tree.lookup(0).expect_ok().unwrap();
980        // traversal path for the index.
981        let traversal_path =
982            <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
983        // insert merkle nodes
984        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
985            &mut tx,
986            proof_before_remove.clone(),
987            traversal_path.clone(),
988            block_height as u64,
989        )
990        .await
991        .expect("failed to insert nodes");
992        // update saved state height
993        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
994            .await
995            .unwrap();
996        tx.commit().await.unwrap();
997        // the path from the db and and tree should match
998        let merkle_path = storage
999            .read()
1000            .await
1001            .unwrap()
1002            .get_path(
1003                Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
1004                0,
1005            )
1006            .await
1007            .unwrap();
1008
1009        // merkle path from the storage should match the path from test tree
1010        assert_eq!(
1011            merkle_path,
1012            proof_before_remove.clone(),
1013            "merkle paths mismatch"
1014        );
1015
1016        //Deleting the index 0
1017        test_tree.remove(0).expect("failed to delete index 0 ");
1018
1019        // Update the database with the proof
1020        // Created = 2 in this case
1021        let proof_after_remove = test_tree.universal_lookup(0).expect_not_found().unwrap();
1022
1023        let mut tx = storage.write().await.unwrap();
1024        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1025            &mut tx,
1026            proof_after_remove.clone(),
1027            traversal_path.clone(),
1028            2_u64,
1029        )
1030        .await
1031        .expect("failed to insert nodes");
1032        // Insert the new header
1033        tx.upsert(
1034                "header",
1035                ["height", "hash", "payload_hash", "timestamp", "data", "ns_table"],
1036                ["height"],
1037                [(
1038                    2i64,
1039                    "randomString2".to_string(),
1040                    "t".to_string(),
1041                    0,
1042                    serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}),
1043                    "ns_table".to_string(),
1044                )],
1045            )
1046            .await
1047            .unwrap();
1048        // update saved state height
1049        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 2)
1050            .await
1051            .unwrap();
1052        tx.commit().await.unwrap();
1053        // Get non membership proof
1054        let non_membership_path = storage
1055            .read()
1056            .await
1057            .unwrap()
1058            .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2_u64), 0)
1059            .await
1060            .unwrap();
1061        // Assert that the paths from the db and the tree are equal
1062        assert_eq!(
1063            non_membership_path, proof_after_remove,
1064            "merkle paths dont match"
1065        );
1066
1067        // Query the membership proof i.e proof with created = 1
1068        // This proof should be equal to the proof before deletion
1069        // Assert that the paths from the db and the tree are equal
1070
1071        let proof_bh_1 = storage
1072            .read()
1073            .await
1074            .unwrap()
1075            .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(1_u64), 0)
1076            .await
1077            .unwrap();
1078        assert_eq!(proof_bh_1, proof_before_remove, "merkle paths dont match");
1079    }
1080
1081    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1082    async fn test_merklized_state_non_membership_proof_unseen_entry() {
1083        let db = TmpDb::init().await;
1084        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1085            .await
1086            .unwrap();
1087
1088        // define a test tree
1089        let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1090
1091        // For each case (where the root is empty, a leaf, and a branch) test getting a
1092        // non-membership proof for an entry node the database has never seen.
1093        for i in 0..=2 {
1094            tracing::info!(i, ?test_tree, "testing non-membership proof");
1095            let mut tx = storage.write().await.unwrap();
1096
1097            // Insert a dummy header
1098            tx.upsert(
1099                "header",
1100                ["height", "hash", "payload_hash", "timestamp", "data", "ns_table"],
1101                ["height"],
1102                [(
1103                    i as i64,
1104                    format!("hash{i}"),
1105                    "t".to_string(),
1106                    0,
1107                    serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}),
1108                    "ns_table".to_string(),
1109                )],
1110            )
1111            .await
1112            .unwrap();
1113            // update saved state height
1114            UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, i)
1115                .await
1116                .unwrap();
1117            tx.commit().await.unwrap();
1118
1119            // get a non-membership proof for a never-before-seen node.
1120            let proof = storage
1121                .read()
1122                .await
1123                .unwrap()
1124                .get_path(
1125                    Snapshot::<MockTypes, MockMerkleTree, 8>::Index(i as u64),
1126                    100,
1127                )
1128                .await
1129                .unwrap();
1130            assert_eq!(proof.elem(), None);
1131
1132            assert!(
1133                MockMerkleTree::non_membership_verify(test_tree.commitment(), 100, proof).unwrap()
1134            );
1135
1136            // insert an additional node into the tree.
1137            test_tree.update(i, i).unwrap();
1138            let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
1139            let traversal_path = ToTraversalPath::<8>::to_traversal_path(&i, test_tree.height());
1140            let mut tx = storage.write().await.unwrap();
1141            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1142                &mut tx,
1143                proof,
1144                traversal_path,
1145                (i + 1) as u64,
1146            )
1147            .await
1148            .expect("failed to insert nodes");
1149            tx.commit().await.unwrap();
1150        }
1151    }
1152
1153    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1154    async fn test_merklized_storage_with_commit() {
1155        // This test insert a merkle path into the database and queries the path using the merkle commitment
1156
1157        let db = TmpDb::init().await;
1158        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1159            .await
1160            .unwrap();
1161
1162        // define a test tree
1163        let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1164        let block_height = 1;
1165        //insert an entry into the tree
1166        test_tree.update(0, 0).unwrap();
1167        let commitment = test_tree.commitment();
1168
1169        let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(commitment).unwrap()});
1170        // insert the header with merkle commitment
1171        let mut tx = storage.write().await.unwrap();
1172        tx.upsert(
1173            "header",
1174            [
1175                "height",
1176                "hash",
1177                "payload_hash",
1178                "timestamp",
1179                "data",
1180                "ns_table",
1181            ],
1182            ["height"],
1183            [(
1184                block_height as i64,
1185                "randomString".to_string(),
1186                "t".to_string(),
1187                0,
1188                test_data,
1189                "ns_table".to_string(),
1190            )],
1191        )
1192        .await
1193        .unwrap();
1194        // proof for the index from the tree
1195        let (_, proof) = test_tree.lookup(0).expect_ok().unwrap();
1196        // traversal path for the index.
1197        let traversal_path =
1198            <usize as ToTraversalPath<8>>::to_traversal_path(&0, test_tree.height());
1199        // insert merkle nodes
1200        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1201            &mut tx,
1202            proof.clone(),
1203            traversal_path.clone(),
1204            block_height as u64,
1205        )
1206        .await
1207        .expect("failed to insert nodes");
1208        // update saved state height
1209        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
1210            .await
1211            .unwrap();
1212        tx.commit().await.unwrap();
1213
1214        let merkle_proof = storage
1215            .read()
1216            .await
1217            .unwrap()
1218            .get_path(Snapshot::<_, MockMerkleTree, 8>::Commit(commitment), 0)
1219            .await
1220            .unwrap();
1221
1222        let (_, proof) = test_tree.lookup(0).expect_ok().unwrap();
1223
1224        assert_eq!(merkle_proof, proof.clone(), "merkle paths mismatch");
1225    }
1226    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1227    async fn test_merklized_state_missing_state() {
1228        // This test checks that header commitment matches the root hash.
1229        // For this, the header merkle root commitment field is not updated, which should result in an error
1230        // The full merkle path verification is also done by recomputing the root hash
1231        // An index and its corresponding merkle nodes with created (bh) = 1 are inserted.
1232        // The entry of the index is updated, and the updated nodes are inserted with created (bh) = 2.
1233        // A node which is in the traversal path with bh = 2 is deleted, so the get_path should return an error as an older version of one of the nodes is used.
1234
1235        let db = TmpDb::init().await;
1236        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1237            .await
1238            .unwrap();
1239
1240        // define a test tree
1241        let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1242        let block_height = 1;
1243        //insert an entry into the tree
1244
1245        let mut tx = storage.write().await.unwrap();
1246        for i in 0..27 {
1247            test_tree.update(i, i).unwrap();
1248            let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1249            // insert the header with merkle commitment
1250            tx.upsert(
1251                "header",
1252                [
1253                    "height",
1254                    "hash",
1255                    "payload_hash",
1256                    "timestamp",
1257                    "data",
1258                    "ns_table",
1259                ],
1260                ["height"],
1261                [(
1262                    block_height as i64,
1263                    format!("rarndomString{i}"),
1264                    "t".to_string(),
1265                    0,
1266                    test_data,
1267                    "ns_table".to_string(),
1268                )],
1269            )
1270            .await
1271            .unwrap();
1272            // proof for the index from the tree
1273            let (_, proof) = test_tree.lookup(i).expect_ok().unwrap();
1274            // traversal path for the index.
1275            let traversal_path =
1276                <usize as ToTraversalPath<8>>::to_traversal_path(&i, test_tree.height());
1277            // insert merkle nodes
1278            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1279                &mut tx,
1280                proof.clone(),
1281                traversal_path.clone(),
1282                block_height as u64,
1283            )
1284            .await
1285            .expect("failed to insert nodes");
1286            // update saved state height
1287            UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, block_height)
1288                .await
1289                .unwrap();
1290        }
1291
1292        test_tree.update(1, 100).unwrap();
1293        //insert updated merkle path without updating the header
1294        let traversal_path =
1295            <usize as ToTraversalPath<8>>::to_traversal_path(&1, test_tree.height());
1296        let (_, proof) = test_tree.lookup(1).expect_ok().unwrap();
1297
1298        // insert merkle nodes
1299        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1300            &mut tx,
1301            proof.clone(),
1302            traversal_path.clone(),
1303            block_height as u64,
1304        )
1305        .await
1306        .expect("failed to insert nodes");
1307        tx.commit().await.unwrap();
1308
1309        let merkle_path = storage
1310            .read()
1311            .await
1312            .unwrap()
1313            .get_path(
1314                Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
1315                1,
1316            )
1317            .await;
1318        assert!(merkle_path.is_err());
1319
1320        let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1321        // insert the header with merkle commitment
1322        let mut tx = storage.write().await.unwrap();
1323        tx.upsert(
1324            "header",
1325            [
1326                "height",
1327                "hash",
1328                "payload_hash",
1329                "timestamp",
1330                "data",
1331                "ns_table",
1332            ],
1333            ["height"],
1334            [(
1335                block_height as i64,
1336                "randomStringgg".to_string(),
1337                "t".to_string(),
1338                0,
1339                test_data,
1340                "ns_table".to_string(),
1341            )],
1342        )
1343        .await
1344        .unwrap();
1345        tx.commit().await.unwrap();
1346        // Querying the path again
1347        let merkle_proof = storage
1348            .read()
1349            .await
1350            .unwrap()
1351            .get_path(
1352                Snapshot::<_, MockMerkleTree, 8>::Index(block_height as u64),
1353                1,
1354            )
1355            .await
1356            .unwrap();
1357        assert_eq!(merkle_proof, proof, "path dont match");
1358
1359        // Update the tree again for index 0 with created (bh) = 2
1360        // Delete one of the node in the traversal path
1361        test_tree.update(1, 200).unwrap();
1362
1363        let (_, proof) = test_tree.lookup(1).expect_ok().unwrap();
1364        let test_data = serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()});
1365
1366        // insert the header with merkle commitment
1367        let mut tx = storage.write().await.unwrap();
1368        tx.upsert(
1369            "header",
1370            [
1371                "height",
1372                "hash",
1373                "payload_hash",
1374                "timestamp",
1375                "data",
1376                "ns_table",
1377            ],
1378            ["height"],
1379            [(
1380                2i64,
1381                "randomHashString".to_string(),
1382                "t".to_string(),
1383                0,
1384                test_data,
1385                "ns_table".to_string(),
1386            )],
1387        )
1388        .await
1389        .unwrap();
1390        UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1391            &mut tx,
1392            proof.clone(),
1393            traversal_path.clone(),
1394            2_u64,
1395        )
1396        .await
1397        .expect("failed to insert nodes");
1398
1399        // Deleting one internal node
1400        let node_path = traversal_path
1401            .iter()
1402            .skip(1)
1403            .rev()
1404            .map(|n| *n as i32)
1405            .collect::<Vec<_>>();
1406        tx.execute(
1407            query(&format!(
1408                "DELETE FROM {} WHERE created = 2 and path = $1",
1409                MockMerkleTree::state_type()
1410            ))
1411            .bind(serde_json::to_value(node_path).unwrap()),
1412        )
1413        .await
1414        .expect("failed to delete internal node");
1415        tx.commit().await.unwrap();
1416
1417        let merkle_path = storage
1418            .read()
1419            .await
1420            .unwrap()
1421            .get_path(Snapshot::<_, MockMerkleTree, 8>::Index(2_u64), 1)
1422            .await;
1423
1424        assert!(merkle_path.is_err());
1425    }
1426
1427    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1428    async fn test_merklized_state_snapshot() {
1429        let db = TmpDb::init().await;
1430        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1431            .await
1432            .unwrap();
1433
1434        // Define a test tree
1435        let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1436
1437        // We will sample random keys as u32. This is a value that is not a valid u32 and thus is a
1438        // key we will never insert into the tree.
1439        const RESERVED_KEY: usize = (u32::MAX as usize) + 1;
1440
1441        // Randomly insert and delete some entries. For each entry we insert, we also keep track of
1442        // whether the entry should be in the tree using a HashMap.
1443        #[tracing::instrument(skip(tree, expected))]
1444        fn randomize(tree: &mut MockMerkleTree, expected: &mut HashMap<usize, Option<usize>>) {
1445            let mut rng = rand::thread_rng();
1446            tracing::info!("randomizing tree");
1447
1448            for _ in 0..50 {
1449                // We flip a coin to decide whether to insert or delete, unless the tree is empty,
1450                // in which case we can only insert.
1451                if !expected.values().any(|v| v.is_some()) || rng.next_u32().is_multiple_of(2) {
1452                    // Insert.
1453                    let key = rng.next_u32() as usize;
1454                    let val = rng.next_u32() as usize;
1455                    tracing::info!(key, val, "inserting");
1456
1457                    tree.update(key, val).unwrap();
1458                    expected.insert(key, Some(val));
1459                } else {
1460                    // Delete.
1461                    let key = expected
1462                        .iter()
1463                        .filter_map(|(k, v)| if v.is_some() { Some(k) } else { None })
1464                        .choose(&mut rng)
1465                        .unwrap();
1466                    tracing::info!(key, "deleting");
1467
1468                    tree.remove(key).unwrap();
1469                    expected.insert(*key, None);
1470                }
1471            }
1472        }
1473
1474        // Commit the tree to storage.
1475        #[tracing::instrument(skip(storage, tree, expected))]
1476        async fn store(
1477            storage: &SqlStorage,
1478            tree: &MockMerkleTree,
1479            expected: &HashMap<usize, Option<usize>>,
1480            block_height: u64,
1481        ) {
1482            tracing::info!("persisting tree");
1483            let mut tx = storage.write().await.unwrap();
1484
1485            for key in expected.keys() {
1486                let proof = match tree.universal_lookup(key) {
1487                    LookupResult::Ok(_, proof) => proof,
1488                    LookupResult::NotFound(proof) => proof,
1489                    LookupResult::NotInMemory => panic!("failed to find key {key}"),
1490                };
1491                let traversal_path = ToTraversalPath::<8>::to_traversal_path(key, tree.height());
1492                UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1493                    &mut tx,
1494                    proof,
1495                    traversal_path,
1496                    block_height,
1497                )
1498                .await
1499                .unwrap();
1500            }
1501            // insert the header with merkle commitment
1502            tx
1503            .upsert("header", ["height", "hash", "payload_hash", "timestamp", "data", "ns_table"], ["height"],
1504                [(
1505                    block_height as i64,
1506                    format!("hash{block_height}"),
1507                    "hash".to_string(),
1508                    0i64,
1509                    serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(tree.commitment()).unwrap()}),
1510                    "ns_table".to_string(),
1511                )],
1512            )
1513            .await
1514            .unwrap();
1515            UpdateStateData::<MockTypes, MockMerkleTree, 8>::set_last_state_height(
1516                &mut tx,
1517                block_height as usize,
1518            )
1519            .await
1520            .unwrap();
1521            tx.commit().await.unwrap();
1522        }
1523
1524        #[tracing::instrument(skip(storage, tree, expected))]
1525        async fn validate(
1526            storage: &SqlStorage,
1527            tree: &MockMerkleTree,
1528            expected: &HashMap<usize, Option<usize>>,
1529            block_height: u64,
1530        ) {
1531            tracing::info!("validating snapshot");
1532
1533            // Check that we can get a correct path for each key that we touched.
1534            let snapshot = Snapshot::<_, MockMerkleTree, 8>::Index(block_height);
1535
1536            for (key, val) in expected {
1537                let proof = match tree.universal_lookup(key) {
1538                    LookupResult::Ok(_, proof) => proof,
1539                    LookupResult::NotFound(proof) => proof,
1540                    LookupResult::NotInMemory => panic!("failed to find key {key}"),
1541                };
1542                assert_eq!(
1543                    proof,
1544                    storage
1545                        .read()
1546                        .await
1547                        .unwrap()
1548                        .get_path(snapshot, *key)
1549                        .await
1550                        .unwrap()
1551                );
1552                assert_eq!(val.as_ref(), proof.elem());
1553                // Check path is valid for test_tree
1554                if val.is_some() {
1555                    MockMerkleTree::verify(tree.commitment(), key, proof)
1556                        .unwrap()
1557                        .unwrap();
1558                } else {
1559                    assert!(
1560                        MockMerkleTree::non_membership_verify(tree.commitment(), key, proof)
1561                            .unwrap()
1562                    );
1563                }
1564            }
1565
1566            // Check that we can even get a non-membership proof for a key that we never touched.
1567            let proof = match tree.universal_lookup(RESERVED_KEY) {
1568                LookupResult::Ok(_, proof) => proof,
1569                LookupResult::NotFound(proof) => proof,
1570                LookupResult::NotInMemory => panic!("failed to find reserved key {RESERVED_KEY}"),
1571            };
1572            assert_eq!(
1573                proof,
1574                storage
1575                    .read()
1576                    .await
1577                    .unwrap()
1578                    .get_path(snapshot, RESERVED_KEY)
1579                    .await
1580                    .unwrap()
1581            );
1582            assert_eq!(proof.elem(), None);
1583            // Check path is valid for test_tree
1584            assert!(
1585                MockMerkleTree::non_membership_verify(tree.commitment(), RESERVED_KEY, proof)
1586                    .unwrap()
1587            );
1588        }
1589
1590        // Create a randomized Merkle tree.
1591        let mut expected = HashMap::<usize, Option<usize>>::new();
1592        randomize(&mut test_tree, &mut expected);
1593
1594        // Commit the randomized tree to storage.
1595        store(&storage, &test_tree, &expected, 1).await;
1596        validate(&storage, &test_tree, &expected, 1).await;
1597
1598        // Make random edits and commit another snapshot.
1599        let mut expected2 = expected.clone();
1600        let mut test_tree2 = test_tree.clone();
1601        randomize(&mut test_tree2, &mut expected2);
1602        store(&storage, &test_tree2, &expected2, 2).await;
1603        validate(&storage, &test_tree2, &expected2, 2).await;
1604
1605        // Ensure the original snapshot is still valid.
1606        validate(&storage, &test_tree, &expected, 1).await;
1607    }
1608
1609    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1610    async fn test_merklized_state_missing_leaf() {
1611        // Check that if a leaf is missing but its ancestors are present/key is in the tree, we
1612        // catch it rather than interpreting the entry as an empty node by default. Note that this
1613        // scenario should be impossible in normal usage, since we never store or delete partial
1614        // paths. But we should never return an invalid proof even in extreme cases like database
1615        // corruption.
1616
1617        for tree_size in 1..=3 {
1618            let db = TmpDb::init().await;
1619            let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1620                .await
1621                .unwrap();
1622
1623            // Define a test tree
1624            let mut test_tree = MockMerkleTree::new(MockMerkleTree::tree_height());
1625            for i in 0..tree_size {
1626                test_tree.update(i, i).unwrap();
1627            }
1628
1629            let mut tx = storage.write().await.unwrap();
1630
1631            // Insert a header with the tree commitment.
1632            tx.upsert(
1633                "header",
1634                ["height", "hash", "payload_hash", "timestamp", "data", "ns_table"],
1635                ["height"],
1636                [(
1637                    0i64,
1638                    "hash".to_string(),
1639                    "hash".to_string(),
1640                    0,
1641                    serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}),
1642                    "ns_table".to_string(),
1643                )],
1644            )
1645            .await
1646            .unwrap();
1647
1648            // Insert Merkle nodes.
1649            for i in 0..tree_size {
1650                let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1651                let traversal_path =
1652                    ToTraversalPath::<8>::to_traversal_path(&i, test_tree.height());
1653                UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1654                    &mut tx,
1655                    proof,
1656                    traversal_path,
1657                    0,
1658                )
1659                .await
1660                .unwrap();
1661            }
1662            UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(&mut tx, 0)
1663                .await
1664                .unwrap();
1665            tx.commit().await.unwrap();
1666
1667            // Test that we can get all the entries.
1668            let snapshot = Snapshot::<MockTypes, MockMerkleTree, 8>::Index(0);
1669            for i in 0..tree_size {
1670                let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1671                assert_eq!(
1672                    proof,
1673                    storage
1674                        .read()
1675                        .await
1676                        .unwrap()
1677                        .get_path(snapshot, i)
1678                        .await
1679                        .unwrap()
1680                );
1681                assert_eq!(*proof.elem().unwrap(), i);
1682            }
1683
1684            // Now delete the leaf node for the last entry we inserted, corrupting the database.
1685            let index = serde_json::to_value(tree_size - 1).unwrap();
1686            let mut tx = storage.write().await.unwrap();
1687
1688            tx.execute(
1689                query(&format!(
1690                    "DELETE FROM {} WHERE idx = $1",
1691                    MockMerkleTree::state_type()
1692                ))
1693                .bind(serde_json::to_value(index).unwrap()),
1694            )
1695            .await
1696            .unwrap();
1697            tx.commit().await.unwrap();
1698
1699            // Test that we can still get the entries we didn't delete.
1700            for i in 0..tree_size - 1 {
1701                let proof = test_tree.lookup(i).expect_ok().unwrap().1;
1702                assert_eq!(
1703                    proof,
1704                    storage
1705                        .read()
1706                        .await
1707                        .unwrap()
1708                        .get_path(snapshot, i)
1709                        .await
1710                        .unwrap()
1711                );
1712                assert_eq!(*proof.elem().unwrap(), i);
1713            }
1714
1715            // Looking up the entry we deleted fails, rather than return an invalid path.
1716            let err = storage
1717                .read()
1718                .await
1719                .unwrap()
1720                .get_path(snapshot, tree_size - 1)
1721                .await
1722                .unwrap_err();
1723            assert!(matches!(err, QueryError::Missing));
1724        }
1725    }
1726}