1use std::{collections::BTreeMap, sync::Arc};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14 message::UpgradeLock,
15 traits::{
16 node_implementation::{ConsensusTime, NodeType, Versions},
17 storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18 },
19 HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24 pub epoch: u64,
26 pub iteration: u64,
28 pub value: [u8; 32],
30 pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35 Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38 upgrade_lock: UpgradeLock<TYPES, V>,
39 config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41 let base_difficulty = config.drb_difficulty;
42 let upgrade_difficulty = config.drb_upgrade_difficulty;
43 Arc::new(move |view| {
44 let upgrade_lock = upgrade_lock.clone();
45 Box::pin(async move {
46 if upgrade_lock.upgraded_drb_and_header(view).await {
47 upgrade_difficulty
48 } else {
49 base_difficulty
50 }
51 })
52 })
53}
54
55pub const DIFFICULTY_LEVEL: u64 = 10;
67
68pub const DRB_CHECKPOINT_INTERVAL: u64 = 1000000;
70
71pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76pub type DrbSeedInput = [u8; 32];
78
79pub type DrbResult = [u8; 32];
81
82pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85#[must_use]
92pub fn difficulty_level() -> u64 {
93 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96#[must_use]
103pub async fn compute_drb_result(
104 drb_input: DrbInput,
105 store_drb_progress: StoreDrbProgressFn,
106 load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108 let mut drb_input = drb_input;
109
110 if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111 if loaded_drb_input.iteration >= drb_input.iteration {
112 drb_input = loaded_drb_input;
113 }
114 }
115
116 let mut hash = drb_input.value.to_vec();
117 let mut iteration = drb_input.iteration;
118 let remaining_iterations = drb_input.difficulty_level
119 .checked_sub(iteration)
120 .unwrap_or_else(||
121 panic!(
122 "DRB difficulty level {} exceeds the iteration {} of the input we were given. This is a fatal error",
123 drb_input.difficulty_level,
124 iteration
125 )
126 );
127
128 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
129
130 for _ in 0..final_checkpoint {
132 for _ in 0..DRB_CHECKPOINT_INTERVAL {
133 hash = Sha256::digest(hash).to_vec();
136 }
137
138 let mut partial_drb_result = [0u8; 32];
139 partial_drb_result.copy_from_slice(&hash);
140
141 iteration += DRB_CHECKPOINT_INTERVAL;
142
143 let updated_drb_input = DrbInput {
144 epoch: drb_input.epoch,
145 iteration,
146 value: partial_drb_result,
147 difficulty_level: drb_input.difficulty_level,
148 };
149
150 let store_drb_progress = store_drb_progress.clone();
151 tokio::spawn(async move {
152 if let Err(e) = store_drb_progress(updated_drb_input).await {
153 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
154 }
155 });
156 }
157
158 let final_checkpoint_iteration = iteration;
159
160 for _ in final_checkpoint_iteration..drb_input.difficulty_level {
162 hash = Sha256::digest(hash).to_vec();
163 iteration += 1;
164 }
165
166 let mut drb_result = [0u8; 32];
168 drb_result.copy_from_slice(&hash);
169 drb_result
170}
171
172#[derive(Clone, Debug)]
174pub struct DrbResults<TYPES: NodeType> {
175 pub results: BTreeMap<TYPES::Epoch, DrbResult>,
177}
178
179impl<TYPES: NodeType> DrbResults<TYPES> {
180 #[must_use]
181 pub fn new() -> Self {
183 Self {
184 results: BTreeMap::from([
185 (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
186 (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
187 ]),
188 }
189 }
190
191 pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
192 self.results.insert(epoch, result);
193 }
194
195 pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
197 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
198 return;
199 }
200
201 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
202 self.results = self.results.split_off(&retain_epoch);
206 }
207}
208
209impl<TYPES: NodeType> Default for DrbResults<TYPES> {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215pub mod election {
231 use alloy::primitives::{U256, U512};
232 use sha2::{Digest, Sha256, Sha512};
233
234 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
235
236 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
238 let drb: Vec<u8> = drb.to_vec();
239
240 let mut result: Vec<u8> = vec![];
241
242 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
243 result.push(drb_byte ^ public_key_byte);
244 }
245
246 result
247 }
248
249 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
251 mut stake_table: Vec<Entry>,
252 drb: [u8; 32],
253 ) -> RandomizedCommittee<Entry> {
254 stake_table.sort_by(|a, b| {
256 cyclic_xor(drb, a.public_key().to_bytes())
257 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
258 });
259
260 let mut hasher = Sha256::new();
261
262 let mut cumulative_stake = U256::from(0);
263 let mut cdf = vec![];
264
265 for entry in stake_table {
266 cumulative_stake += entry.stake();
267 hasher.update(entry.public_key().to_bytes());
268
269 cdf.push((entry, cumulative_stake));
270 }
271
272 RandomizedCommittee {
273 cdf,
274 stake_table_hash: hasher.finalize().into(),
275 drb,
276 }
277 }
278
279 pub fn select_randomized_leader<
287 SignatureKey,
288 Entry: StakeTableEntryType<SignatureKey> + Clone,
289 >(
290 randomized_committee: &RandomizedCommittee<Entry>,
291 view: u64,
292 ) -> Entry {
293 let RandomizedCommittee {
294 cdf,
295 stake_table_hash,
296 drb,
297 } = randomized_committee;
298 let mut hasher = Sha512::new();
300 hasher.update(drb);
301 hasher.update(view.to_le_bytes());
302 hasher.update(stake_table_hash);
303 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
304
305 let remainder: U512 =
307 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
308
309 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
311
312 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
318
319 cdf[index].0.clone()
321 }
322
323 #[derive(Clone, Debug)]
324 pub struct RandomizedCommittee<Entry> {
325 cdf: Vec<(Entry, U256)>,
327 stake_table_hash: [u8; 32],
329 drb: [u8; 32],
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use std::collections::HashMap;
337
338 use alloy::primitives::U256;
339 use rand::RngCore;
340 use sha2::{Digest, Sha256};
341
342 use super::election::{generate_stake_cdf, select_randomized_leader};
343 use crate::{
344 signature_key::BLSPubKey,
345 stake_table::StakeTableEntry,
346 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
347 };
348
349 #[test]
350 fn test_randomized_leader() {
351 let mut rng = rand::thread_rng();
352 let drb: [u8; 32] = Sha256::digest(b"drb").into();
354 let stake_table_entries: Vec<_> = (0..10)
356 .map(|i| StakeTableEntry {
357 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
358 stake_amount: U256::from(rng.next_u64() % 100 + 1),
359 })
360 .collect();
361 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
362
363 let num_views = 100000;
365 let mut selected = HashMap::<_, u64>::new();
366 for i in 0..num_views {
368 let leader = select_randomized_leader(&randomized_committee, i);
369 *selected.entry(leader).or_insert(0) += 1;
370 }
371
372 let mut tvd = 0.;
374 let total_stakes = stake_table_entries
375 .iter()
376 .map(|e| e.stake())
377 .sum::<U256>()
378 .to::<u64>() as f64;
379 for entry in stake_table_entries {
380 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
381 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
382 tvd += (expected - actual).abs();
383 }
384
385 assert!(tvd >= 0.0);
387 assert!(tvd < 0.03);
389 }
390}