1use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use vbs::version::Version;
13use versions::DRB_AND_HEADER_UPGRADE_VERSION;
14
15use crate::{
16 HotShotConfig,
17 data::EpochNumber,
18 traits::{
19 node_implementation::NodeType,
20 storage::{LoadDrbProgressFn, StoreDrbProgressFn},
21 },
22};
23
24#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
25pub struct DrbInput {
26 pub epoch: u64,
28 pub iteration: u64,
30 pub value: [u8; 32],
32 pub difficulty_level: u64,
34}
35
36pub type DrbDifficultySelectorFn =
37 Arc<dyn Fn(Version) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
38
39pub fn drb_difficulty_selector<TYPES: NodeType>(
40 config: &HotShotConfig<TYPES>,
41) -> DrbDifficultySelectorFn {
42 let base_difficulty = config.drb_difficulty;
43 let upgrade_difficulty = config.drb_upgrade_difficulty;
44 Arc::new(move |version| {
45 Box::pin(async move {
46 if version >= DRB_AND_HEADER_UPGRADE_VERSION {
47 upgrade_difficulty
48 } else {
49 base_difficulty
50 }
51 })
52 })
53}
54
55pub const DIFFICULTY_LEVEL: u64 = 10;
67
68pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
70
71pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76pub type DrbSeedInput = [u8; 32];
78
79pub type DrbResult = [u8; 32];
81
82pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85#[must_use]
92pub fn difficulty_level() -> u64 {
93 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96#[must_use]
103pub async fn compute_drb_result(
104 drb_input: DrbInput,
105 store_drb_progress: StoreDrbProgressFn,
106 load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108 tracing::warn!("Beginning DRB calculation with input {:?}", drb_input);
109 let mut drb_input = drb_input;
110
111 if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
112 if loaded_drb_input.difficulty_level != drb_input.difficulty_level {
113 tracing::error!(
114 "We are calculating the DRB result with input {drb_input:?}, but we had \
115 previously stored {loaded_drb_input:?} with a different difficulty level for \
116 this epoch. Discarding the value from storage"
117 );
118 } else if loaded_drb_input.iteration >= drb_input.iteration {
119 drb_input = loaded_drb_input;
120 }
121 }
122
123 let mut hash = drb_input.value.to_vec();
124 let mut iteration = drb_input.iteration;
125 let remaining_iterations = drb_input
126 .difficulty_level
127 .checked_sub(iteration)
128 .unwrap_or_else(|| {
129 panic!(
130 "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
131 This is a fatal error",
132 drb_input.difficulty_level, iteration
133 )
134 });
135
136 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
137
138 let mut last_time = Instant::now();
139 let mut last_iteration = iteration;
140
141 for _ in 0..final_checkpoint {
143 hash = tokio::task::spawn_blocking(move || {
144 let mut hash_tmp = hash.clone();
145 for _ in 0..DRB_CHECKPOINT_INTERVAL {
146 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
149 }
150
151 hash_tmp
152 })
153 .await
154 .expect("DRB calculation failed: this should never happen");
155
156 let mut partial_drb_result = [0u8; 32];
157 partial_drb_result.copy_from_slice(&hash);
158
159 iteration += DRB_CHECKPOINT_INTERVAL;
160
161 let updated_drb_input = DrbInput {
162 epoch: drb_input.epoch,
163 iteration,
164 value: partial_drb_result,
165 difficulty_level: drb_input.difficulty_level,
166 };
167
168 let elapsed_time = last_time.elapsed().as_millis();
169
170 let store_drb_progress = store_drb_progress.clone();
171 tokio::spawn(async move {
172 tracing::warn!(
173 "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
174 {:?}: {:?}",
175 updated_drb_input,
176 last_iteration,
177 elapsed_time
178 );
179 if let Err(e) = store_drb_progress(updated_drb_input).await {
180 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
181 }
182 });
183
184 last_time = Instant::now();
185 last_iteration = iteration;
186 }
187
188 let final_checkpoint_iteration = iteration;
189
190 hash = tokio::task::spawn_blocking(move || {
192 let mut hash_tmp = hash.clone();
193 for _ in final_checkpoint_iteration..drb_input.difficulty_level {
194 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
197 }
198
199 hash_tmp
200 })
201 .await
202 .expect("DRB calculation failed: this should never happen");
203
204 let mut drb_result = [0u8; 32];
206 drb_result.copy_from_slice(&hash);
207
208 let final_drb_input = DrbInput {
209 epoch: drb_input.epoch,
210 iteration: drb_input.difficulty_level,
211 value: drb_result,
212 difficulty_level: drb_input.difficulty_level,
213 };
214
215 tracing::warn!("Completed DRB calculation. Result: {:?}", final_drb_input);
216
217 let store_drb_progress = store_drb_progress.clone();
218 tokio::spawn(async move {
219 if let Err(e) = store_drb_progress(final_drb_input).await {
220 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
221 }
222 });
223
224 drb_result
225}
226
227#[derive(Clone, Debug)]
229pub struct DrbResults {
230 pub results: BTreeMap<EpochNumber, DrbResult>,
232}
233
234impl DrbResults {
235 #[must_use]
236 pub fn new() -> Self {
238 Self {
239 results: BTreeMap::from([
240 (EpochNumber::new(1), INITIAL_DRB_RESULT),
241 (EpochNumber::new(2), INITIAL_DRB_RESULT),
242 ]),
243 }
244 }
245
246 pub fn store_result(&mut self, epoch: EpochNumber, result: DrbResult) {
247 self.results.insert(epoch, result);
248 }
249
250 pub fn garbage_collect(&mut self, epoch: EpochNumber) {
252 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
253 return;
254 }
255
256 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
257 self.results = self.results.split_off(&retain_epoch);
261 }
262}
263
264impl Default for DrbResults {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270pub mod election {
286 use alloy::primitives::{U256, U512};
287 use sha2::{Digest, Sha256, Sha512};
288
289 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
290
291 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
293 let drb: Vec<u8> = drb.to_vec();
294
295 let mut result: Vec<u8> = vec![];
296
297 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
298 result.push(drb_byte ^ public_key_byte);
299 }
300
301 result
302 }
303
304 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
306 mut stake_table: Vec<Entry>,
307 drb: [u8; 32],
308 ) -> RandomizedCommittee<Entry> {
309 stake_table.sort_by(|a, b| {
311 cyclic_xor(drb, a.public_key().to_bytes())
312 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
313 });
314
315 let mut hasher = Sha256::new();
316
317 let mut cumulative_stake = U256::from(0);
318 let mut cdf = vec![];
319
320 for entry in stake_table {
321 cumulative_stake += entry.stake();
322 hasher.update(entry.public_key().to_bytes());
323
324 cdf.push((entry, cumulative_stake));
325 }
326
327 RandomizedCommittee {
328 cdf,
329 stake_table_hash: hasher.finalize().into(),
330 drb,
331 }
332 }
333
334 pub fn select_randomized_leader<
342 SignatureKey,
343 Entry: StakeTableEntryType<SignatureKey> + Clone,
344 >(
345 randomized_committee: &RandomizedCommittee<Entry>,
346 view: u64,
347 ) -> Entry {
348 let RandomizedCommittee {
349 cdf,
350 stake_table_hash,
351 drb,
352 } = randomized_committee;
353 let mut hasher = Sha512::new();
355 hasher.update(drb);
356 hasher.update(view.to_le_bytes());
357 hasher.update(stake_table_hash);
358 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
359
360 let remainder: U512 =
362 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
363
364 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
366
367 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
373
374 cdf[index].0.clone()
376 }
377
378 #[derive(Clone, Debug)]
379 pub struct RandomizedCommittee<Entry> {
380 cdf: Vec<(Entry, U256)>,
382 stake_table_hash: [u8; 32],
384 drb: [u8; 32],
386 }
387
388 impl<Entry> RandomizedCommittee<Entry> {
389 pub fn drb_result(&self) -> [u8; 32] {
390 self.drb
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use std::collections::HashMap;
398
399 use alloy::primitives::U256;
400 use rand::RngCore;
401 use sha2::{Digest, Sha256};
402
403 use super::election::{generate_stake_cdf, select_randomized_leader};
404 use crate::{
405 signature_key::BLSPubKey,
406 stake_table::StakeTableEntry,
407 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
408 };
409
410 #[test]
411 fn test_randomized_leader() {
412 let mut rng = rand::thread_rng();
413 let drb: [u8; 32] = Sha256::digest(b"drb").into();
415 let stake_table_entries: Vec<_> = (0..10)
417 .map(|i| StakeTableEntry {
418 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
419 stake_amount: U256::from(rng.next_u64() % 100 + 1),
420 })
421 .collect();
422 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
423
424 let num_views = 100000;
426 let mut selected = HashMap::<_, u64>::new();
427 for i in 0..num_views {
429 let leader = select_randomized_leader(&randomized_committee, i);
430 *selected.entry(leader).or_insert(0) += 1;
431 }
432
433 let mut tvd = 0.;
435 let total_stakes = stake_table_entries
436 .iter()
437 .map(|e| e.stake())
438 .sum::<U256>()
439 .to::<u64>() as f64;
440 for entry in stake_table_entries {
441 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
442 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
443 tvd += (expected - actual).abs();
444 }
445
446 assert!(tvd >= 0.0);
448 assert!(tvd < 0.03);
450 }
451}