vid/
avidm_gf2.rs

1//! This module implements the AVID-M scheme over GF2
2
3use std::{ops::Range, vec};
4
5use anyhow::anyhow;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use jf_merkle_tree::{hasher::HasherNode, MerkleTreeScheme};
8use jf_utils::canonical;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11use tagged_base64::tagged;
12
13use crate::{VidError, VidResult, VidScheme};
14
15/// Namespaced AvidmGf2 scheme
16pub mod namespaced;
17/// Namespace proofs for AvidmGf2 scheme
18pub mod proofs;
19
20/// Merkle tree scheme used in the VID
21pub(crate) type MerkleTree =
22    jf_merkle_tree::hasher::HasherMerkleTree<sha3::Keccak256, HasherNode<sha3::Keccak256>>;
23type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
24type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
25
26/// Dummy struct for AVID-M scheme over GF2
27pub struct AvidmGf2Scheme;
28
29/// VID Parameters
30#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
31pub struct AvidmGf2Param {
32    /// Total weights of all storage nodes
33    pub total_weights: usize,
34    /// Minimum collective weights required to recover the original payload.
35    pub recovery_threshold: usize,
36}
37
38impl AvidmGf2Param {
39    /// Construct a new [`AvidmGf2Param`].
40    pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
41        if recovery_threshold == 0 || total_weights < recovery_threshold {
42            return Err(VidError::InvalidParam);
43        }
44        Ok(Self {
45            total_weights,
46            recovery_threshold,
47        })
48    }
49}
50
51/// VID Share type to be distributed among the parties.
52#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
53pub struct AvidmGf2Share {
54    /// Range of this share in the encoded payload.
55    range: Range<usize>,
56    /// Actual share content.
57    #[serde(with = "canonical")]
58    payload: Vec<Vec<u8>>,
59    /// Merkle proof of the content.
60    #[serde(with = "canonical")]
61    mt_proofs: Vec<MerkleProof>,
62}
63
64impl AvidmGf2Share {
65    /// Get the weight of this share
66    pub fn weight(&self) -> usize {
67        self.range.len()
68    }
69
70    /// Validate the share structure.
71    pub fn validate(&self) -> bool {
72        self.payload.len() == self.range.len() && self.mt_proofs.len() == self.range.len()
73    }
74}
75
76/// VID Commitment type
77#[derive(
78    Clone,
79    Copy,
80    Debug,
81    Hash,
82    CanonicalSerialize,
83    CanonicalDeserialize,
84    Eq,
85    PartialEq,
86    Ord,
87    PartialOrd,
88)]
89#[tagged("AvidmGf2Commit")]
90#[repr(C)]
91pub struct AvidmGf2Commit {
92    /// VID commitment is the Merkle tree root
93    pub commit: MerkleCommit,
94}
95
96impl AsRef<[u8]> for AvidmGf2Commit {
97    fn as_ref(&self) -> &[u8] {
98        self.commit.as_ref()
99    }
100}
101
102impl AvidmGf2Scheme {
103    /// Setup an instance for AVID-M scheme
104    pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidmGf2Param> {
105        AvidmGf2Param::new(recovery_threshold, total_weights)
106    }
107
108    fn bit_padding(payload: &[u8], payload_len: usize) -> VidResult<Vec<u8>> {
109        if payload_len < payload.len() + 1 {
110            return Err(VidError::Argument(
111                "Payload length is too large to fit in the given payload length".to_string(),
112            ));
113        }
114        let mut padded = vec![0u8; payload_len];
115        padded[..payload.len()].copy_from_slice(payload);
116        padded[payload.len()] = 1u8;
117        Ok(padded)
118    }
119
120    fn raw_disperse(
121        param: &AvidmGf2Param,
122        payload: &[u8],
123    ) -> VidResult<(MerkleTree, Vec<Vec<u8>>)> {
124        let original_count = param.recovery_threshold;
125        let recovery_count = param.total_weights - param.recovery_threshold;
126        // Bit padding, we append an 1u8 to the end of the payload.
127        let mut shard_bytes = (payload.len() + 1).div_ceil(original_count);
128        if shard_bytes % 2 == 1 {
129            shard_bytes += 1;
130        }
131        let payload = Self::bit_padding(payload, shard_bytes * original_count)?;
132        let original = payload
133            .chunks(shard_bytes)
134            .map(|chunk| chunk.to_owned())
135            .collect::<Vec<_>>();
136        let recovery = reed_solomon_simd::encode(original_count, recovery_count, &original)?;
137
138        let shares = [original, recovery].concat();
139        let share_digests: Vec<_> = shares
140            .iter()
141            .map(|share| HasherNode::from(sha3::Keccak256::digest(share)))
142            .collect();
143        let mt = MerkleTree::from_elems(None, &share_digests)?;
144        Ok((mt, shares))
145    }
146}
147
148impl VidScheme for AvidmGf2Scheme {
149    type Param = AvidmGf2Param;
150    type Share = AvidmGf2Share;
151    type Commit = AvidmGf2Commit;
152
153    fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
154        let (mt, _) = Self::raw_disperse(param, payload)?;
155        Ok(Self::Commit {
156            commit: mt.commitment(),
157        })
158    }
159
160    fn disperse(
161        param: &Self::Param,
162        distribution: &[u32],
163        payload: &[u8],
164    ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
165        let (mt, shares) = Self::raw_disperse(param, payload)?;
166        let commit = AvidmGf2Commit {
167            commit: mt.commitment(),
168        };
169        let ranges: Vec<_> = distribution
170            .iter()
171            .scan(0, |sum, w| {
172                let prefix_sum = *sum;
173                *sum += w;
174                Some(prefix_sum as usize..*sum as usize)
175            })
176            .collect();
177        let shares: Vec<_> = ranges
178            .into_iter()
179            .map(|range| AvidmGf2Share {
180                range: range.clone(),
181                payload: shares[range.clone()].to_vec(),
182                // TODO(Chengyu): switch to batch proof generation
183                mt_proofs: range
184                    .map(|k| {
185                        mt.lookup(k as u64)
186                            .expect_ok()
187                            .expect("MT lookup shouldn't fail")
188                            .1
189                    })
190                    .collect::<Vec<_>>(),
191            })
192            .collect();
193        Ok((commit, shares))
194    }
195
196    fn verify_share(
197        _param: &Self::Param,
198        commit: &Self::Commit,
199        share: &Self::Share,
200    ) -> VidResult<crate::VerificationResult> {
201        if !share.validate() {
202            return Err(VidError::InvalidShare);
203        }
204        for (i, index) in share.range.clone().enumerate() {
205            let payload_digest = HasherNode::from(sha3::Keccak256::digest(&share.payload[i]));
206            // TODO(Chengyu): switch to batch verification
207            if MerkleTree::verify(
208                commit.commit,
209                index as u64,
210                payload_digest,
211                &share.mt_proofs[i],
212            )?
213            .is_err()
214            {
215                return Ok(Err(()));
216            }
217        }
218        Ok(Ok(()))
219    }
220
221    fn recover(
222        param: &Self::Param,
223        _commit: &Self::Commit,
224        shares: &[Self::Share],
225    ) -> VidResult<Vec<u8>> {
226        let original_count = param.recovery_threshold;
227        let recovery_count = param.total_weights - param.recovery_threshold;
228        // Find the first non-empty share
229        let Some(first_share) = shares.iter().find(|s| !s.payload.is_empty()) else {
230            return Err(VidError::InsufficientShares);
231        };
232        let shard_bytes = first_share.payload[0].len();
233
234        let mut decoder = reed_solomon_simd::ReedSolomonDecoder::new(
235            original_count,
236            recovery_count,
237            shard_bytes,
238        )?;
239        let mut original_shares = vec![None; original_count];
240        for share in shares {
241            if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
242                return Err(VidError::InvalidShare);
243            }
244            for (i, index) in share.range.clone().enumerate() {
245                if index < original_count {
246                    original_shares[index] = Some(share.payload[i].as_ref());
247                    decoder.add_original_shard(index, &share.payload[i])?;
248                } else {
249                    decoder.add_recovery_shard(index - original_count, &share.payload[i])?;
250                }
251            }
252        }
253        let result = decoder.decode()?;
254        original_shares
255            .iter_mut()
256            .enumerate()
257            .for_each(|(i, share)| {
258                if share.is_none() {
259                    *share = result.restored_original(i);
260                }
261            });
262        if original_shares.iter().any(|share| share.is_none()) {
263            return Err(VidError::Internal(anyhow!(
264                "Failed to recover the payload."
265            )));
266        }
267        let mut recovered: Vec<_> = original_shares
268            .into_iter()
269            .flat_map(|share| share.unwrap().to_vec())
270            .collect();
271        match recovered.iter().rposition(|&b| b != 0) {
272            Some(pad_index) if recovered[pad_index] == 1u8 => {
273                recovered.truncate(pad_index);
274                Ok(recovered)
275            },
276            _ => Err(VidError::Argument(
277                "Malformed payload, cannot find the padding position".to_string(),
278            )),
279        }
280    }
281}
282
283/// Unit tests
284#[cfg(test)]
285pub mod tests {
286    use rand::{seq::SliceRandom, RngCore};
287
288    use super::AvidmGf2Scheme;
289    use crate::VidScheme;
290
291    #[test]
292    fn round_trip() {
293        // play with these items
294        let num_storage_nodes_list = [4, 9, 16];
295        let payload_byte_lens = [1, 31, 32, 500];
296
297        // more items as a function of the above
298
299        let mut rng = jf_utils::test_rng();
300
301        for num_storage_nodes in num_storage_nodes_list {
302            let weights: Vec<u32> = (0..num_storage_nodes)
303                .map(|_| rng.next_u32() % 5 + 1)
304                .collect();
305            let total_weights: u32 = weights.iter().sum();
306            let recovery_threshold = total_weights.div_ceil(3) as usize;
307            let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
308
309            for payload_byte_len in payload_byte_lens {
310                let payload = {
311                    let mut bytes_random = vec![0u8; payload_byte_len];
312                    rng.fill_bytes(&mut bytes_random);
313                    bytes_random
314                };
315
316                let (commit, mut shares) =
317                    AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
318
319                assert_eq!(shares.len(), num_storage_nodes);
320
321                // verify shares
322                shares.iter().for_each(|share| {
323                    assert!(AvidmGf2Scheme::verify_share(&params, &commit, share)
324                        .is_ok_and(|r| r.is_ok()))
325                });
326
327                // test payload recovery on a random subset of shares
328                shares.shuffle(&mut rng);
329                let mut cumulated_weights = 0;
330                let mut cut_index = 0;
331                while cumulated_weights <= recovery_threshold {
332                    cumulated_weights += shares[cut_index].weight();
333                    cut_index += 1;
334                }
335                let payload_recovered =
336                    AvidmGf2Scheme::recover(&params, &commit, &shares[..cut_index]).unwrap();
337                assert_eq!(payload_recovered, payload);
338            }
339        }
340    }
341}