1use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use vbs::version::{StaticVersionType, Version};
13
14use crate::{
15 traits::{
16 node_implementation::{ConsensusTime, NodeType, Versions},
17 storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18 },
19 HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24 pub epoch: u64,
26 pub iteration: u64,
28 pub value: [u8; 32],
30 pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn =
35 Arc<dyn Fn(Version) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38 config: &HotShotConfig<TYPES>,
39) -> DrbDifficultySelectorFn {
40 let base_difficulty = config.drb_difficulty;
41 let upgrade_difficulty = config.drb_upgrade_difficulty;
42 Arc::new(move |version| {
43 Box::pin(async move {
44 if version >= V::DrbAndHeaderUpgrade::VERSION {
45 upgrade_difficulty
46 } else {
47 base_difficulty
48 }
49 })
50 })
51}
52
53pub const DIFFICULTY_LEVEL: u64 = 10;
65
66pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
68
69pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
71pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
73
74pub type DrbSeedInput = [u8; 32];
76
77pub type DrbResult = [u8; 32];
79
80pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
82
83#[must_use]
90pub fn difficulty_level() -> u64 {
91 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
92}
93
94#[must_use]
101pub async fn compute_drb_result(
102 drb_input: DrbInput,
103 store_drb_progress: StoreDrbProgressFn,
104 load_drb_progress: LoadDrbProgressFn,
105) -> DrbResult {
106 tracing::warn!("Beginning DRB calculation with input {:?}", drb_input);
107 let mut drb_input = drb_input;
108
109 if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
110 if loaded_drb_input.difficulty_level != drb_input.difficulty_level {
111 tracing::error!(
112 "We are calculating the DRB result with input {drb_input:?}, but we had \
113 previously stored {loaded_drb_input:?} with a different difficulty level for \
114 this epoch. Discarding the value from storage"
115 );
116 } else if loaded_drb_input.iteration >= drb_input.iteration {
117 drb_input = loaded_drb_input;
118 }
119 }
120
121 let mut hash = drb_input.value.to_vec();
122 let mut iteration = drb_input.iteration;
123 let remaining_iterations = drb_input
124 .difficulty_level
125 .checked_sub(iteration)
126 .unwrap_or_else(|| {
127 panic!(
128 "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
129 This is a fatal error",
130 drb_input.difficulty_level, iteration
131 )
132 });
133
134 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
135
136 let mut last_time = Instant::now();
137 let mut last_iteration = iteration;
138
139 for _ in 0..final_checkpoint {
141 hash = tokio::task::spawn_blocking(move || {
142 let mut hash_tmp = hash.clone();
143 for _ in 0..DRB_CHECKPOINT_INTERVAL {
144 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
147 }
148
149 hash_tmp
150 })
151 .await
152 .expect("DRB calculation failed: this should never happen");
153
154 let mut partial_drb_result = [0u8; 32];
155 partial_drb_result.copy_from_slice(&hash);
156
157 iteration += DRB_CHECKPOINT_INTERVAL;
158
159 let updated_drb_input = DrbInput {
160 epoch: drb_input.epoch,
161 iteration,
162 value: partial_drb_result,
163 difficulty_level: drb_input.difficulty_level,
164 };
165
166 let elapsed_time = last_time.elapsed().as_millis();
167
168 let store_drb_progress = store_drb_progress.clone();
169 tokio::spawn(async move {
170 tracing::warn!(
171 "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
172 {:?}: {:?}",
173 updated_drb_input,
174 last_iteration,
175 elapsed_time
176 );
177 if let Err(e) = store_drb_progress(updated_drb_input).await {
178 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
179 }
180 });
181
182 last_time = Instant::now();
183 last_iteration = iteration;
184 }
185
186 let final_checkpoint_iteration = iteration;
187
188 hash = tokio::task::spawn_blocking(move || {
190 let mut hash_tmp = hash.clone();
191 for _ in final_checkpoint_iteration..drb_input.difficulty_level {
192 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
195 }
196
197 hash_tmp
198 })
199 .await
200 .expect("DRB calculation failed: this should never happen");
201
202 let mut drb_result = [0u8; 32];
204 drb_result.copy_from_slice(&hash);
205
206 let final_drb_input = DrbInput {
207 epoch: drb_input.epoch,
208 iteration: drb_input.difficulty_level,
209 value: drb_result,
210 difficulty_level: drb_input.difficulty_level,
211 };
212
213 tracing::warn!("Completed DRB calculation. Result: {:?}", final_drb_input);
214
215 let store_drb_progress = store_drb_progress.clone();
216 tokio::spawn(async move {
217 if let Err(e) = store_drb_progress(final_drb_input).await {
218 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
219 }
220 });
221
222 drb_result
223}
224
225#[derive(Clone, Debug)]
227pub struct DrbResults<TYPES: NodeType> {
228 pub results: BTreeMap<TYPES::Epoch, DrbResult>,
230}
231
232impl<TYPES: NodeType> DrbResults<TYPES> {
233 #[must_use]
234 pub fn new() -> Self {
236 Self {
237 results: BTreeMap::from([
238 (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
239 (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
240 ]),
241 }
242 }
243
244 pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
245 self.results.insert(epoch, result);
246 }
247
248 pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
250 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
251 return;
252 }
253
254 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
255 self.results = self.results.split_off(&retain_epoch);
259 }
260}
261
262impl<TYPES: NodeType> Default for DrbResults<TYPES> {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268pub mod election {
284 use alloy::primitives::{U256, U512};
285 use sha2::{Digest, Sha256, Sha512};
286
287 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
288
289 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
291 let drb: Vec<u8> = drb.to_vec();
292
293 let mut result: Vec<u8> = vec![];
294
295 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
296 result.push(drb_byte ^ public_key_byte);
297 }
298
299 result
300 }
301
302 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
304 mut stake_table: Vec<Entry>,
305 drb: [u8; 32],
306 ) -> RandomizedCommittee<Entry> {
307 stake_table.sort_by(|a, b| {
309 cyclic_xor(drb, a.public_key().to_bytes())
310 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
311 });
312
313 let mut hasher = Sha256::new();
314
315 let mut cumulative_stake = U256::from(0);
316 let mut cdf = vec![];
317
318 for entry in stake_table {
319 cumulative_stake += entry.stake();
320 hasher.update(entry.public_key().to_bytes());
321
322 cdf.push((entry, cumulative_stake));
323 }
324
325 RandomizedCommittee {
326 cdf,
327 stake_table_hash: hasher.finalize().into(),
328 drb,
329 }
330 }
331
332 pub fn select_randomized_leader<
340 SignatureKey,
341 Entry: StakeTableEntryType<SignatureKey> + Clone,
342 >(
343 randomized_committee: &RandomizedCommittee<Entry>,
344 view: u64,
345 ) -> Entry {
346 let RandomizedCommittee {
347 cdf,
348 stake_table_hash,
349 drb,
350 } = randomized_committee;
351 let mut hasher = Sha512::new();
353 hasher.update(drb);
354 hasher.update(view.to_le_bytes());
355 hasher.update(stake_table_hash);
356 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
357
358 let remainder: U512 =
360 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
361
362 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
364
365 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
371
372 cdf[index].0.clone()
374 }
375
376 #[derive(Clone, Debug)]
377 pub struct RandomizedCommittee<Entry> {
378 cdf: Vec<(Entry, U256)>,
380 stake_table_hash: [u8; 32],
382 drb: [u8; 32],
384 }
385
386 impl<Entry> RandomizedCommittee<Entry> {
387 pub fn drb_result(&self) -> [u8; 32] {
388 self.drb
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::collections::HashMap;
396
397 use alloy::primitives::U256;
398 use rand::RngCore;
399 use sha2::{Digest, Sha256};
400
401 use super::election::{generate_stake_cdf, select_randomized_leader};
402 use crate::{
403 signature_key::BLSPubKey,
404 stake_table::StakeTableEntry,
405 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
406 };
407
408 #[test]
409 fn test_randomized_leader() {
410 let mut rng = rand::thread_rng();
411 let drb: [u8; 32] = Sha256::digest(b"drb").into();
413 let stake_table_entries: Vec<_> = (0..10)
415 .map(|i| StakeTableEntry {
416 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
417 stake_amount: U256::from(rng.next_u64() % 100 + 1),
418 })
419 .collect();
420 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
421
422 let num_views = 100000;
424 let mut selected = HashMap::<_, u64>::new();
425 for i in 0..num_views {
427 let leader = select_randomized_leader(&randomized_committee, i);
428 *selected.entry(leader).or_insert(0) += 1;
429 }
430
431 let mut tvd = 0.;
433 let total_stakes = stake_table_entries
434 .iter()
435 .map(|e| e.stake())
436 .sum::<U256>()
437 .to::<u64>() as f64;
438 for entry in stake_table_entries {
439 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
440 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
441 tvd += (expected - actual).abs();
442 }
443
444 assert!(tvd >= 0.0);
446 assert!(tvd < 0.03);
448 }
449}