vid/
avidm_gf2.rs

1//! This module implements the AVID-M scheme over GF2
2
3use std::{ops::Range, vec};
4
5use anyhow::anyhow;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use jf_merkle_tree::{hasher::HasherNode, MerkleTreeScheme};
8use jf_utils::canonical;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11use tagged_base64::tagged;
12
13use crate::{VidError, VidResult, VidScheme};
14
15/// Namespaced AvidmGf2 scheme
16pub mod namespaced;
17/// Namespace proofs for AvidmGf2 scheme
18pub mod proofs;
19
20/// Merkle tree scheme used in the VID
21pub(crate) type MerkleTree =
22    jf_merkle_tree::hasher::HasherMerkleTree<sha3::Keccak256, HasherNode<sha3::Keccak256>>;
23type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
24type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
25
26/// Dummy struct for AVID-M scheme over GF2
27pub struct AvidmGf2Scheme;
28
29/// VID Parameters
30#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
31pub struct AvidmGf2Param {
32    /// Total weights of all storage nodes
33    pub total_weights: usize,
34    /// Minimum collective weights required to recover the original payload.
35    pub recovery_threshold: usize,
36}
37
38impl AvidmGf2Param {
39    /// Construct a new [`AvidmGf2Param`].
40    pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
41        if recovery_threshold == 0 || total_weights < recovery_threshold {
42            return Err(VidError::InvalidParam);
43        }
44        Ok(Self {
45            total_weights,
46            recovery_threshold,
47        })
48    }
49}
50
51/// VID Share type to be distributed among the parties.
52#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
53pub struct AvidmGf2Share {
54    /// Range of this share in the encoded payload.
55    range: Range<usize>,
56    /// Actual share content.
57    #[serde(with = "canonical")]
58    payload: Vec<Vec<u8>>,
59    /// Merkle proof of the content.
60    #[serde(with = "canonical")]
61    mt_proofs: Vec<MerkleProof>,
62}
63
64impl AvidmGf2Share {
65    /// Get the weight of this share
66    pub fn weight(&self) -> usize {
67        self.range.len()
68    }
69
70    /// Validate the share structure.
71    pub fn validate(&self) -> bool {
72        self.payload.len() == self.range.len() && self.mt_proofs.len() == self.range.len()
73    }
74}
75
76/// VID Commitment type
77#[derive(
78    Clone,
79    Copy,
80    Debug,
81    Default,
82    Hash,
83    CanonicalSerialize,
84    CanonicalDeserialize,
85    Eq,
86    PartialEq,
87    Ord,
88    PartialOrd,
89)]
90#[tagged("AvidmGf2Commit")]
91#[repr(C)]
92pub struct AvidmGf2Commit {
93    /// VID commitment is the Merkle tree root
94    pub commit: MerkleCommit,
95}
96
97impl AsRef<[u8]> for AvidmGf2Commit {
98    fn as_ref(&self) -> &[u8] {
99        self.commit.as_ref()
100    }
101}
102
103impl AsRef<[u8; 32]> for AvidmGf2Commit {
104    fn as_ref(&self) -> &[u8; 32] {
105        <Self as AsRef<[u8]>>::as_ref(self)
106            .try_into()
107            .expect("AvidmGf2Commit is always 32 bytes")
108    }
109}
110
111impl AvidmGf2Scheme {
112    /// Setup an instance for AVID-M scheme
113    pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidmGf2Param> {
114        AvidmGf2Param::new(recovery_threshold, total_weights)
115    }
116
117    fn bit_padding(payload: &[u8], payload_len: usize) -> VidResult<Vec<u8>> {
118        if payload_len < payload.len() + 1 {
119            return Err(VidError::Argument(
120                "Payload length is too large to fit in the given payload length".to_string(),
121            ));
122        }
123        let mut padded = vec![0u8; payload_len];
124        padded[..payload.len()].copy_from_slice(payload);
125        padded[payload.len()] = 1u8;
126        Ok(padded)
127    }
128
129    fn raw_disperse(
130        param: &AvidmGf2Param,
131        payload: &[u8],
132    ) -> VidResult<(MerkleTree, Vec<Vec<u8>>)> {
133        let original_count = param.recovery_threshold;
134        let recovery_count = param.total_weights - param.recovery_threshold;
135        // Bit padding, we append an 1u8 to the end of the payload.
136        let mut shard_bytes = (payload.len() + 1).div_ceil(original_count);
137        if shard_bytes % 2 == 1 {
138            shard_bytes += 1;
139        }
140        let payload = Self::bit_padding(payload, shard_bytes * original_count)?;
141        let original = payload
142            .chunks(shard_bytes)
143            .map(|chunk| chunk.to_owned())
144            .collect::<Vec<_>>();
145        let recovery = if recovery_count == 0 {
146            vec![]
147        } else {
148            reed_solomon_simd::encode(original_count, recovery_count, &original)?
149        };
150
151        let shares = [original, recovery].concat();
152        let share_digests: Vec<_> = shares
153            .iter()
154            .map(|share| HasherNode::from(sha3::Keccak256::digest(share)))
155            .collect();
156        let mt = MerkleTree::from_elems(None, &share_digests)?;
157        Ok((mt, shares))
158    }
159}
160
161impl VidScheme for AvidmGf2Scheme {
162    type Param = AvidmGf2Param;
163    type Share = AvidmGf2Share;
164    type Commit = AvidmGf2Commit;
165
166    fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
167        let (mt, _) = Self::raw_disperse(param, payload)?;
168        Ok(Self::Commit {
169            commit: mt.commitment(),
170        })
171    }
172
173    fn disperse(
174        param: &Self::Param,
175        distribution: &[u32],
176        payload: &[u8],
177    ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
178        let (mt, shares) = Self::raw_disperse(param, payload)?;
179        let commit = AvidmGf2Commit {
180            commit: mt.commitment(),
181        };
182        let ranges: Vec<_> = distribution
183            .iter()
184            .scan(0, |sum, w| {
185                let prefix_sum = *sum;
186                *sum += w;
187                Some(prefix_sum as usize..*sum as usize)
188            })
189            .collect();
190        let shares: Vec<_> = ranges
191            .into_iter()
192            .map(|range| AvidmGf2Share {
193                range: range.clone(),
194                payload: shares[range.clone()].to_vec(),
195                // TODO(Chengyu): switch to batch proof generation
196                mt_proofs: range
197                    .map(|k| {
198                        mt.lookup(k as u64)
199                            .expect_ok()
200                            .expect("MT lookup shouldn't fail")
201                            .1
202                    })
203                    .collect::<Vec<_>>(),
204            })
205            .collect();
206        Ok((commit, shares))
207    }
208
209    fn verify_share(
210        _param: &Self::Param,
211        commit: &Self::Commit,
212        share: &Self::Share,
213    ) -> VidResult<crate::VerificationResult> {
214        if !share.validate() {
215            return Err(VidError::InvalidShare);
216        }
217        for (i, index) in share.range.clone().enumerate() {
218            let payload_digest = HasherNode::from(sha3::Keccak256::digest(&share.payload[i]));
219            // TODO(Chengyu): switch to batch verification
220            if MerkleTree::verify(
221                commit.commit,
222                index as u64,
223                payload_digest,
224                &share.mt_proofs[i],
225            )?
226            .is_err()
227            {
228                return Ok(Err(()));
229            }
230        }
231        Ok(Ok(()))
232    }
233
234    fn recover(
235        param: &Self::Param,
236        _commit: &Self::Commit,
237        shares: &[Self::Share],
238    ) -> VidResult<Vec<u8>> {
239        let original_count = param.recovery_threshold;
240        let recovery_count = param.total_weights - param.recovery_threshold;
241        // Find the first non-empty share
242        let Some(first_share) = shares.iter().find(|s| !s.payload.is_empty()) else {
243            return Err(VidError::InsufficientShares);
244        };
245        let shard_bytes = first_share.payload[0].len();
246
247        let mut original_shares: Vec<Option<Vec<u8>>> = vec![None; original_count];
248        if recovery_count == 0 {
249            // Edge case where there are no recovery shares
250            for share in shares {
251                if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
252                    return Err(VidError::InvalidShare);
253                }
254                for (i, index) in share.range.clone().enumerate() {
255                    if index < original_count {
256                        original_shares[index] = Some(share.payload[i].clone());
257                    }
258                }
259            }
260        } else {
261            let mut decoder = reed_solomon_simd::ReedSolomonDecoder::new(
262                original_count,
263                recovery_count,
264                shard_bytes,
265            )?;
266            for share in shares {
267                if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
268                    return Err(VidError::InvalidShare);
269                }
270                for (i, index) in share.range.clone().enumerate() {
271                    if index < original_count {
272                        original_shares[index] = Some(share.payload[i].clone());
273                        decoder.add_original_shard(index, &share.payload[i])?;
274                    } else {
275                        decoder.add_recovery_shard(index - original_count, &share.payload[i])?;
276                    }
277                }
278            }
279
280            let result = decoder.decode()?;
281            original_shares
282                .iter_mut()
283                .enumerate()
284                .for_each(|(i, share)| {
285                    if share.is_none() {
286                        *share = result.restored_original(i).map(|s| s.to_vec());
287                    }
288                });
289        }
290        if original_shares.iter().any(|share| share.is_none()) {
291            return Err(VidError::Internal(anyhow!(
292                "Failed to recover the payload."
293            )));
294        }
295        let mut recovered: Vec<_> = original_shares
296            .into_iter()
297            .flat_map(|share| share.unwrap())
298            .collect();
299        match recovered.iter().rposition(|&b| b != 0) {
300            Some(pad_index) if recovered[pad_index] == 1u8 => {
301                recovered.truncate(pad_index);
302                Ok(recovered)
303            },
304            _ => Err(VidError::Argument(
305                "Malformed payload, cannot find the padding position".to_string(),
306            )),
307        }
308    }
309}
310
311/// Unit tests
312#[cfg(test)]
313pub mod tests {
314    use rand::{seq::SliceRandom, RngCore};
315
316    use super::AvidmGf2Scheme;
317    use crate::VidScheme;
318
319    #[test]
320    fn round_trip() {
321        // play with these items
322        let num_storage_nodes_list = [4, 9, 16];
323        let payload_byte_lens = [1, 31, 32, 500];
324
325        // more items as a function of the above
326
327        let mut rng = jf_utils::test_rng();
328
329        for num_storage_nodes in num_storage_nodes_list {
330            let weights: Vec<u32> = (0..num_storage_nodes)
331                .map(|_| rng.next_u32() % 5 + 1)
332                .collect();
333            let total_weights: u32 = weights.iter().sum();
334            let recovery_threshold = total_weights.div_ceil(3) as usize;
335            let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
336
337            for payload_byte_len in payload_byte_lens {
338                let payload = {
339                    let mut bytes_random = vec![0u8; payload_byte_len];
340                    rng.fill_bytes(&mut bytes_random);
341                    bytes_random
342                };
343
344                let (commit, mut shares) =
345                    AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
346
347                assert_eq!(shares.len(), num_storage_nodes);
348
349                // verify shares
350                shares.iter().for_each(|share| {
351                    assert!(AvidmGf2Scheme::verify_share(&params, &commit, share)
352                        .is_ok_and(|r| r.is_ok()))
353                });
354
355                // test payload recovery on a random subset of shares
356                shares.shuffle(&mut rng);
357                let mut cumulated_weights = 0;
358                let mut cut_index = 0;
359                while cumulated_weights < recovery_threshold {
360                    cumulated_weights += shares[cut_index].weight();
361                    cut_index += 1;
362                }
363                let payload_recovered =
364                    AvidmGf2Scheme::recover(&params, &commit, &shares[..cut_index]).unwrap();
365                assert_eq!(payload_recovered, payload);
366            }
367        }
368    }
369
370    #[test]
371    fn round_trip_edge_case() {
372        // play with these items
373        let num_storage_nodes_list = [4, 9, 16];
374        let payload_byte_lens = [1, 31, 32, 500];
375
376        // more items as a function of the above
377
378        let mut rng = jf_utils::test_rng();
379
380        for num_storage_nodes in num_storage_nodes_list {
381            let weights: Vec<u32> = (0..num_storage_nodes)
382                .map(|_| rng.next_u32() % 5 + 1)
383                .collect();
384            let total_weights: u32 = weights.iter().sum();
385            let recovery_threshold = total_weights as usize;
386            let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
387
388            for payload_byte_len in payload_byte_lens {
389                let payload = {
390                    let mut bytes_random = vec![0u8; payload_byte_len];
391                    rng.fill_bytes(&mut bytes_random);
392                    bytes_random
393                };
394
395                let (commit, mut shares) =
396                    AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
397
398                assert_eq!(shares.len(), num_storage_nodes);
399
400                // verify shares
401                shares.iter().for_each(|share| {
402                    assert!(AvidmGf2Scheme::verify_share(&params, &commit, share)
403                        .is_ok_and(|r| r.is_ok()))
404                });
405
406                // test payload recovery on a random subset of shares
407                shares.shuffle(&mut rng);
408                let payload_recovered =
409                    AvidmGf2Scheme::recover(&params, &commit, &shares[..]).unwrap();
410                assert_eq!(payload_recovered, payload);
411            }
412        }
413    }
414}