1use std::{collections::HashMap, iter, ops::Range};
15
16use ark_ff::PrimeField;
17use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
18use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
19use ark_std::{end_timer, start_timer};
20use config::AvidMConfig;
21use jf_merkle_tree::MerkleTreeScheme;
22use jf_utils::canonical;
23use p3_maybe_rayon::prelude::{
24 IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSlice,
25};
26use serde::{Deserialize, Serialize};
27use tagged_base64::tagged;
28
29use crate::{
30 utils::bytes_to_field::{self, bytes_to_field, field_to_bytes},
31 VidError, VidResult, VidScheme,
32};
33
34mod config;
35
36pub mod namespaced;
37pub mod proofs;
38
39#[cfg(all(not(feature = "sha256"), not(feature = "keccak256")))]
40type Config = config::Poseidon2Config;
41#[cfg(feature = "sha256")]
42type Config = config::Sha256Config;
43#[cfg(feature = "keccak256")]
44type Config = config::Keccak256Config;
45
46type F = <Config as AvidMConfig>::BaseField;
48type MerkleTree = <Config as AvidMConfig>::MerkleTree;
49type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
50type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
51
52#[derive(
54 Clone,
55 Copy,
56 Debug,
57 Default,
58 Hash,
59 CanonicalSerialize,
60 CanonicalDeserialize,
61 Eq,
62 PartialEq,
63 Ord,
64 PartialOrd,
65)]
66#[tagged("AvidMCommit")]
67#[repr(C)]
68pub struct AvidMCommit {
69 pub commit: MerkleCommit,
71}
72
73impl AsRef<[u8]> for AvidMCommit {
74 fn as_ref(&self) -> &[u8] {
75 unsafe {
76 ::core::slice::from_raw_parts(
77 (self as *const Self) as *const u8,
78 ::core::mem::size_of::<Self>(),
79 )
80 }
81 }
82}
83
84impl AsRef<[u8; 32]> for AvidMCommit {
85 fn as_ref(&self) -> &[u8; 32] {
86 unsafe { ::core::slice::from_raw_parts((self as *const Self) as *const u8, 32) }
87 .try_into()
88 .unwrap()
89 }
90}
91
92#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
94pub struct RawAvidMShare {
95 range: Range<usize>,
97 #[serde(with = "canonical")]
99 payload: Vec<Vec<F>>,
100 #[serde(with = "canonical")]
102 mt_proofs: Vec<MerkleProof>,
103}
104
105#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
107pub struct AvidMShare {
108 index: u32,
110 payload_byte_len: usize,
112 content: RawAvidMShare,
114}
115
116#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
118pub struct AvidMParam {
119 pub total_weights: usize,
121 pub recovery_threshold: usize,
123}
124
125impl AvidMParam {
126 pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
128 if recovery_threshold == 0 || total_weights < recovery_threshold {
129 return Err(VidError::InvalidParam);
130 }
131 Ok(Self {
132 total_weights,
133 recovery_threshold,
134 })
135 }
136}
137
138#[inline]
140fn radix2_domain<F: PrimeField>(domain_size: usize) -> VidResult<Radix2EvaluationDomain<F>> {
141 Radix2EvaluationDomain::<F>::new(domain_size).ok_or_else(|| VidError::InvalidParam)
142}
143
144pub struct AvidMScheme;
146
147impl AvidMScheme {
148 pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidMParam> {
150 AvidMParam::new(recovery_threshold, total_weights)
151 }
152}
153
154impl AvidMScheme {
155 fn pad_to_fields(param: &AvidMParam, payload: &[u8]) -> Vec<F> {
160 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
162
163 let num_bytes_per_chunk = param.recovery_threshold * elem_bytes_len;
166
167 let remainder = (payload.len() + 1) % num_bytes_per_chunk;
168 let pad_num_zeros = (num_bytes_per_chunk - remainder) % num_bytes_per_chunk;
169
170 bytes_to_field::<_, F>(
172 payload
173 .iter()
174 .chain(iter::once(&1u8))
175 .chain(iter::repeat_n(&0u8, pad_num_zeros)),
176 )
177 .collect()
178 }
179
180 #[allow(clippy::type_complexity)]
188 #[inline]
189 fn raw_encode(param: &AvidMParam, payload: &[F]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
190 let domain = radix2_domain::<F>(param.total_weights)?; let encoding_timer = start_timer!(|| "Encoding payload");
193
194 let codewords: Vec<_> = payload
196 .par_chunks(param.recovery_threshold)
197 .map(|chunk| {
198 let mut fft_vec = domain.fft(chunk); fft_vec.truncate(param.total_weights); fft_vec
201 })
202 .collect();
203 let raw_shares: Vec<_> = (0..param.total_weights)
206 .into_par_iter()
207 .map(|i| codewords.iter().map(|v| v[i]).collect::<Vec<F>>())
208 .collect();
209 end_timer!(encoding_timer);
210
211 let hash_timer = start_timer!(|| "Compressing each raw share");
212 let compressed_raw_shares = raw_shares
213 .par_iter()
214 .map(|v| Config::raw_share_digest(v))
215 .collect::<Result<Vec<_>, _>>()?;
216 end_timer!(hash_timer);
217
218 let mt_timer = start_timer!(|| "Constructing Merkle tree");
219 let mt = MerkleTree::from_elems(None, &compressed_raw_shares)?;
220 end_timer!(mt_timer);
221
222 Ok((mt, raw_shares))
223 }
224
225 fn pad_and_encode(param: &AvidMParam, payload: &[u8]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
227 let payload = Self::pad_to_fields(param, payload);
228 Self::raw_encode(param, &payload)
229 }
230
231 fn distribute_shares(
233 param: &AvidMParam,
234 distribution: &[u32],
235 mt: MerkleTree,
236 raw_shares: Vec<Vec<F>>,
237 payload_byte_len: usize,
238 ) -> VidResult<(AvidMCommit, Vec<AvidMShare>)> {
239 let total_weights = distribution.iter().sum::<u32>() as usize;
241 if total_weights != param.total_weights {
242 return Err(VidError::Argument(
243 "Weight distribution is inconsistent with the given param".to_string(),
244 ));
245 }
246 if distribution.contains(&0u32) {
247 return Err(VidError::Argument("Weight cannot be zero".to_string()));
248 }
249
250 let distribute_timer = start_timer!(|| "Distribute codewords to the storage nodes");
251 let ranges: Vec<_> = distribution
255 .iter()
256 .scan(0, |sum, w| {
257 let prefix_sum = *sum;
258 *sum += w;
259 Some(prefix_sum as usize..*sum as usize)
260 })
261 .collect();
262 let shares: Vec<_> = ranges
263 .par_iter()
264 .map(|range| {
265 range
266 .clone()
267 .map(|k| raw_shares[k].to_owned())
268 .collect::<Vec<_>>()
269 })
270 .collect();
271 end_timer!(distribute_timer);
272
273 let mt_proof_timer = start_timer!(|| "Generate Merkle tree proofs");
274 let shares = shares
275 .into_iter()
276 .enumerate()
277 .map(|(i, payload)| AvidMShare {
278 index: i as u32,
279 payload_byte_len,
280 content: RawAvidMShare {
281 range: ranges[i].clone(),
282 payload,
283 mt_proofs: ranges[i]
284 .clone()
285 .map(|k| {
286 mt.lookup(k as u64)
287 .expect_ok()
288 .expect("MT lookup shouldn't fail")
289 .1
290 })
291 .collect::<Vec<_>>(),
292 },
293 })
294 .collect::<Vec<_>>();
295 end_timer!(mt_proof_timer);
296
297 let commit = AvidMCommit {
298 commit: mt.commitment(),
299 };
300
301 Ok((commit, shares))
302 }
303
304 pub(crate) fn verify_internal(
305 param: &AvidMParam,
306 commit: &AvidMCommit,
307 share: &RawAvidMShare,
308 ) -> VidResult<crate::VerificationResult> {
309 if share.range.end > param.total_weights
310 || share.range.len() != share.payload.len()
311 || share.range.len() != share.mt_proofs.len()
312 {
313 return Err(VidError::InvalidShare);
314 }
315 for (i, index) in share.range.clone().enumerate() {
316 let compressed_payload = Config::raw_share_digest(&share.payload[i])?;
317 if MerkleTree::verify(
318 commit.commit,
319 index as u64,
320 compressed_payload,
321 &share.mt_proofs[i],
322 )?
323 .is_err()
324 {
325 return Ok(Err(()));
326 }
327 }
328 Ok(Ok(()))
329 }
330
331 pub(crate) fn recover_fields(param: &AvidMParam, shares: &[AvidMShare]) -> VidResult<Vec<F>> {
332 let recovery_threshold: usize = param.recovery_threshold;
333
334 let num_polys = shares
337 .iter()
338 .find(|s| !s.content.payload.is_empty())
339 .ok_or(VidError::Argument("All shares are empty".to_string()))?
340 .content
341 .payload[0]
342 .len();
343
344 let mut raw_shares = HashMap::new();
345 for share in shares {
346 if share.content.range.len() != share.content.payload.len()
347 || share.content.range.end > param.total_weights
348 {
349 return Err(VidError::InvalidShare);
350 }
351 for (i, p) in share.content.range.clone().zip(&share.content.payload) {
352 if p.len() != num_polys {
353 return Err(VidError::InvalidShare);
354 }
355 if raw_shares.contains_key(&i) {
356 return Err(VidError::InvalidShare);
357 }
358 raw_shares.insert(i, p);
359 if raw_shares.len() >= recovery_threshold {
360 break;
361 }
362 }
363 if raw_shares.len() >= recovery_threshold {
364 break;
365 }
366 }
367
368 if raw_shares.len() < recovery_threshold {
369 return Err(VidError::InsufficientShares);
370 }
371
372 let domain = radix2_domain::<F>(param.total_weights)?;
373
374 let (x, raw_shares): (Vec<_>, Vec<_>) = raw_shares
377 .into_iter()
378 .map(|(i, p)| (domain.element(i), p))
379 .unzip();
380 Ok((0..num_polys)
382 .into_par_iter()
383 .map(|poly_index| {
384 jf_utils::reed_solomon_code::reed_solomon_erasure_decode(
385 x.iter().zip(raw_shares.iter().map(|p| p[poly_index])),
386 recovery_threshold,
387 )
388 .map_err(|err| VidError::Internal(err.into()))
389 })
390 .collect::<Result<Vec<_>, _>>()?
391 .into_iter()
392 .flatten()
393 .collect())
394 }
395}
396
397impl VidScheme for AvidMScheme {
398 type Param = AvidMParam;
399
400 type Share = AvidMShare;
401
402 type Commit = AvidMCommit;
403
404 fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
405 let (mt, _) = Self::pad_and_encode(param, payload)?;
406 Ok(AvidMCommit {
407 commit: mt.commitment(),
408 })
409 }
410
411 fn disperse(
412 param: &Self::Param,
413 distribution: &[u32],
414 payload: &[u8],
415 ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
416 let (mt, raw_shares) = Self::pad_and_encode(param, payload)?;
417 Self::distribute_shares(param, distribution, mt, raw_shares, payload.len())
418 }
419
420 fn verify_share(
421 param: &Self::Param,
422 commit: &Self::Commit,
423 share: &Self::Share,
424 ) -> VidResult<crate::VerificationResult> {
425 Self::verify_internal(param, commit, &share.content)
426 }
427
428 fn recover(
437 param: &Self::Param,
438 _commit: &Self::Commit,
439 shares: &[Self::Share],
440 ) -> VidResult<Vec<u8>> {
441 let mut bytes: Vec<u8> = field_to_bytes(Self::recover_fields(param, shares)?).collect();
442 if let Some(pad_index) = bytes.iter().rposition(|&b| b != 0) {
445 if bytes[pad_index] == 1u8 {
446 bytes.truncate(pad_index);
447 return Ok(bytes);
448 }
449 }
450 Err(VidError::Argument(
451 "Malformed payload, cannot find the padding position".to_string(),
452 ))
453 }
454}
455
456#[cfg(test)]
458pub mod tests {
459 use rand::{seq::SliceRandom, RngCore};
460
461 use super::F;
462 use crate::{avidm::AvidMScheme, utils::bytes_to_field, VidScheme};
463
464 #[test]
465 fn test_padding() {
466 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
467 let param = AvidMScheme::setup(2usize, 5usize).unwrap();
468 let bytes = vec![2u8; 1];
469 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
470 assert_eq!(padded.len(), 2usize);
471 assert_eq!(padded, [F::from(2u32 + u8::MAX as u32 + 1), F::from(0)]);
472
473 let bytes = vec![2u8; elem_bytes_len * 2];
474 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
475 assert_eq!(padded.len(), 4usize);
476 }
477
478 #[test]
479 fn round_trip() {
480 let params_list = [(2, 4), (3, 9), (5, 6), (15, 16)];
482 let payload_byte_lens = [1, 31, 32, 500];
483
484 let mut rng = jf_utils::test_rng();
487
488 for (recovery_threshold, num_storage_nodes) in params_list {
489 let weights: Vec<u32> = (0..num_storage_nodes)
490 .map(|_| rng.next_u32() % 5 + 1)
491 .collect();
492 let total_weights: u32 = weights.iter().sum();
493 let params = AvidMScheme::setup(recovery_threshold, total_weights as usize).unwrap();
494
495 for payload_byte_len in payload_byte_lens {
496 println!(
497 "recovery_threshold:: {recovery_threshold} num_storage_nodes: \
498 {num_storage_nodes} payload_byte_len: {payload_byte_len}"
499 );
500 println!("weights: {weights:?}");
501
502 let payload = {
503 let mut bytes_random = vec![0u8; payload_byte_len];
504 rng.fill_bytes(&mut bytes_random);
505 bytes_random
506 };
507
508 let (commit, mut shares) =
509 AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
510
511 assert_eq!(shares.len(), num_storage_nodes);
512
513 shares.iter().for_each(|share| {
515 assert!(
516 AvidMScheme::verify_share(¶ms, &commit, share).is_ok_and(|r| r.is_ok())
517 )
518 });
519
520 shares.shuffle(&mut rng);
522 let mut cumulated_weights = 0;
523 let mut cut_index = 0;
524 while cumulated_weights <= recovery_threshold {
525 cumulated_weights += shares[cut_index].content.range.len();
526 cut_index += 1;
527 }
528 let payload_recovered =
529 AvidMScheme::recover(¶ms, &commit, &shares[..cut_index]).unwrap();
530 assert_eq!(payload_recovered, payload);
531 }
532 }
533 }
534
535 #[test]
536 #[cfg(feature = "print-trace")]
537 fn round_trip_breakdown() {
538 use ark_std::{end_timer, start_timer};
539
540 let mut rng = jf_utils::test_rng();
541
542 let params = AvidMScheme::setup(50usize, 200usize).unwrap();
543 let weights = vec![2u32; 100usize];
544 let payload_byte_len = 1024 * 1024 * 32; let payload = {
547 let mut bytes_random = vec![0u8; payload_byte_len];
548 rng.fill_bytes(&mut bytes_random);
549 bytes_random
550 };
551
552 let disperse_timer = start_timer!(|| format!("Disperse {} bytes", payload_byte_len));
553 let (commit, shares) = AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
554 end_timer!(disperse_timer);
555
556 let recover_timer = start_timer!(|| "Recovery");
557 AvidMScheme::recover(¶ms, &commit, &shares).unwrap();
558 end_timer!(recover_timer);
559 }
560}