1use std::{collections::HashMap, iter, ops::Range};
15
16use ark_ff::PrimeField;
17use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
18use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
19use ark_std::{end_timer, start_timer};
20use config::AvidMConfig;
21use jf_merkle_tree::MerkleTreeScheme;
22use jf_utils::canonical;
23use p3_maybe_rayon::prelude::{
24 IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSlice,
25};
26use serde::{Deserialize, Serialize};
27use tagged_base64::tagged;
28
29use crate::{
30 utils::bytes_to_field::{self, bytes_to_field, field_to_bytes},
31 VidError, VidResult, VidScheme,
32};
33
34mod config;
35
36pub mod namespaced;
37pub mod proofs;
38
39#[cfg(all(not(feature = "sha256"), not(feature = "keccak256")))]
40type Config = config::Poseidon2Config;
41#[cfg(feature = "sha256")]
42type Config = config::Sha256Config;
43#[cfg(feature = "keccak256")]
44type Config = config::Keccak256Config;
45
46type F = <Config as AvidMConfig>::BaseField;
48type MerkleTree = <Config as AvidMConfig>::MerkleTree;
49type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
50type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
51
52#[derive(
54 Clone,
55 Copy,
56 Debug,
57 Hash,
58 CanonicalSerialize,
59 CanonicalDeserialize,
60 Eq,
61 PartialEq,
62 Ord,
63 PartialOrd,
64)]
65#[tagged("AvidMCommit")]
66#[repr(C)]
67pub struct AvidMCommit {
68 pub commit: MerkleCommit,
70}
71
72impl AsRef<[u8]> for AvidMCommit {
73 fn as_ref(&self) -> &[u8] {
74 unsafe {
75 ::core::slice::from_raw_parts(
76 (self as *const Self) as *const u8,
77 ::core::mem::size_of::<Self>(),
78 )
79 }
80 }
81}
82
83impl AsRef<[u8; 32]> for AvidMCommit {
84 fn as_ref(&self) -> &[u8; 32] {
85 unsafe { ::core::slice::from_raw_parts((self as *const Self) as *const u8, 32) }
86 .try_into()
87 .unwrap()
88 }
89}
90
91#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
93pub struct RawAvidMShare {
94 range: Range<usize>,
96 #[serde(with = "canonical")]
98 payload: Vec<Vec<F>>,
99 #[serde(with = "canonical")]
101 mt_proofs: Vec<MerkleProof>,
102}
103
104#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
106pub struct AvidMShare {
107 index: u32,
109 payload_byte_len: usize,
111 content: RawAvidMShare,
113}
114
115#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
117pub struct AvidMParam {
118 pub total_weights: usize,
120 pub recovery_threshold: usize,
122}
123
124impl AvidMParam {
125 pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
127 if recovery_threshold == 0 || total_weights < recovery_threshold {
128 return Err(VidError::InvalidParam);
129 }
130 Ok(Self {
131 total_weights,
132 recovery_threshold,
133 })
134 }
135}
136
137#[inline]
139fn radix2_domain<F: PrimeField>(domain_size: usize) -> VidResult<Radix2EvaluationDomain<F>> {
140 Radix2EvaluationDomain::<F>::new(domain_size).ok_or_else(|| VidError::InvalidParam)
141}
142
143pub struct AvidMScheme;
145
146impl AvidMScheme {
147 pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidMParam> {
149 AvidMParam::new(recovery_threshold, total_weights)
150 }
151}
152
153impl AvidMScheme {
154 fn pad_to_fields(param: &AvidMParam, payload: &[u8]) -> Vec<F> {
159 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
161
162 let num_bytes_per_chunk = param.recovery_threshold * elem_bytes_len;
165
166 let remainder = (payload.len() + 1) % num_bytes_per_chunk;
167 let pad_num_zeros = (num_bytes_per_chunk - remainder) % num_bytes_per_chunk;
168
169 bytes_to_field::<_, F>(
171 payload
172 .iter()
173 .chain(iter::once(&1u8))
174 .chain(iter::repeat_n(&0u8, pad_num_zeros)),
175 )
176 .collect()
177 }
178
179 #[allow(clippy::type_complexity)]
187 #[inline]
188 fn raw_encode(param: &AvidMParam, payload: &[F]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
189 let domain = radix2_domain::<F>(param.total_weights)?; let encoding_timer = start_timer!(|| "Encoding payload");
192
193 let codewords: Vec<_> = payload
195 .par_chunks(param.recovery_threshold)
196 .map(|chunk| {
197 let mut fft_vec = domain.fft(chunk); fft_vec.truncate(param.total_weights); fft_vec
200 })
201 .collect();
202 let raw_shares: Vec<_> = (0..param.total_weights)
205 .into_par_iter()
206 .map(|i| codewords.iter().map(|v| v[i]).collect::<Vec<F>>())
207 .collect();
208 end_timer!(encoding_timer);
209
210 let hash_timer = start_timer!(|| "Compressing each raw share");
211 let compressed_raw_shares = raw_shares
212 .par_iter()
213 .map(|v| Config::raw_share_digest(v))
214 .collect::<Result<Vec<_>, _>>()?;
215 end_timer!(hash_timer);
216
217 let mt_timer = start_timer!(|| "Constructing Merkle tree");
218 let mt = MerkleTree::from_elems(None, &compressed_raw_shares)?;
219 end_timer!(mt_timer);
220
221 Ok((mt, raw_shares))
222 }
223
224 fn pad_and_encode(param: &AvidMParam, payload: &[u8]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
226 let payload = Self::pad_to_fields(param, payload);
227 Self::raw_encode(param, &payload)
228 }
229
230 fn distribute_shares(
232 param: &AvidMParam,
233 distribution: &[u32],
234 mt: MerkleTree,
235 raw_shares: Vec<Vec<F>>,
236 payload_byte_len: usize,
237 ) -> VidResult<(AvidMCommit, Vec<AvidMShare>)> {
238 let total_weights = distribution.iter().sum::<u32>() as usize;
240 if total_weights != param.total_weights {
241 return Err(VidError::Argument(
242 "Weight distribution is inconsistent with the given param".to_string(),
243 ));
244 }
245 if distribution.iter().any(|&w| w == 0) {
246 return Err(VidError::Argument("Weight cannot be zero".to_string()));
247 }
248
249 let distribute_timer = start_timer!(|| "Distribute codewords to the storage nodes");
250 let ranges: Vec<_> = distribution
254 .iter()
255 .scan(0, |sum, w| {
256 let prefix_sum = *sum;
257 *sum += w;
258 Some(prefix_sum as usize..*sum as usize)
259 })
260 .collect();
261 let shares: Vec<_> = ranges
262 .par_iter()
263 .map(|range| {
264 range
265 .clone()
266 .map(|k| raw_shares[k].to_owned())
267 .collect::<Vec<_>>()
268 })
269 .collect();
270 end_timer!(distribute_timer);
271
272 let mt_proof_timer = start_timer!(|| "Generate Merkle tree proofs");
273 let shares = shares
274 .into_iter()
275 .enumerate()
276 .map(|(i, payload)| AvidMShare {
277 index: i as u32,
278 payload_byte_len,
279 content: RawAvidMShare {
280 range: ranges[i].clone(),
281 payload,
282 mt_proofs: ranges[i]
283 .clone()
284 .map(|k| {
285 mt.lookup(k as u64)
286 .expect_ok()
287 .expect("MT lookup shouldn't fail")
288 .1
289 })
290 .collect::<Vec<_>>(),
291 },
292 })
293 .collect::<Vec<_>>();
294 end_timer!(mt_proof_timer);
295
296 let commit = AvidMCommit {
297 commit: mt.commitment(),
298 };
299
300 Ok((commit, shares))
301 }
302
303 pub(crate) fn verify_internal(
304 param: &AvidMParam,
305 commit: &AvidMCommit,
306 share: &RawAvidMShare,
307 ) -> VidResult<crate::VerificationResult> {
308 if share.range.end > param.total_weights || share.range.len() != share.payload.len() {
309 return Err(VidError::InvalidShare);
310 }
311 for (i, index) in share.range.clone().enumerate() {
312 let compressed_payload = Config::raw_share_digest(&share.payload[i])?;
313 if MerkleTree::verify(
314 commit.commit,
315 index as u64,
316 compressed_payload,
317 &share.mt_proofs[i],
318 )?
319 .is_err()
320 {
321 return Ok(Err(()));
322 }
323 }
324 Ok(Ok(()))
325 }
326
327 pub(crate) fn recover_fields(param: &AvidMParam, shares: &[AvidMShare]) -> VidResult<Vec<F>> {
328 let recovery_threshold: usize = param.recovery_threshold;
329
330 let num_polys = shares
333 .iter()
334 .find(|s| !s.content.payload.is_empty())
335 .ok_or(VidError::Argument("All shares are empty".to_string()))?
336 .content
337 .payload[0]
338 .len();
339
340 let mut raw_shares = HashMap::new();
341 for share in shares {
342 if share.content.range.len() != share.content.payload.len()
343 || share.content.range.end > param.total_weights
344 {
345 return Err(VidError::InvalidShare);
346 }
347 for (i, p) in share.content.range.clone().zip(&share.content.payload) {
348 if p.len() != num_polys {
349 return Err(VidError::InvalidShare);
350 }
351 if raw_shares.contains_key(&i) {
352 return Err(VidError::InvalidShare);
353 }
354 raw_shares.insert(i, p);
355 if raw_shares.len() >= recovery_threshold {
356 break;
357 }
358 }
359 if raw_shares.len() >= recovery_threshold {
360 break;
361 }
362 }
363
364 if raw_shares.len() < recovery_threshold {
365 return Err(VidError::InsufficientShares);
366 }
367
368 let domain = radix2_domain::<F>(param.total_weights)?;
369
370 let (x, raw_shares): (Vec<_>, Vec<_>) = raw_shares
373 .into_iter()
374 .map(|(i, p)| (domain.element(i), p))
375 .unzip();
376 Ok((0..num_polys)
378 .into_par_iter()
379 .map(|poly_index| {
380 jf_utils::reed_solomon_code::reed_solomon_erasure_decode(
381 x.iter().zip(raw_shares.iter().map(|p| p[poly_index])),
382 recovery_threshold,
383 )
384 .map_err(|err| VidError::Internal(err.into()))
385 })
386 .collect::<Result<Vec<_>, _>>()?
387 .into_iter()
388 .flatten()
389 .collect())
390 }
391}
392
393impl VidScheme for AvidMScheme {
394 type Param = AvidMParam;
395
396 type Share = AvidMShare;
397
398 type Commit = AvidMCommit;
399
400 fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
401 let (mt, _) = Self::pad_and_encode(param, payload)?;
402 Ok(AvidMCommit {
403 commit: mt.commitment(),
404 })
405 }
406
407 fn disperse(
408 param: &Self::Param,
409 distribution: &[u32],
410 payload: &[u8],
411 ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
412 let (mt, raw_shares) = Self::pad_and_encode(param, payload)?;
413 Self::distribute_shares(param, distribution, mt, raw_shares, payload.len())
414 }
415
416 fn verify_share(
417 param: &Self::Param,
418 commit: &Self::Commit,
419 share: &Self::Share,
420 ) -> VidResult<crate::VerificationResult> {
421 Self::verify_internal(param, commit, &share.content)
422 }
423
424 fn recover(
433 param: &Self::Param,
434 _commit: &Self::Commit,
435 shares: &[Self::Share],
436 ) -> VidResult<Vec<u8>> {
437 let mut bytes: Vec<u8> = field_to_bytes(Self::recover_fields(param, shares)?).collect();
438 if let Some(pad_index) = bytes.iter().rposition(|&b| b != 0) {
441 if bytes[pad_index] == 1u8 {
442 bytes.truncate(pad_index);
443 return Ok(bytes);
444 }
445 }
446 Err(VidError::Argument(
447 "Malformed payload, cannot find the padding position".to_string(),
448 ))
449 }
450}
451
452#[cfg(test)]
454pub mod tests {
455 use rand::{seq::SliceRandom, RngCore};
456
457 use super::F;
458 use crate::{avid_m::AvidMScheme, utils::bytes_to_field, VidScheme};
459
460 #[test]
461 fn test_padding() {
462 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
463 let param = AvidMScheme::setup(2usize, 5usize).unwrap();
464 let bytes = vec![2u8; 1];
465 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
466 assert_eq!(padded.len(), 2usize);
467 assert_eq!(padded, [F::from(2u32 + u8::MAX as u32 + 1), F::from(0)]);
468
469 let bytes = vec![2u8; elem_bytes_len * 2];
470 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
471 assert_eq!(padded.len(), 4usize);
472 }
473
474 #[test]
475 fn round_trip() {
476 let params_list = [(2, 4), (3, 9), (5, 6), (15, 16)];
478 let payload_byte_lens = [1, 31, 32, 500];
479
480 let mut rng = jf_utils::test_rng();
483
484 for (recovery_threshold, num_storage_nodes) in params_list {
485 let weights: Vec<u32> = (0..num_storage_nodes)
486 .map(|_| rng.next_u32() % 5 + 1)
487 .collect();
488 let total_weights: u32 = weights.iter().sum();
489 let params = AvidMScheme::setup(recovery_threshold, total_weights as usize).unwrap();
490
491 for payload_byte_len in payload_byte_lens {
492 println!(
493 "recovery_threshold:: {} num_storage_nodes: {} payload_byte_len: {}",
494 recovery_threshold, num_storage_nodes, payload_byte_len
495 );
496 println!("weights: {:?}", weights);
497
498 let payload = {
499 let mut bytes_random = vec![0u8; payload_byte_len];
500 rng.fill_bytes(&mut bytes_random);
501 bytes_random
502 };
503
504 let (commit, mut shares) =
505 AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
506
507 assert_eq!(shares.len(), num_storage_nodes);
508
509 shares.iter().for_each(|share| {
511 assert!(
512 AvidMScheme::verify_share(¶ms, &commit, share).is_ok_and(|r| r.is_ok())
513 )
514 });
515
516 shares.shuffle(&mut rng);
518 let mut cumulated_weights = 0;
519 let mut cut_index = 0;
520 while cumulated_weights <= recovery_threshold {
521 cumulated_weights += shares[cut_index].content.range.len();
522 cut_index += 1;
523 }
524 let payload_recovered =
525 AvidMScheme::recover(¶ms, &commit, &shares[..cut_index]).unwrap();
526 assert_eq!(payload_recovered, payload);
527 }
528 }
529 }
530
531 #[test]
532 #[cfg(feature = "print-trace")]
533 fn round_trip_breakdown() {
534 use ark_std::{end_timer, start_timer};
535
536 let mut rng = jf_utils::test_rng();
537
538 let params = AvidMScheme::setup(50usize, 200usize).unwrap();
539 let weights = vec![2u32; 100usize];
540 let payload_byte_len = 1024 * 1024 * 32; let payload = {
543 let mut bytes_random = vec![0u8; payload_byte_len];
544 rng.fill_bytes(&mut bytes_random);
545 bytes_random
546 };
547
548 let disperse_timer = start_timer!(|| format!("Disperse {} bytes", payload_byte_len));
549 let (commit, shares) = AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
550 end_timer!(disperse_timer);
551
552 let recover_timer = start_timer!(|| "Recovery");
553 AvidMScheme::recover(¶ms, &commit, &shares).unwrap();
554 end_timer!(recover_timer);
555 }
556}