1use std::{collections::BTreeMap, sync::Arc};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14 message::UpgradeLock,
15 traits::{
16 node_implementation::{ConsensusTime, NodeType, Versions},
17 storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18 },
19 HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24 pub epoch: u64,
26 pub iteration: u64,
28 pub value: [u8; 32],
30 pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35 Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38 upgrade_lock: UpgradeLock<TYPES, V>,
39 config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41 let base_difficulty = config.drb_difficulty;
42 let upgrade_difficulty = config.drb_upgrade_difficulty;
43 Arc::new(move |view| {
44 let upgrade_lock = upgrade_lock.clone();
45 Box::pin(async move {
46 if upgrade_lock.upgraded_drb_and_header(view).await {
47 upgrade_difficulty
48 } else {
49 base_difficulty
50 }
51 })
52 })
53}
54
55pub const DIFFICULTY_LEVEL: u64 = 10;
67
68pub const DRB_CHECKPOINT_INTERVAL: u64 = 1000000;
70
71pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76pub type DrbSeedInput = [u8; 32];
78
79pub type DrbResult = [u8; 32];
81
82pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85#[must_use]
92pub fn difficulty_level() -> u64 {
93 unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96#[must_use]
103pub async fn compute_drb_result(
104 drb_input: DrbInput,
105 store_drb_progress: StoreDrbProgressFn,
106 load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108 let mut drb_input = drb_input;
109
110 if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111 if loaded_drb_input.iteration >= drb_input.iteration {
112 drb_input = loaded_drb_input;
113 }
114 }
115
116 let mut hash = drb_input.value.to_vec();
117 let mut iteration = drb_input.iteration;
118 let remaining_iterations = drb_input
119 .difficulty_level
120 .checked_sub(iteration)
121 .unwrap_or_else(|| {
122 panic!(
123 "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
124 This is a fatal error",
125 drb_input.difficulty_level, iteration
126 )
127 });
128
129 let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
130
131 for _ in 0..final_checkpoint {
133 for _ in 0..DRB_CHECKPOINT_INTERVAL {
134 hash = Sha256::digest(hash).to_vec();
137 }
138
139 let mut partial_drb_result = [0u8; 32];
140 partial_drb_result.copy_from_slice(&hash);
141
142 iteration += DRB_CHECKPOINT_INTERVAL;
143
144 let updated_drb_input = DrbInput {
145 epoch: drb_input.epoch,
146 iteration,
147 value: partial_drb_result,
148 difficulty_level: drb_input.difficulty_level,
149 };
150
151 let store_drb_progress = store_drb_progress.clone();
152 tokio::spawn(async move {
153 if let Err(e) = store_drb_progress(updated_drb_input).await {
154 tracing::warn!("Failed to store DRB progress during calculation: {}", e);
155 }
156 });
157 }
158
159 let final_checkpoint_iteration = iteration;
160
161 for _ in final_checkpoint_iteration..drb_input.difficulty_level {
163 hash = Sha256::digest(hash).to_vec();
164 iteration += 1;
165 }
166
167 let mut drb_result = [0u8; 32];
169 drb_result.copy_from_slice(&hash);
170 drb_result
171}
172
173#[derive(Clone, Debug)]
175pub struct DrbResults<TYPES: NodeType> {
176 pub results: BTreeMap<TYPES::Epoch, DrbResult>,
178}
179
180impl<TYPES: NodeType> DrbResults<TYPES> {
181 #[must_use]
182 pub fn new() -> Self {
184 Self {
185 results: BTreeMap::from([
186 (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
187 (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
188 ]),
189 }
190 }
191
192 pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
193 self.results.insert(epoch, result);
194 }
195
196 pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
198 if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
199 return;
200 }
201
202 let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
203 self.results = self.results.split_off(&retain_epoch);
207 }
208}
209
210impl<TYPES: NodeType> Default for DrbResults<TYPES> {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216pub mod election {
232 use alloy::primitives::{U256, U512};
233 use sha2::{Digest, Sha256, Sha512};
234
235 use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
236
237 fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
239 let drb: Vec<u8> = drb.to_vec();
240
241 let mut result: Vec<u8> = vec![];
242
243 for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
244 result.push(drb_byte ^ public_key_byte);
245 }
246
247 result
248 }
249
250 pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
252 mut stake_table: Vec<Entry>,
253 drb: [u8; 32],
254 ) -> RandomizedCommittee<Entry> {
255 stake_table.sort_by(|a, b| {
257 cyclic_xor(drb, a.public_key().to_bytes())
258 .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
259 });
260
261 let mut hasher = Sha256::new();
262
263 let mut cumulative_stake = U256::from(0);
264 let mut cdf = vec![];
265
266 for entry in stake_table {
267 cumulative_stake += entry.stake();
268 hasher.update(entry.public_key().to_bytes());
269
270 cdf.push((entry, cumulative_stake));
271 }
272
273 RandomizedCommittee {
274 cdf,
275 stake_table_hash: hasher.finalize().into(),
276 drb,
277 }
278 }
279
280 pub fn select_randomized_leader<
288 SignatureKey,
289 Entry: StakeTableEntryType<SignatureKey> + Clone,
290 >(
291 randomized_committee: &RandomizedCommittee<Entry>,
292 view: u64,
293 ) -> Entry {
294 let RandomizedCommittee {
295 cdf,
296 stake_table_hash,
297 drb,
298 } = randomized_committee;
299 let mut hasher = Sha512::new();
301 hasher.update(drb);
302 hasher.update(view.to_le_bytes());
303 hasher.update(stake_table_hash);
304 let raw_breakpoint: [u8; 64] = hasher.finalize().into();
305
306 let remainder: U512 =
308 U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
309
310 let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
312
313 let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
319
320 cdf[index].0.clone()
322 }
323
324 #[derive(Clone, Debug)]
325 pub struct RandomizedCommittee<Entry> {
326 cdf: Vec<(Entry, U256)>,
328 stake_table_hash: [u8; 32],
330 drb: [u8; 32],
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use std::collections::HashMap;
338
339 use alloy::primitives::U256;
340 use rand::RngCore;
341 use sha2::{Digest, Sha256};
342
343 use super::election::{generate_stake_cdf, select_randomized_leader};
344 use crate::{
345 signature_key::BLSPubKey,
346 stake_table::StakeTableEntry,
347 traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
348 };
349
350 #[test]
351 fn test_randomized_leader() {
352 let mut rng = rand::thread_rng();
353 let drb: [u8; 32] = Sha256::digest(b"drb").into();
355 let stake_table_entries: Vec<_> = (0..10)
357 .map(|i| StakeTableEntry {
358 stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
359 stake_amount: U256::from(rng.next_u64() % 100 + 1),
360 })
361 .collect();
362 let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
363
364 let num_views = 100000;
366 let mut selected = HashMap::<_, u64>::new();
367 for i in 0..num_views {
369 let leader = select_randomized_leader(&randomized_committee, i);
370 *selected.entry(leader).or_insert(0) += 1;
371 }
372
373 let mut tvd = 0.;
375 let total_stakes = stake_table_entries
376 .iter()
377 .map(|e| e.stake())
378 .sum::<U256>()
379 .to::<u64>() as f64;
380 for entry in stake_table_entries {
381 let expected = entry.stake().to::<u64>() as f64 / total_stakes;
382 let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
383 tvd += (expected - actual).abs();
384 }
385
386 assert!(tvd >= 0.0);
388 assert!(tvd < 0.03);
390 }
391}