vid/
avidm.rs

1//! This module implements the AVID-M scheme, whose name came after the DispersedLedger paper <https://www.usenix.org/conference/nsdi22/presentation/yang>.
2//!
3//! To disperse a payload to a number of storage nodes according to a weight
4//! distribution, the payload is first converted into field elements and then
5//! divided into chunks of `k` elements each, and each chunk is then encoded
6//! into `n` field elements using Reed Solomon code. The parameter `n` equals to
7//! the total weight of all storage nodes, and `k` is the minimum collective
8//! weights required to recover the original payload. After the encoding, it can
9//! be viewed as `n` vectors of field elements each of length equals to the
10//! number of chunks. The VID commitment is obtained by Merklized these `n`
11//! vectors. And for dispersal, each storage node gets some vectors and their
12//! Merkle proofs according to its weight.
13
14use std::{collections::HashMap, iter, ops::Range};
15
16use ark_ff::PrimeField;
17use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
18use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
19use ark_std::{end_timer, start_timer};
20use config::AvidMConfig;
21use jf_merkle_tree::MerkleTreeScheme;
22use jf_utils::canonical;
23use p3_maybe_rayon::prelude::{
24    IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSlice,
25};
26use serde::{Deserialize, Serialize};
27use tagged_base64::tagged;
28
29use crate::{
30    utils::bytes_to_field::{self, bytes_to_field, field_to_bytes},
31    VidError, VidResult, VidScheme,
32};
33
34mod config;
35
36pub mod namespaced;
37pub mod proofs;
38
39#[cfg(all(not(feature = "sha256"), not(feature = "keccak256")))]
40type Config = config::Poseidon2Config;
41#[cfg(feature = "sha256")]
42type Config = config::Sha256Config;
43#[cfg(feature = "keccak256")]
44type Config = config::Keccak256Config;
45
46// Type alias for convenience
47type F = <Config as AvidMConfig>::BaseField;
48type MerkleTree = <Config as AvidMConfig>::MerkleTree;
49type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
50type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
51
52/// Commit type for AVID-M scheme.
53#[derive(
54    Clone,
55    Copy,
56    Debug,
57    Hash,
58    CanonicalSerialize,
59    CanonicalDeserialize,
60    Eq,
61    PartialEq,
62    Ord,
63    PartialOrd,
64)]
65#[tagged("AvidMCommit")]
66#[repr(C)]
67pub struct AvidMCommit {
68    /// Root commitment of the Merkle tree.
69    pub commit: MerkleCommit,
70}
71
72impl AsRef<[u8]> for AvidMCommit {
73    fn as_ref(&self) -> &[u8] {
74        unsafe {
75            ::core::slice::from_raw_parts(
76                (self as *const Self) as *const u8,
77                ::core::mem::size_of::<Self>(),
78            )
79        }
80    }
81}
82
83impl AsRef<[u8; 32]> for AvidMCommit {
84    fn as_ref(&self) -> &[u8; 32] {
85        unsafe { ::core::slice::from_raw_parts((self as *const Self) as *const u8, 32) }
86            .try_into()
87            .unwrap()
88    }
89}
90
91/// Share type to be distributed among the parties.
92#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
93pub struct RawAvidMShare {
94    /// Range of this share in the encoded payload.
95    range: Range<usize>,
96    /// Actual share content.
97    #[serde(with = "canonical")]
98    payload: Vec<Vec<F>>,
99    /// Merkle proof of the content.
100    #[serde(with = "canonical")]
101    mt_proofs: Vec<MerkleProof>,
102}
103
104/// Share type to be distributed among the parties.
105#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
106pub struct AvidMShare {
107    /// Index number of the given share.
108    index: u32,
109    /// The length of payload in bytes.
110    payload_byte_len: usize,
111    /// Content of this AvidMShare.
112    content: RawAvidMShare,
113}
114
115/// Public parameters of the AVID-M scheme.
116#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
117pub struct AvidMParam {
118    /// Total weights of all storage nodes
119    pub total_weights: usize,
120    /// Minimum collective weights required to recover the original payload.
121    pub recovery_threshold: usize,
122}
123
124impl AvidMParam {
125    /// Construct a new [`AvidMParam`].
126    pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
127        if recovery_threshold == 0 || total_weights < recovery_threshold {
128            return Err(VidError::InvalidParam);
129        }
130        Ok(Self {
131            total_weights,
132            recovery_threshold,
133        })
134    }
135}
136
137/// Helper: initialize a FFT domain
138#[inline]
139fn radix2_domain<F: PrimeField>(domain_size: usize) -> VidResult<Radix2EvaluationDomain<F>> {
140    Radix2EvaluationDomain::<F>::new(domain_size).ok_or_else(|| VidError::InvalidParam)
141}
142
143/// Dummy struct for AVID-M scheme.
144pub struct AvidMScheme;
145
146impl AvidMScheme {
147    /// Setup an instance for AVID-M scheme
148    pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidMParam> {
149        AvidMParam::new(recovery_threshold, total_weights)
150    }
151}
152
153impl AvidMScheme {
154    /// Helper function.
155    /// Transform the payload bytes into a list of fields elements.
156    /// This function also pads the bytes with a 1 in the end, following by many 0's
157    /// until the length of the output is a multiple of `param.recovery_threshold`.
158    fn pad_to_fields(param: &AvidMParam, payload: &[u8]) -> Vec<F> {
159        // The number of bytes that can be encoded into a single F element.
160        let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
161
162        // A "chunk" is a byte slice whose size holds exactly `recovery_threshold`
163        // F elements.
164        let num_bytes_per_chunk = param.recovery_threshold * elem_bytes_len;
165
166        let remainder = (payload.len() + 1) % num_bytes_per_chunk;
167        let pad_num_zeros = (num_bytes_per_chunk - remainder) % num_bytes_per_chunk;
168
169        // Pad the payload with a 1 and many 0's.
170        bytes_to_field::<_, F>(
171            payload
172                .iter()
173                .chain(iter::once(&1u8))
174                .chain(iter::repeat_n(&0u8, pad_num_zeros)),
175        )
176        .collect()
177    }
178
179    /// Helper function.
180    /// Let `k = recovery_threshold` and `n = total_weights`. This function
181    /// partition the `payload` into many chunks, each containing `k` field
182    /// elements. Then each chunk is encoded into `n` field element with Reed
183    /// Solomon erasure code. They are then re-organized as `n` vectors, each
184    /// collecting one field element from each chunk. These `n` vectors are
185    /// then Merklized for commitment and membership proof generation.
186    #[allow(clippy::type_complexity)]
187    #[inline]
188    fn raw_encode(param: &AvidMParam, payload: &[F]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
189        let domain = radix2_domain::<F>(param.total_weights)?; // See docs at `domains`.
190
191        let encoding_timer = start_timer!(|| "Encoding payload");
192
193        // RS-encode each chunk
194        let codewords: Vec<_> = payload
195            .par_chunks(param.recovery_threshold)
196            .map(|chunk| {
197                let mut fft_vec = domain.fft(chunk); // RS-encode the chunk
198                fft_vec.truncate(param.total_weights); // truncate the useless evaluations
199                fft_vec
200            })
201            .collect();
202        // Generate `total_weights` raw shares. Each share collects one field element
203        // from each encode chunk.
204        let raw_shares: Vec<_> = (0..param.total_weights)
205            .into_par_iter()
206            .map(|i| codewords.iter().map(|v| v[i]).collect::<Vec<F>>())
207            .collect();
208        end_timer!(encoding_timer);
209
210        let hash_timer = start_timer!(|| "Compressing each raw share");
211        let compressed_raw_shares = raw_shares
212            .par_iter()
213            .map(|v| Config::raw_share_digest(v))
214            .collect::<Result<Vec<_>, _>>()?;
215        end_timer!(hash_timer);
216
217        let mt_timer = start_timer!(|| "Constructing Merkle tree");
218        let mt = MerkleTree::from_elems(None, &compressed_raw_shares)?;
219        end_timer!(mt_timer);
220
221        Ok((mt, raw_shares))
222    }
223
224    /// Short hand for `pad_to_field` and `raw_encode`.
225    fn pad_and_encode(param: &AvidMParam, payload: &[u8]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
226        let payload = Self::pad_to_fields(param, payload);
227        Self::raw_encode(param, &payload)
228    }
229
230    /// Consume in the constructed Merkle tree and the raw shares from `raw_encode`, provide the AvidM commitment and shares.
231    fn distribute_shares(
232        param: &AvidMParam,
233        distribution: &[u32],
234        mt: MerkleTree,
235        raw_shares: Vec<Vec<F>>,
236        payload_byte_len: usize,
237    ) -> VidResult<(AvidMCommit, Vec<AvidMShare>)> {
238        // let payload_byte_len = payload.len();
239        let total_weights = distribution.iter().sum::<u32>() as usize;
240        if total_weights != param.total_weights {
241            return Err(VidError::Argument(
242                "Weight distribution is inconsistent with the given param".to_string(),
243            ));
244        }
245        if distribution.contains(&0u32) {
246            return Err(VidError::Argument("Weight cannot be zero".to_string()));
247        }
248
249        let distribute_timer = start_timer!(|| "Distribute codewords to the storage nodes");
250        // Distribute the raw shares to each storage node according to the weight
251        // distribution. For each chunk, storage `i` gets `distribution[i]`
252        // consecutive raw shares ranging as `ranges[i]`.
253        let ranges: Vec<_> = distribution
254            .iter()
255            .scan(0, |sum, w| {
256                let prefix_sum = *sum;
257                *sum += w;
258                Some(prefix_sum as usize..*sum as usize)
259            })
260            .collect();
261        let shares: Vec<_> = ranges
262            .par_iter()
263            .map(|range| {
264                range
265                    .clone()
266                    .map(|k| raw_shares[k].to_owned())
267                    .collect::<Vec<_>>()
268            })
269            .collect();
270        end_timer!(distribute_timer);
271
272        let mt_proof_timer = start_timer!(|| "Generate Merkle tree proofs");
273        let shares = shares
274            .into_iter()
275            .enumerate()
276            .map(|(i, payload)| AvidMShare {
277                index: i as u32,
278                payload_byte_len,
279                content: RawAvidMShare {
280                    range: ranges[i].clone(),
281                    payload,
282                    mt_proofs: ranges[i]
283                        .clone()
284                        .map(|k| {
285                            mt.lookup(k as u64)
286                                .expect_ok()
287                                .expect("MT lookup shouldn't fail")
288                                .1
289                        })
290                        .collect::<Vec<_>>(),
291                },
292            })
293            .collect::<Vec<_>>();
294        end_timer!(mt_proof_timer);
295
296        let commit = AvidMCommit {
297            commit: mt.commitment(),
298        };
299
300        Ok((commit, shares))
301    }
302
303    pub(crate) fn verify_internal(
304        param: &AvidMParam,
305        commit: &AvidMCommit,
306        share: &RawAvidMShare,
307    ) -> VidResult<crate::VerificationResult> {
308        if share.range.end > param.total_weights
309            || share.range.len() != share.payload.len()
310            || share.range.len() != share.mt_proofs.len()
311        {
312            return Err(VidError::InvalidShare);
313        }
314        for (i, index) in share.range.clone().enumerate() {
315            let compressed_payload = Config::raw_share_digest(&share.payload[i])?;
316            if MerkleTree::verify(
317                commit.commit,
318                index as u64,
319                compressed_payload,
320                &share.mt_proofs[i],
321            )?
322            .is_err()
323            {
324                return Ok(Err(()));
325            }
326        }
327        Ok(Ok(()))
328    }
329
330    pub(crate) fn recover_fields(param: &AvidMParam, shares: &[AvidMShare]) -> VidResult<Vec<F>> {
331        let recovery_threshold: usize = param.recovery_threshold;
332
333        // Each share's payload contains some evaluations from `num_polys`
334        // polynomials.
335        let num_polys = shares
336            .iter()
337            .find(|s| !s.content.payload.is_empty())
338            .ok_or(VidError::Argument("All shares are empty".to_string()))?
339            .content
340            .payload[0]
341            .len();
342
343        let mut raw_shares = HashMap::new();
344        for share in shares {
345            if share.content.range.len() != share.content.payload.len()
346                || share.content.range.end > param.total_weights
347            {
348                return Err(VidError::InvalidShare);
349            }
350            for (i, p) in share.content.range.clone().zip(&share.content.payload) {
351                if p.len() != num_polys {
352                    return Err(VidError::InvalidShare);
353                }
354                if raw_shares.contains_key(&i) {
355                    return Err(VidError::InvalidShare);
356                }
357                raw_shares.insert(i, p);
358                if raw_shares.len() >= recovery_threshold {
359                    break;
360                }
361            }
362            if raw_shares.len() >= recovery_threshold {
363                break;
364            }
365        }
366
367        if raw_shares.len() < recovery_threshold {
368            return Err(VidError::InsufficientShares);
369        }
370
371        let domain = radix2_domain::<F>(param.total_weights)?;
372
373        // Lagrange interpolation
374        // step 1: find all evaluation points and their raw shares
375        let (x, raw_shares): (Vec<_>, Vec<_>) = raw_shares
376            .into_iter()
377            .map(|(i, p)| (domain.element(i), p))
378            .unzip();
379        // step 2: interpolate each polynomial
380        Ok((0..num_polys)
381            .into_par_iter()
382            .map(|poly_index| {
383                jf_utils::reed_solomon_code::reed_solomon_erasure_decode(
384                    x.iter().zip(raw_shares.iter().map(|p| p[poly_index])),
385                    recovery_threshold,
386                )
387                .map_err(|err| VidError::Internal(err.into()))
388            })
389            .collect::<Result<Vec<_>, _>>()?
390            .into_iter()
391            .flatten()
392            .collect())
393    }
394}
395
396impl VidScheme for AvidMScheme {
397    type Param = AvidMParam;
398
399    type Share = AvidMShare;
400
401    type Commit = AvidMCommit;
402
403    fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
404        let (mt, _) = Self::pad_and_encode(param, payload)?;
405        Ok(AvidMCommit {
406            commit: mt.commitment(),
407        })
408    }
409
410    fn disperse(
411        param: &Self::Param,
412        distribution: &[u32],
413        payload: &[u8],
414    ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
415        let (mt, raw_shares) = Self::pad_and_encode(param, payload)?;
416        Self::distribute_shares(param, distribution, mt, raw_shares, payload.len())
417    }
418
419    fn verify_share(
420        param: &Self::Param,
421        commit: &Self::Commit,
422        share: &Self::Share,
423    ) -> VidResult<crate::VerificationResult> {
424        Self::verify_internal(param, commit, &share.content)
425    }
426
427    /// Recover payload data from shares.
428    ///
429    /// # Requirements
430    /// - Total weight of all shares must be at least `recovery_threshold`.
431    /// - Each share's `payload` must have equal length.
432    /// - All shares must be verified under the given commitment.
433    ///
434    /// Shares beyond `recovery_threshold` are ignored.
435    fn recover(
436        param: &Self::Param,
437        _commit: &Self::Commit,
438        shares: &[Self::Share],
439    ) -> VidResult<Vec<u8>> {
440        let mut bytes: Vec<u8> = field_to_bytes(Self::recover_fields(param, shares)?).collect();
441        // Remove the trimming zeros and the last 1 to get the actual payload bytes.
442        // See `pad_to_fields`.
443        if let Some(pad_index) = bytes.iter().rposition(|&b| b != 0) {
444            if bytes[pad_index] == 1u8 {
445                bytes.truncate(pad_index);
446                return Ok(bytes);
447            }
448        }
449        Err(VidError::Argument(
450            "Malformed payload, cannot find the padding position".to_string(),
451        ))
452    }
453}
454
455/// Unit tests
456#[cfg(test)]
457pub mod tests {
458    use rand::{seq::SliceRandom, RngCore};
459
460    use super::F;
461    use crate::{avidm::AvidMScheme, utils::bytes_to_field, VidScheme};
462
463    #[test]
464    fn test_padding() {
465        let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
466        let param = AvidMScheme::setup(2usize, 5usize).unwrap();
467        let bytes = vec![2u8; 1];
468        let padded = AvidMScheme::pad_to_fields(&param, &bytes);
469        assert_eq!(padded.len(), 2usize);
470        assert_eq!(padded, [F::from(2u32 + u8::MAX as u32 + 1), F::from(0)]);
471
472        let bytes = vec![2u8; elem_bytes_len * 2];
473        let padded = AvidMScheme::pad_to_fields(&param, &bytes);
474        assert_eq!(padded.len(), 4usize);
475    }
476
477    #[test]
478    fn round_trip() {
479        // play with these items
480        let params_list = [(2, 4), (3, 9), (5, 6), (15, 16)];
481        let payload_byte_lens = [1, 31, 32, 500];
482
483        // more items as a function of the above
484
485        let mut rng = jf_utils::test_rng();
486
487        for (recovery_threshold, num_storage_nodes) in params_list {
488            let weights: Vec<u32> = (0..num_storage_nodes)
489                .map(|_| rng.next_u32() % 5 + 1)
490                .collect();
491            let total_weights: u32 = weights.iter().sum();
492            let params = AvidMScheme::setup(recovery_threshold, total_weights as usize).unwrap();
493
494            for payload_byte_len in payload_byte_lens {
495                println!(
496                    "recovery_threshold:: {recovery_threshold} num_storage_nodes: \
497                     {num_storage_nodes} payload_byte_len: {payload_byte_len}"
498                );
499                println!("weights: {weights:?}");
500
501                let payload = {
502                    let mut bytes_random = vec![0u8; payload_byte_len];
503                    rng.fill_bytes(&mut bytes_random);
504                    bytes_random
505                };
506
507                let (commit, mut shares) =
508                    AvidMScheme::disperse(&params, &weights, &payload).unwrap();
509
510                assert_eq!(shares.len(), num_storage_nodes);
511
512                // verify shares
513                shares.iter().for_each(|share| {
514                    assert!(
515                        AvidMScheme::verify_share(&params, &commit, share).is_ok_and(|r| r.is_ok())
516                    )
517                });
518
519                // test payload recovery on a random subset of shares
520                shares.shuffle(&mut rng);
521                let mut cumulated_weights = 0;
522                let mut cut_index = 0;
523                while cumulated_weights <= recovery_threshold {
524                    cumulated_weights += shares[cut_index].content.range.len();
525                    cut_index += 1;
526                }
527                let payload_recovered =
528                    AvidMScheme::recover(&params, &commit, &shares[..cut_index]).unwrap();
529                assert_eq!(payload_recovered, payload);
530            }
531        }
532    }
533
534    #[test]
535    #[cfg(feature = "print-trace")]
536    fn round_trip_breakdown() {
537        use ark_std::{end_timer, start_timer};
538
539        let mut rng = jf_utils::test_rng();
540
541        let params = AvidMScheme::setup(50usize, 200usize).unwrap();
542        let weights = vec![2u32; 100usize];
543        let payload_byte_len = 1024 * 1024 * 32; // 32MB
544
545        let payload = {
546            let mut bytes_random = vec![0u8; payload_byte_len];
547            rng.fill_bytes(&mut bytes_random);
548            bytes_random
549        };
550
551        let disperse_timer = start_timer!(|| format!("Disperse {} bytes", payload_byte_len));
552        let (commit, shares) = AvidMScheme::disperse(&params, &weights, &payload).unwrap();
553        end_timer!(disperse_timer);
554
555        let recover_timer = start_timer!(|| "Recovery");
556        AvidMScheme::recover(&params, &commit, &shares).unwrap();
557        end_timer!(recover_timer);
558    }
559}