1use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14 message::UpgradeLock,
15 traits::{
16 node_implementation::{ConsensusTime, NodeType, Versions},
17 storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18 },
19 HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24 pub epoch: u64,
26 pub iteration: u64,
28 pub value: [u8; 32],
30 pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35 Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38 upgrade_lock: UpgradeLock<TYPES, V>,
39 config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41 let base_difficulty = config.drb_difficulty;
42 let upgrade_difficulty = config.drb_upgrade_difficulty;
43 Arc::new(move |view| {
44 let upgrade_lock = upgrade_lock.clone();
45 Box::pin(async move {
46 if upgrade_lock.upgraded_drb_and_header(view).await {
47 upgrade_difficulty
48 } else {
49 base_difficulty
50 }
51 })
52 })
53}
54
55pub const DIFFICULTY_LEVEL: u64 = 10;
67
68pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
70
71pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76pub type DrbSeedInput = [u8; 32];
78
79pub type DrbResult = [u8; 32];
81
82pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85#[must_use]
92pub fn difficulty_level() -> u64 {
93 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96#[must_use]
103pub async fn compute_drb_result(
104 drb_input: DrbInput,
105 store_drb_progress: StoreDrbProgressFn,
106 load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108 let mut drb_input = drb_input;
109
110 if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111 if loaded_drb_input.iteration >= drb_input.iteration {
112 drb_input = loaded_drb_input;
113 }
114 }
115
116 let mut hash = drb_input.value.to_vec();
117 let mut iteration = drb_input.iteration;
118 let remaining_iterations = drb_input
119 .difficulty_level
120 .checked_sub(iteration)
121 .unwrap_or_else(|| {
122 panic!(
123 "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
124 This is a fatal error",
125 drb_input.difficulty_level, iteration
126 )
127 });
128
129 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
130
131 let mut last_time = Instant::now();
132 let mut last_iteration = iteration;
133
134 for _ in 0..final_checkpoint {
136 hash = tokio::task::spawn_blocking(move || {
137 let mut hash_tmp = hash.clone();
138 for _ in 0..DRB_CHECKPOINT_INTERVAL {
139 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
142 }
143
144 hash_tmp
145 })
146 .await
147 .expect("DRB calculation failed: this should never happen");
148
149 let mut partial_drb_result = [0u8; 32];
150 partial_drb_result.copy_from_slice(&hash);
151
152 iteration += DRB_CHECKPOINT_INTERVAL;
153
154 let updated_drb_input = DrbInput {
155 epoch: drb_input.epoch,
156 iteration,
157 value: partial_drb_result,
158 difficulty_level: drb_input.difficulty_level,
159 };
160
161 let elapsed_time = last_time.elapsed().as_millis();
162
163 let store_drb_progress = store_drb_progress.clone();
164 tokio::spawn(async move {
165 tracing::warn!(
166 "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
167 {:?}: {:?}",
168 updated_drb_input,
169 last_iteration,
170 elapsed_time
171 );
172 if let Err(e) = store_drb_progress(updated_drb_input).await {
173 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
174 }
175 });
176
177 last_time = Instant::now();
178 last_iteration = iteration;
179 }
180
181 let final_checkpoint_iteration = iteration;
182
183 hash = tokio::task::spawn_blocking(move || {
185 let mut hash_tmp = hash.clone();
186 for _ in final_checkpoint_iteration..drb_input.difficulty_level {
187 hash_tmp = Sha256::digest(&hash_tmp).to_vec();
190 }
191
192 hash_tmp
193 })
194 .await
195 .expect("DRB calculation failed: this should never happen");
196
197 let mut drb_result = [0u8; 32];
199 drb_result.copy_from_slice(&hash);
200
201 let final_drb_input = DrbInput {
202 epoch: drb_input.epoch,
203 iteration: drb_input.difficulty_level,
204 value: drb_result,
205 difficulty_level: drb_input.difficulty_level,
206 };
207
208 let store_drb_progress = store_drb_progress.clone();
209 tokio::spawn(async move {
210 if let Err(e) = store_drb_progress(final_drb_input).await {
211 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
212 }
213 });
214
215 drb_result
216}
217
218#[derive(Clone, Debug)]
220pub struct DrbResults<TYPES: NodeType> {
221 pub results: BTreeMap<TYPES::Epoch, DrbResult>,
223}
224
225impl<TYPES: NodeType> DrbResults<TYPES> {
226 #[must_use]
227 pub fn new() -> Self {
229 Self {
230 results: BTreeMap::from([
231 (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
232 (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
233 ]),
234 }
235 }
236
237 pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
238 self.results.insert(epoch, result);
239 }
240
241 pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
243 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
244 return;
245 }
246
247 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
248 self.results = self.results.split_off(&retain_epoch);
252 }
253}
254
255impl<TYPES: NodeType> Default for DrbResults<TYPES> {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261pub mod election {
277 use alloy::primitives::{U256, U512};
278 use sha2::{Digest, Sha256, Sha512};
279
280 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
281
282 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
284 let drb: Vec<u8> = drb.to_vec();
285
286 let mut result: Vec<u8> = vec![];
287
288 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
289 result.push(drb_byte ^ public_key_byte);
290 }
291
292 result
293 }
294
295 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
297 mut stake_table: Vec<Entry>,
298 drb: [u8; 32],
299 ) -> RandomizedCommittee<Entry> {
300 stake_table.sort_by(|a, b| {
302 cyclic_xor(drb, a.public_key().to_bytes())
303 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
304 });
305
306 let mut hasher = Sha256::new();
307
308 let mut cumulative_stake = U256::from(0);
309 let mut cdf = vec![];
310
311 for entry in stake_table {
312 cumulative_stake += entry.stake();
313 hasher.update(entry.public_key().to_bytes());
314
315 cdf.push((entry, cumulative_stake));
316 }
317
318 RandomizedCommittee {
319 cdf,
320 stake_table_hash: hasher.finalize().into(),
321 drb,
322 }
323 }
324
325 pub fn select_randomized_leader<
333 SignatureKey,
334 Entry: StakeTableEntryType<SignatureKey> + Clone,
335 >(
336 randomized_committee: &RandomizedCommittee<Entry>,
337 view: u64,
338 ) -> Entry {
339 let RandomizedCommittee {
340 cdf,
341 stake_table_hash,
342 drb,
343 } = randomized_committee;
344 let mut hasher = Sha512::new();
346 hasher.update(drb);
347 hasher.update(view.to_le_bytes());
348 hasher.update(stake_table_hash);
349 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
350
351 let remainder: U512 =
353 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
354
355 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
357
358 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
364
365 cdf[index].0.clone()
367 }
368
369 #[derive(Clone, Debug)]
370 pub struct RandomizedCommittee<Entry> {
371 cdf: Vec<(Entry, U256)>,
373 stake_table_hash: [u8; 32],
375 drb: [u8; 32],
377 }
378
379 impl<Entry> RandomizedCommittee<Entry> {
380 pub fn drb_result(&self) -> [u8; 32] {
381 self.drb
382 }
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use std::collections::HashMap;
389
390 use alloy::primitives::U256;
391 use rand::RngCore;
392 use sha2::{Digest, Sha256};
393
394 use super::election::{generate_stake_cdf, select_randomized_leader};
395 use crate::{
396 signature_key::BLSPubKey,
397 stake_table::StakeTableEntry,
398 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
399 };
400
401 #[test]
402 fn test_randomized_leader() {
403 let mut rng = rand::thread_rng();
404 let drb: [u8; 32] = Sha256::digest(b"drb").into();
406 let stake_table_entries: Vec<_> = (0..10)
408 .map(|i| StakeTableEntry {
409 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
410 stake_amount: U256::from(rng.next_u64() % 100 + 1),
411 })
412 .collect();
413 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
414
415 let num_views = 100000;
417 let mut selected = HashMap::<_, u64>::new();
418 for i in 0..num_views {
420 let leader = select_randomized_leader(&randomized_committee, i);
421 *selected.entry(leader).or_insert(0) += 1;
422 }
423
424 let mut tvd = 0.;
426 let total_stakes = stake_table_entries
427 .iter()
428 .map(|e| e.stake())
429 .sum::<U256>()
430 .to::<u64>() as f64;
431 for entry in stake_table_entries {
432 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
433 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
434 tvd += (expected - actual).abs();
435 }
436
437 assert!(tvd >= 0.0);
439 assert!(tvd < 0.03);
441 }
442}