1use std::{ops::Range, vec};
4
5use anyhow::anyhow;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use jf_merkle_tree::{hasher::HasherNode, MerkleTreeScheme};
8use jf_utils::canonical;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11use tagged_base64::tagged;
12
13use crate::{VidError, VidResult, VidScheme};
14
15pub mod namespaced;
17pub mod proofs;
19
20pub(crate) type MerkleTree =
22 jf_merkle_tree::hasher::HasherMerkleTree<sha3::Keccak256, HasherNode<sha3::Keccak256>>;
23type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
24type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
25
26pub struct AvidmGf2Scheme;
28
29#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
31pub struct AvidmGf2Param {
32 pub total_weights: usize,
34 pub recovery_threshold: usize,
36}
37
38impl AvidmGf2Param {
39 pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
41 if recovery_threshold == 0 || total_weights < recovery_threshold {
42 return Err(VidError::InvalidParam);
43 }
44 Ok(Self {
45 total_weights,
46 recovery_threshold,
47 })
48 }
49}
50
51#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
53pub struct AvidmGf2Share {
54 range: Range<usize>,
56 #[serde(with = "canonical")]
58 payload: Vec<Vec<u8>>,
59 #[serde(with = "canonical")]
61 mt_proofs: Vec<MerkleProof>,
62}
63
64impl AvidmGf2Share {
65 pub fn weight(&self) -> usize {
67 self.range.len()
68 }
69
70 pub fn validate(&self) -> bool {
72 self.payload.len() == self.range.len() && self.mt_proofs.len() == self.range.len()
73 }
74}
75
76#[derive(
78 Clone,
79 Copy,
80 Debug,
81 Hash,
82 CanonicalSerialize,
83 CanonicalDeserialize,
84 Eq,
85 PartialEq,
86 Ord,
87 PartialOrd,
88)]
89#[tagged("AvidmGf2Commit")]
90#[repr(C)]
91pub struct AvidmGf2Commit {
92 pub commit: MerkleCommit,
94}
95
96impl AsRef<[u8]> for AvidmGf2Commit {
97 fn as_ref(&self) -> &[u8] {
98 self.commit.as_ref()
99 }
100}
101
102impl AvidmGf2Scheme {
103 pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidmGf2Param> {
105 AvidmGf2Param::new(recovery_threshold, total_weights)
106 }
107
108 fn bit_padding(payload: &[u8], payload_len: usize) -> VidResult<Vec<u8>> {
109 if payload_len < payload.len() + 1 {
110 return Err(VidError::Argument(
111 "Payload length is too large to fit in the given payload length".to_string(),
112 ));
113 }
114 let mut padded = vec![0u8; payload_len];
115 padded[..payload.len()].copy_from_slice(payload);
116 padded[payload.len()] = 1u8;
117 Ok(padded)
118 }
119
120 fn raw_disperse(
121 param: &AvidmGf2Param,
122 payload: &[u8],
123 ) -> VidResult<(MerkleTree, Vec<Vec<u8>>)> {
124 let original_count = param.recovery_threshold;
125 let recovery_count = param.total_weights - param.recovery_threshold;
126 let mut shard_bytes = (payload.len() + 1).div_ceil(original_count);
128 if shard_bytes % 2 == 1 {
129 shard_bytes += 1;
130 }
131 let payload = Self::bit_padding(payload, shard_bytes * original_count)?;
132 let original = payload
133 .chunks(shard_bytes)
134 .map(|chunk| chunk.to_owned())
135 .collect::<Vec<_>>();
136 let recovery = reed_solomon_simd::encode(original_count, recovery_count, &original)?;
137
138 let shares = [original, recovery].concat();
139 let share_digests: Vec<_> = shares
140 .iter()
141 .map(|share| HasherNode::from(sha3::Keccak256::digest(share)))
142 .collect();
143 let mt = MerkleTree::from_elems(None, &share_digests)?;
144 Ok((mt, shares))
145 }
146}
147
148impl VidScheme for AvidmGf2Scheme {
149 type Param = AvidmGf2Param;
150 type Share = AvidmGf2Share;
151 type Commit = AvidmGf2Commit;
152
153 fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
154 let (mt, _) = Self::raw_disperse(param, payload)?;
155 Ok(Self::Commit {
156 commit: mt.commitment(),
157 })
158 }
159
160 fn disperse(
161 param: &Self::Param,
162 distribution: &[u32],
163 payload: &[u8],
164 ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
165 let (mt, shares) = Self::raw_disperse(param, payload)?;
166 let commit = AvidmGf2Commit {
167 commit: mt.commitment(),
168 };
169 let ranges: Vec<_> = distribution
170 .iter()
171 .scan(0, |sum, w| {
172 let prefix_sum = *sum;
173 *sum += w;
174 Some(prefix_sum as usize..*sum as usize)
175 })
176 .collect();
177 let shares: Vec<_> = ranges
178 .into_iter()
179 .map(|range| AvidmGf2Share {
180 range: range.clone(),
181 payload: shares[range.clone()].to_vec(),
182 mt_proofs: range
184 .map(|k| {
185 mt.lookup(k as u64)
186 .expect_ok()
187 .expect("MT lookup shouldn't fail")
188 .1
189 })
190 .collect::<Vec<_>>(),
191 })
192 .collect();
193 Ok((commit, shares))
194 }
195
196 fn verify_share(
197 _param: &Self::Param,
198 commit: &Self::Commit,
199 share: &Self::Share,
200 ) -> VidResult<crate::VerificationResult> {
201 if !share.validate() {
202 return Err(VidError::InvalidShare);
203 }
204 for (i, index) in share.range.clone().enumerate() {
205 let payload_digest = HasherNode::from(sha3::Keccak256::digest(&share.payload[i]));
206 if MerkleTree::verify(
208 commit.commit,
209 index as u64,
210 payload_digest,
211 &share.mt_proofs[i],
212 )?
213 .is_err()
214 {
215 return Ok(Err(()));
216 }
217 }
218 Ok(Ok(()))
219 }
220
221 fn recover(
222 param: &Self::Param,
223 _commit: &Self::Commit,
224 shares: &[Self::Share],
225 ) -> VidResult<Vec<u8>> {
226 let original_count = param.recovery_threshold;
227 let recovery_count = param.total_weights - param.recovery_threshold;
228 let Some(first_share) = shares.iter().find(|s| !s.payload.is_empty()) else {
230 return Err(VidError::InsufficientShares);
231 };
232 let shard_bytes = first_share.payload[0].len();
233
234 let mut decoder = reed_solomon_simd::ReedSolomonDecoder::new(
235 original_count,
236 recovery_count,
237 shard_bytes,
238 )?;
239 let mut original_shares = vec![None; original_count];
240 for share in shares {
241 if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
242 return Err(VidError::InvalidShare);
243 }
244 for (i, index) in share.range.clone().enumerate() {
245 if index < original_count {
246 original_shares[index] = Some(share.payload[i].as_ref());
247 decoder.add_original_shard(index, &share.payload[i])?;
248 } else {
249 decoder.add_recovery_shard(index - original_count, &share.payload[i])?;
250 }
251 }
252 }
253 let result = decoder.decode()?;
254 original_shares
255 .iter_mut()
256 .enumerate()
257 .for_each(|(i, share)| {
258 if share.is_none() {
259 *share = result.restored_original(i);
260 }
261 });
262 if original_shares.iter().any(|share| share.is_none()) {
263 return Err(VidError::Internal(anyhow!(
264 "Failed to recover the payload."
265 )));
266 }
267 let mut recovered: Vec<_> = original_shares
268 .into_iter()
269 .flat_map(|share| share.unwrap().to_vec())
270 .collect();
271 match recovered.iter().rposition(|&b| b != 0) {
272 Some(pad_index) if recovered[pad_index] == 1u8 => {
273 recovered.truncate(pad_index);
274 Ok(recovered)
275 },
276 _ => Err(VidError::Argument(
277 "Malformed payload, cannot find the padding position".to_string(),
278 )),
279 }
280 }
281}
282
283#[cfg(test)]
285pub mod tests {
286 use rand::{seq::SliceRandom, RngCore};
287
288 use super::AvidmGf2Scheme;
289 use crate::VidScheme;
290
291 #[test]
292 fn round_trip() {
293 let num_storage_nodes_list = [4, 9, 16];
295 let payload_byte_lens = [1, 31, 32, 500];
296
297 let mut rng = jf_utils::test_rng();
300
301 for num_storage_nodes in num_storage_nodes_list {
302 let weights: Vec<u32> = (0..num_storage_nodes)
303 .map(|_| rng.next_u32() % 5 + 1)
304 .collect();
305 let total_weights: u32 = weights.iter().sum();
306 let recovery_threshold = total_weights.div_ceil(3) as usize;
307 let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
308
309 for payload_byte_len in payload_byte_lens {
310 let payload = {
311 let mut bytes_random = vec![0u8; payload_byte_len];
312 rng.fill_bytes(&mut bytes_random);
313 bytes_random
314 };
315
316 let (commit, mut shares) =
317 AvidmGf2Scheme::disperse(¶ms, &weights, &payload).unwrap();
318
319 assert_eq!(shares.len(), num_storage_nodes);
320
321 shares.iter().for_each(|share| {
323 assert!(AvidmGf2Scheme::verify_share(¶ms, &commit, share)
324 .is_ok_and(|r| r.is_ok()))
325 });
326
327 shares.shuffle(&mut rng);
329 let mut cumulated_weights = 0;
330 let mut cut_index = 0;
331 while cumulated_weights <= recovery_threshold {
332 cumulated_weights += shares[cut_index].weight();
333 cut_index += 1;
334 }
335 let payload_recovered =
336 AvidmGf2Scheme::recover(¶ms, &commit, &shares[..cut_index]).unwrap();
337 assert_eq!(payload_recovered, payload);
338 }
339 }
340 }
341}