1use std::{ops::Range, vec};
4
5use anyhow::anyhow;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use jf_merkle_tree::{hasher::HasherNode, MerkleTreeScheme};
8use jf_utils::canonical;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11use tagged_base64::tagged;
12
13use crate::{VidError, VidResult, VidScheme};
14
15pub mod namespaced;
17pub mod proofs;
19
20pub(crate) type MerkleTree =
22 jf_merkle_tree::hasher::HasherMerkleTree<sha3::Keccak256, HasherNode<sha3::Keccak256>>;
23type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
24type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
25
26pub struct AvidmGf2Scheme;
28
29#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
31pub struct AvidmGf2Param {
32 pub total_weights: usize,
34 pub recovery_threshold: usize,
36}
37
38impl AvidmGf2Param {
39 pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
41 if recovery_threshold == 0 || total_weights < recovery_threshold {
42 return Err(VidError::InvalidParam);
43 }
44 Ok(Self {
45 total_weights,
46 recovery_threshold,
47 })
48 }
49}
50
51#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
53pub struct AvidmGf2Share {
54 range: Range<usize>,
56 #[serde(with = "canonical")]
58 payload: Vec<Vec<u8>>,
59 #[serde(with = "canonical")]
61 mt_proofs: Vec<MerkleProof>,
62}
63
64impl AvidmGf2Share {
65 pub fn weight(&self) -> usize {
67 self.range.len()
68 }
69
70 pub fn validate(&self) -> bool {
72 self.payload.len() == self.range.len() && self.mt_proofs.len() == self.range.len()
73 }
74}
75
76#[derive(
78 Clone,
79 Copy,
80 Debug,
81 Default,
82 Hash,
83 CanonicalSerialize,
84 CanonicalDeserialize,
85 Eq,
86 PartialEq,
87 Ord,
88 PartialOrd,
89)]
90#[tagged("AvidmGf2Commit")]
91#[repr(C)]
92pub struct AvidmGf2Commit {
93 pub commit: MerkleCommit,
95}
96
97impl AsRef<[u8]> for AvidmGf2Commit {
98 fn as_ref(&self) -> &[u8] {
99 self.commit.as_ref()
100 }
101}
102
103impl AsRef<[u8; 32]> for AvidmGf2Commit {
104 fn as_ref(&self) -> &[u8; 32] {
105 <Self as AsRef<[u8]>>::as_ref(self)
106 .try_into()
107 .expect("AvidmGf2Commit is always 32 bytes")
108 }
109}
110
111impl AvidmGf2Scheme {
112 pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidmGf2Param> {
114 AvidmGf2Param::new(recovery_threshold, total_weights)
115 }
116
117 fn bit_padding(payload: &[u8], payload_len: usize) -> VidResult<Vec<u8>> {
118 if payload_len < payload.len() + 1 {
119 return Err(VidError::Argument(
120 "Payload length is too large to fit in the given payload length".to_string(),
121 ));
122 }
123 let mut padded = vec![0u8; payload_len];
124 padded[..payload.len()].copy_from_slice(payload);
125 padded[payload.len()] = 1u8;
126 Ok(padded)
127 }
128
129 fn raw_disperse(
130 param: &AvidmGf2Param,
131 payload: &[u8],
132 ) -> VidResult<(MerkleTree, Vec<Vec<u8>>)> {
133 let original_count = param.recovery_threshold;
134 let recovery_count = param.total_weights - param.recovery_threshold;
135 let mut shard_bytes = (payload.len() + 1).div_ceil(original_count);
137 if shard_bytes % 2 == 1 {
138 shard_bytes += 1;
139 }
140 let payload = Self::bit_padding(payload, shard_bytes * original_count)?;
141 let original = payload
142 .chunks(shard_bytes)
143 .map(|chunk| chunk.to_owned())
144 .collect::<Vec<_>>();
145 let recovery = if recovery_count == 0 {
146 vec![]
147 } else {
148 reed_solomon_simd::encode(original_count, recovery_count, &original)?
149 };
150
151 let shares = [original, recovery].concat();
152 let share_digests: Vec<_> = shares
153 .iter()
154 .map(|share| HasherNode::from(sha3::Keccak256::digest(share)))
155 .collect();
156 let mt = MerkleTree::from_elems(None, &share_digests)?;
157 Ok((mt, shares))
158 }
159}
160
161impl VidScheme for AvidmGf2Scheme {
162 type Param = AvidmGf2Param;
163 type Share = AvidmGf2Share;
164 type Commit = AvidmGf2Commit;
165
166 fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
167 let (mt, _) = Self::raw_disperse(param, payload)?;
168 Ok(Self::Commit {
169 commit: mt.commitment(),
170 })
171 }
172
173 fn disperse(
174 param: &Self::Param,
175 distribution: &[u32],
176 payload: &[u8],
177 ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
178 let (mt, shares) = Self::raw_disperse(param, payload)?;
179 let commit = AvidmGf2Commit {
180 commit: mt.commitment(),
181 };
182 let ranges: Vec<_> = distribution
183 .iter()
184 .scan(0, |sum, w| {
185 let prefix_sum = *sum;
186 *sum += w;
187 Some(prefix_sum as usize..*sum as usize)
188 })
189 .collect();
190 let shares: Vec<_> = ranges
191 .into_iter()
192 .map(|range| AvidmGf2Share {
193 range: range.clone(),
194 payload: shares[range.clone()].to_vec(),
195 mt_proofs: range
197 .map(|k| {
198 mt.lookup(k as u64)
199 .expect_ok()
200 .expect("MT lookup shouldn't fail")
201 .1
202 })
203 .collect::<Vec<_>>(),
204 })
205 .collect();
206 Ok((commit, shares))
207 }
208
209 fn verify_share(
210 _param: &Self::Param,
211 commit: &Self::Commit,
212 share: &Self::Share,
213 ) -> VidResult<crate::VerificationResult> {
214 if !share.validate() {
215 return Err(VidError::InvalidShare);
216 }
217 for (i, index) in share.range.clone().enumerate() {
218 let payload_digest = HasherNode::from(sha3::Keccak256::digest(&share.payload[i]));
219 if MerkleTree::verify(
221 commit.commit,
222 index as u64,
223 payload_digest,
224 &share.mt_proofs[i],
225 )?
226 .is_err()
227 {
228 return Ok(Err(()));
229 }
230 }
231 Ok(Ok(()))
232 }
233
234 fn recover(
235 param: &Self::Param,
236 _commit: &Self::Commit,
237 shares: &[Self::Share],
238 ) -> VidResult<Vec<u8>> {
239 let original_count = param.recovery_threshold;
240 let recovery_count = param.total_weights - param.recovery_threshold;
241 let Some(first_share) = shares.iter().find(|s| !s.payload.is_empty()) else {
243 return Err(VidError::InsufficientShares);
244 };
245 let shard_bytes = first_share.payload[0].len();
246
247 let mut original_shares: Vec<Option<Vec<u8>>> = vec![None; original_count];
248 if recovery_count == 0 {
249 for share in shares {
251 if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
252 return Err(VidError::InvalidShare);
253 }
254 for (i, index) in share.range.clone().enumerate() {
255 if index < original_count {
256 original_shares[index] = Some(share.payload[i].clone());
257 }
258 }
259 }
260 } else {
261 let mut decoder = reed_solomon_simd::ReedSolomonDecoder::new(
262 original_count,
263 recovery_count,
264 shard_bytes,
265 )?;
266 for share in shares {
267 if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
268 return Err(VidError::InvalidShare);
269 }
270 for (i, index) in share.range.clone().enumerate() {
271 if index < original_count {
272 original_shares[index] = Some(share.payload[i].clone());
273 decoder.add_original_shard(index, &share.payload[i])?;
274 } else {
275 decoder.add_recovery_shard(index - original_count, &share.payload[i])?;
276 }
277 }
278 }
279
280 let result = decoder.decode()?;
281 original_shares
282 .iter_mut()
283 .enumerate()
284 .for_each(|(i, share)| {
285 if share.is_none() {
286 *share = result.restored_original(i).map(|s| s.to_vec());
287 }
288 });
289 }
290 if original_shares.iter().any(|share| share.is_none()) {
291 return Err(VidError::Internal(anyhow!(
292 "Failed to recover the payload."
293 )));
294 }
295 let mut recovered: Vec<_> = original_shares
296 .into_iter()
297 .flat_map(|share| share.unwrap())
298 .collect();
299 match recovered.iter().rposition(|&b| b != 0) {
300 Some(pad_index) if recovered[pad_index] == 1u8 => {
301 recovered.truncate(pad_index);
302 Ok(recovered)
303 },
304 _ => Err(VidError::Argument(
305 "Malformed payload, cannot find the padding position".to_string(),
306 )),
307 }
308 }
309}
310
311#[cfg(test)]
313pub mod tests {
314 use rand::{seq::SliceRandom, RngCore};
315
316 use super::AvidmGf2Scheme;
317 use crate::VidScheme;
318
319 #[test]
320 fn round_trip() {
321 let num_storage_nodes_list = [4, 9, 16];
323 let payload_byte_lens = [1, 31, 32, 500];
324
325 let mut rng = jf_utils::test_rng();
328
329 for num_storage_nodes in num_storage_nodes_list {
330 let weights: Vec<u32> = (0..num_storage_nodes)
331 .map(|_| rng.next_u32() % 5 + 1)
332 .collect();
333 let total_weights: u32 = weights.iter().sum();
334 let recovery_threshold = total_weights.div_ceil(3) as usize;
335 let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
336
337 for payload_byte_len in payload_byte_lens {
338 let payload = {
339 let mut bytes_random = vec![0u8; payload_byte_len];
340 rng.fill_bytes(&mut bytes_random);
341 bytes_random
342 };
343
344 let (commit, mut shares) =
345 AvidmGf2Scheme::disperse(¶ms, &weights, &payload).unwrap();
346
347 assert_eq!(shares.len(), num_storage_nodes);
348
349 shares.iter().for_each(|share| {
351 assert!(AvidmGf2Scheme::verify_share(¶ms, &commit, share)
352 .is_ok_and(|r| r.is_ok()))
353 });
354
355 shares.shuffle(&mut rng);
357 let mut cumulated_weights = 0;
358 let mut cut_index = 0;
359 while cumulated_weights < recovery_threshold {
360 cumulated_weights += shares[cut_index].weight();
361 cut_index += 1;
362 }
363 let payload_recovered =
364 AvidmGf2Scheme::recover(¶ms, &commit, &shares[..cut_index]).unwrap();
365 assert_eq!(payload_recovered, payload);
366 }
367 }
368 }
369
370 #[test]
371 fn round_trip_edge_case() {
372 let num_storage_nodes_list = [4, 9, 16];
374 let payload_byte_lens = [1, 31, 32, 500];
375
376 let mut rng = jf_utils::test_rng();
379
380 for num_storage_nodes in num_storage_nodes_list {
381 let weights: Vec<u32> = (0..num_storage_nodes)
382 .map(|_| rng.next_u32() % 5 + 1)
383 .collect();
384 let total_weights: u32 = weights.iter().sum();
385 let recovery_threshold = total_weights as usize;
386 let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
387
388 for payload_byte_len in payload_byte_lens {
389 let payload = {
390 let mut bytes_random = vec![0u8; payload_byte_len];
391 rng.fill_bytes(&mut bytes_random);
392 bytes_random
393 };
394
395 let (commit, mut shares) =
396 AvidmGf2Scheme::disperse(¶ms, &weights, &payload).unwrap();
397
398 assert_eq!(shares.len(), num_storage_nodes);
399
400 shares.iter().for_each(|share| {
402 assert!(AvidmGf2Scheme::verify_share(¶ms, &commit, share)
403 .is_ok_and(|r| r.is_ok()))
404 });
405
406 shares.shuffle(&mut rng);
408 let payload_recovered =
409 AvidmGf2Scheme::recover(¶ms, &commit, &shares[..]).unwrap();
410 assert_eq!(payload_recovered, payload);
411 }
412 }
413 }
414}