1use std::collections::BTreeMap;
8
9use sha2::{Digest, Sha256};
10
11use crate::traits::{
12 node_implementation::{ConsensusTime, NodeType},
13 storage::StoreDrbProgressFn,
14};
15
16pub struct DrbInput {
17 pub epoch: u64,
19 pub iteration: u64,
21 pub value: [u8; 32],
23}
24
25pub const DIFFICULTY_LEVEL: u64 = 10;
37
38pub const DRB_CHECKPOINT_INTERVAL: u64 = 3;
40
41pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
43pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
45
46pub type DrbSeedInput = [u8; 32];
48
49pub type DrbResult = [u8; 32];
51
52pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
54
55#[must_use]
62pub fn difficulty_level() -> u64 {
63 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
64}
65
66#[must_use]
73pub fn compute_drb_result(
74 drb_input: DrbInput,
75 store_drb_progress: StoreDrbProgressFn,
76) -> DrbResult {
77 let mut hash = drb_input.value.to_vec();
78 let mut iteration = drb_input.iteration;
79 let remaining_iterations = DIFFICULTY_LEVEL
80 .checked_sub(iteration)
81 .unwrap_or_else( ||
82 panic!(
83 "DRB difficulty level {} exceeds the iteration {} of the input we were given. This is a fatal error",
84 DIFFICULTY_LEVEL,
85 iteration
86 )
87 );
88
89 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
90
91 for _ in 0..final_checkpoint {
93 for _ in 0..DRB_CHECKPOINT_INTERVAL {
94 hash = Sha256::digest(hash).to_vec();
97 }
98
99 let mut partial_drb_result = [0u8; 32];
100 partial_drb_result.copy_from_slice(&hash);
101
102 iteration += DRB_CHECKPOINT_INTERVAL;
103
104 let storage = store_drb_progress.clone();
105 tokio::spawn(async move {
106 storage(drb_input.epoch, iteration, partial_drb_result).await;
107 });
108 }
109
110 let final_checkpoint_iteration = iteration;
111
112 for _ in final_checkpoint_iteration..DIFFICULTY_LEVEL {
114 hash = Sha256::digest(hash).to_vec();
115 iteration += 1;
116 }
117
118 let mut drb_result = [0u8; 32];
120 drb_result.copy_from_slice(&hash);
121 drb_result
122}
123
124#[derive(Clone, Debug)]
126pub struct DrbResults<TYPES: NodeType> {
127 pub results: BTreeMap<TYPES::Epoch, DrbResult>,
129}
130
131impl<TYPES: NodeType> DrbResults<TYPES> {
132 #[must_use]
133 pub fn new() -> Self {
135 Self {
136 results: BTreeMap::from([
137 (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
138 (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
139 ]),
140 }
141 }
142
143 pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
144 self.results.insert(epoch, result);
145 }
146
147 pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
149 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
150 return;
151 }
152
153 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
154 self.results = self.results.split_off(&retain_epoch);
158 }
159}
160
161impl<TYPES: NodeType> Default for DrbResults<TYPES> {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167pub mod election {
183 use alloy::primitives::{U256, U512};
184 use sha2::{Digest, Sha256, Sha512};
185
186 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
187
188 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
190 let drb: Vec<u8> = drb.to_vec();
191
192 let mut result: Vec<u8> = vec![];
193
194 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
195 result.push(drb_byte ^ public_key_byte);
196 }
197
198 result
199 }
200
201 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
203 mut stake_table: Vec<Entry>,
204 drb: [u8; 32],
205 ) -> RandomizedCommittee<Entry> {
206 stake_table.sort_by(|a, b| {
208 cyclic_xor(drb, a.public_key().to_bytes())
209 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
210 });
211
212 let mut hasher = Sha256::new();
213
214 let mut cumulative_stake = U256::from(0);
215 let mut cdf = vec![];
216
217 for entry in stake_table {
218 cumulative_stake += entry.stake();
219 hasher.update(entry.public_key().to_bytes());
220
221 cdf.push((entry, cumulative_stake));
222 }
223
224 RandomizedCommittee {
225 cdf,
226 stake_table_hash: hasher.finalize().into(),
227 drb,
228 }
229 }
230
231 pub fn select_randomized_leader<
239 SignatureKey,
240 Entry: StakeTableEntryType<SignatureKey> + Clone,
241 >(
242 randomized_committee: &RandomizedCommittee<Entry>,
243 view: u64,
244 ) -> Entry {
245 let RandomizedCommittee {
246 cdf,
247 stake_table_hash,
248 drb,
249 } = randomized_committee;
250 let mut hasher = Sha512::new();
252 hasher.update(drb);
253 hasher.update(view.to_le_bytes());
254 hasher.update(stake_table_hash);
255 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
256
257 let remainder: U512 =
259 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
260
261 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
263
264 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
270
271 cdf[index].0.clone()
273 }
274
275 #[derive(Clone, Debug)]
276 pub struct RandomizedCommittee<Entry> {
277 cdf: Vec<(Entry, U256)>,
279 stake_table_hash: [u8; 32],
281 drb: [u8; 32],
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use std::collections::HashMap;
289
290 use alloy::primitives::U256;
291 use rand::RngCore;
292 use sha2::{Digest, Sha256};
293
294 use super::election::{generate_stake_cdf, select_randomized_leader};
295 use crate::{
296 signature_key::BLSPubKey,
297 stake_table::StakeTableEntry,
298 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
299 };
300
301 #[test]
302 fn test_randomized_leader() {
303 let mut rng = rand::thread_rng();
304 let drb: [u8; 32] = Sha256::digest(b"drb").into();
306 let stake_table_entries: Vec<_> = (0..10)
308 .map(|i| StakeTableEntry {
309 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
310 stake_amount: U256::from(rng.next_u64() % 100 + 1),
311 })
312 .collect();
313 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
314
315 let num_views = 100000;
317 let mut selected = HashMap::<_, u64>::new();
318 for i in 0..num_views {
320 let leader = select_randomized_leader(&randomized_committee, i);
321 *selected.entry(leader).or_insert(0) += 1;
322 }
323
324 let mut tvd = 0.;
326 let total_stakes = stake_table_entries
327 .iter()
328 .map(|e| e.stake())
329 .sum::<U256>()
330 .to::<u64>() as f64;
331 for entry in stake_table_entries {
332 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
333 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
334 tvd += (expected - actual).abs();
335 }
336
337 assert!(tvd >= 0.0);
339 assert!(tvd < 0.03);
341 }
342}