hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use vbs::version::Version;
13use versions::DRB_AND_HEADER_UPGRADE_VERSION;
14
15use crate::{
16    HotShotConfig,
17    data::EpochNumber,
18    traits::{
19        node_implementation::NodeType,
20        storage::{LoadDrbProgressFn, StoreDrbProgressFn},
21    },
22};
23
24#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
25pub struct DrbInput {
26    /// The epoch we are calculating the result for
27    pub epoch: u64,
28    /// The iteration this seed is from. For fresh calculations, this should be `0`.
29    pub iteration: u64,
30    /// the value of the drb calculation at the current iteration
31    pub value: [u8; 32],
32    /// difficulty value for the DRB calculation
33    pub difficulty_level: u64,
34}
35
36pub type DrbDifficultySelectorFn =
37    Arc<dyn Fn(Version) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
38
39pub fn drb_difficulty_selector<TYPES: NodeType>(
40    config: &HotShotConfig<TYPES>,
41) -> DrbDifficultySelectorFn {
42    let base_difficulty = config.drb_difficulty;
43    let upgrade_difficulty = config.drb_upgrade_difficulty;
44    Arc::new(move |version| {
45        Box::pin(async move {
46            if version >= DRB_AND_HEADER_UPGRADE_VERSION {
47                upgrade_difficulty
48            } else {
49                base_difficulty
50            }
51        })
52    })
53}
54
55// TODO: Add the following consts once we bench the hash time.
56// <https://github.com/EspressoSystems/HotShot/issues/3880>
57// /// Highest number of hashes that a hardware can complete in a second.
58// const `HASHES_PER_SECOND`
59// /// Time a DRB calculation will take, in terms of number of views.
60// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
61
62// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
63// the hash time.
64// <https://github.com/EspressoSystems/HotShot/issues/3880>
65/// Arbitrary number of times the hash function will be repeatedly called.
66pub const DIFFICULTY_LEVEL: u64 = 10;
67
68/// Interval at which to store the results
69pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
70
71/// DRB seed input for epoch 1 and 2.
72pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73/// DRB result for epoch 1 and 2.
74pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
77pub type DrbSeedInput = [u8; 32];
78
79/// Alias for DRB result from `compute_drb_result`.
80pub type DrbResult = [u8; 32];
81
82/// Number of previous results and seeds to keep
83pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
86// once we bench the hash time.
87// <https://github.com/EspressoSystems/HotShot/issues/3880>
88/// Difficulty level of the DRB calculation.
89///
90/// Represents the number of times the hash function will be repeatedly called.
91#[must_use]
92pub fn difficulty_level() -> u64 {
93    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96/// Compute the DRB result for the leader rotation.
97///
98/// This is to be started two epochs in advance and spawned in a non-blocking thread.
99///
100/// # Arguments
101/// * `drb_seed_input` - Serialized QC signature.
102#[must_use]
103pub async fn compute_drb_result(
104    drb_input: DrbInput,
105    store_drb_progress: StoreDrbProgressFn,
106    load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108    tracing::warn!("Beginning DRB calculation with input {:?}", drb_input);
109    let mut drb_input = drb_input;
110
111    if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
112        if loaded_drb_input.difficulty_level != drb_input.difficulty_level {
113            tracing::error!(
114                "We are calculating the DRB result with input {drb_input:?}, but we had \
115                 previously stored {loaded_drb_input:?} with a different difficulty level for \
116                 this epoch. Discarding the value from storage"
117            );
118        } else if loaded_drb_input.iteration >= drb_input.iteration {
119            drb_input = loaded_drb_input;
120        }
121    }
122
123    let mut hash = drb_input.value.to_vec();
124    let mut iteration = drb_input.iteration;
125    let remaining_iterations = drb_input
126        .difficulty_level
127        .checked_sub(iteration)
128        .unwrap_or_else(|| {
129            panic!(
130                "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
131                 This is a fatal error",
132                drb_input.difficulty_level, iteration
133            )
134        });
135
136    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
137
138    let mut last_time = Instant::now();
139    let mut last_iteration = iteration;
140
141    // loop up to, but not including, the `final_checkpoint`
142    for _ in 0..final_checkpoint {
143        hash = tokio::task::spawn_blocking(move || {
144            let mut hash_tmp = hash.clone();
145            for _ in 0..DRB_CHECKPOINT_INTERVAL {
146                // TODO: This may be optimized to avoid memcopies after we bench the hash time.
147                // <https://github.com/EspressoSystems/HotShot/issues/3880>
148                hash_tmp = Sha256::digest(&hash_tmp).to_vec();
149            }
150
151            hash_tmp
152        })
153        .await
154        .expect("DRB calculation failed: this should never happen");
155
156        let mut partial_drb_result = [0u8; 32];
157        partial_drb_result.copy_from_slice(&hash);
158
159        iteration += DRB_CHECKPOINT_INTERVAL;
160
161        let updated_drb_input = DrbInput {
162            epoch: drb_input.epoch,
163            iteration,
164            value: partial_drb_result,
165            difficulty_level: drb_input.difficulty_level,
166        };
167
168        let elapsed_time = last_time.elapsed().as_millis();
169
170        let store_drb_progress = store_drb_progress.clone();
171        tokio::spawn(async move {
172            tracing::warn!(
173                "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
174                 {:?}: {:?}",
175                updated_drb_input,
176                last_iteration,
177                elapsed_time
178            );
179            if let Err(e) = store_drb_progress(updated_drb_input).await {
180                tracing::warn!("Failed to store DRB progress during calculation: {}", e);
181            }
182        });
183
184        last_time = Instant::now();
185        last_iteration = iteration;
186    }
187
188    let final_checkpoint_iteration = iteration;
189
190    // perform the remaining iterations
191    hash = tokio::task::spawn_blocking(move || {
192        let mut hash_tmp = hash.clone();
193        for _ in final_checkpoint_iteration..drb_input.difficulty_level {
194            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
195            // <https://github.com/EspressoSystems/HotShot/issues/3880>
196            hash_tmp = Sha256::digest(&hash_tmp).to_vec();
197        }
198
199        hash_tmp
200    })
201    .await
202    .expect("DRB calculation failed: this should never happen");
203
204    // Convert the hash to the DRB result.
205    let mut drb_result = [0u8; 32];
206    drb_result.copy_from_slice(&hash);
207
208    let final_drb_input = DrbInput {
209        epoch: drb_input.epoch,
210        iteration: drb_input.difficulty_level,
211        value: drb_result,
212        difficulty_level: drb_input.difficulty_level,
213    };
214
215    tracing::warn!("Completed DRB calculation. Result: {:?}", final_drb_input);
216
217    let store_drb_progress = store_drb_progress.clone();
218    tokio::spawn(async move {
219        if let Err(e) = store_drb_progress(final_drb_input).await {
220            tracing::warn!("Failed to store DRB progress during calculation: {}", e);
221        }
222    });
223
224    drb_result
225}
226
227/// Seeds for DRB computation and computed results.
228#[derive(Clone, Debug)]
229pub struct DrbResults {
230    /// Stored results from computations
231    pub results: BTreeMap<EpochNumber, DrbResult>,
232}
233
234impl DrbResults {
235    #[must_use]
236    /// Constructor with initial values for epochs 1 and 2.
237    pub fn new() -> Self {
238        Self {
239            results: BTreeMap::from([
240                (EpochNumber::new(1), INITIAL_DRB_RESULT),
241                (EpochNumber::new(2), INITIAL_DRB_RESULT),
242            ]),
243        }
244    }
245
246    pub fn store_result(&mut self, epoch: EpochNumber, result: DrbResult) {
247        self.results.insert(epoch, result);
248    }
249
250    /// Garbage collects internal data structures
251    pub fn garbage_collect(&mut self, epoch: EpochNumber) {
252        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
253            return;
254        }
255
256        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
257        // N.B. x.split_off(y) returns the part of the map where key >= y
258
259        // Remove result entries older than EPOCH
260        self.results = self.results.split_off(&retain_epoch);
261    }
262}
263
264impl Default for DrbResults {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270/// Functions for leader selection based on the DRB.
271///
272/// The algorithm we use is:
273///
274/// Initialization:
275/// - obtain `drb: [u8; 32]` from the DRB calculation
276/// - sort the stake table for a given epoch by `xor(drb, public_key)`
277/// - generate a cdf of the cumulative stake using this newly-sorted table,
278///   along with a hash of the stake table entries
279///
280/// Selecting a leader:
281/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
282/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
283///   is strictly smaller than the cdf entry
284/// - return the corresponding node as the leader for that view
285pub mod election {
286    use alloy::primitives::{U256, U512};
287    use sha2::{Digest, Sha256, Sha512};
288
289    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
290
291    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
292    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
293        let drb: Vec<u8> = drb.to_vec();
294
295        let mut result: Vec<u8> = vec![];
296
297        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
298            result.push(drb_byte ^ public_key_byte);
299        }
300
301        result
302    }
303
304    /// Generate the stake table CDF, as well as a hash of the resulting stake table
305    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
306        mut stake_table: Vec<Entry>,
307        drb: [u8; 32],
308    ) -> RandomizedCommittee<Entry> {
309        // sort by xor(public_key, drb_result)
310        stake_table.sort_by(|a, b| {
311            cyclic_xor(drb, a.public_key().to_bytes())
312                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
313        });
314
315        let mut hasher = Sha256::new();
316
317        let mut cumulative_stake = U256::from(0);
318        let mut cdf = vec![];
319
320        for entry in stake_table {
321            cumulative_stake += entry.stake();
322            hasher.update(entry.public_key().to_bytes());
323
324            cdf.push((entry, cumulative_stake));
325        }
326
327        RandomizedCommittee {
328            cdf,
329            stake_table_hash: hasher.finalize().into(),
330            drb,
331        }
332    }
333
334    /// select the leader for a view
335    ///
336    /// # Panics
337    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
338    ///
339    /// Note that we try to downcast a U512 to a U256,
340    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
341    pub fn select_randomized_leader<
342        SignatureKey,
343        Entry: StakeTableEntryType<SignatureKey> + Clone,
344    >(
345        randomized_committee: &RandomizedCommittee<Entry>,
346        view: u64,
347    ) -> Entry {
348        let RandomizedCommittee {
349            cdf,
350            stake_table_hash,
351            drb,
352        } = randomized_committee;
353        // We hash the concatenated drb, view and stake table hash.
354        let mut hasher = Sha512::new();
355        hasher.update(drb);
356        hasher.update(view.to_le_bytes());
357        hasher.update(stake_table_hash);
358        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
359
360        // then calculate the remainder modulo the total stake as a U512
361        let remainder: U512 =
362            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
363
364        // and drop the top 32 bytes, downcasting to a U256
365        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
366
367        // now find the first index where the breakpoint is strictly smaller than the cdf
368        //
369        // in principle, this may result in an index larger than `cdf.len()`.
370        // however, we have ensured by construction that `breakpoint < total_stake`
371        // and so the largest index we can actually return is `cdf.len() - 1`
372        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
373
374        // and return the corresponding entry
375        cdf[index].0.clone()
376    }
377
378    #[derive(Clone, Debug)]
379    pub struct RandomizedCommittee<Entry> {
380        /// cdf of nodes by cumulative stake
381        cdf: Vec<(Entry, U256)>,
382        /// Hash of the stake table
383        stake_table_hash: [u8; 32],
384        /// DRB result
385        drb: [u8; 32],
386    }
387
388    impl<Entry> RandomizedCommittee<Entry> {
389        pub fn drb_result(&self) -> [u8; 32] {
390            self.drb
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use std::collections::HashMap;
398
399    use alloy::primitives::U256;
400    use rand::RngCore;
401    use sha2::{Digest, Sha256};
402
403    use super::election::{generate_stake_cdf, select_randomized_leader};
404    use crate::{
405        signature_key::BLSPubKey,
406        stake_table::StakeTableEntry,
407        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
408    };
409
410    #[test]
411    fn test_randomized_leader() {
412        let mut rng = rand::thread_rng();
413        // use an arbitrary Sha256 output.
414        let drb: [u8; 32] = Sha256::digest(b"drb").into();
415        // a stake table with 10 nodes, each with a stake of 1-100
416        let stake_table_entries: Vec<_> = (0..10)
417            .map(|i| StakeTableEntry {
418                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
419                stake_amount: U256::from(rng.next_u64() % 100 + 1),
420            })
421            .collect();
422        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
423
424        // Number of views to test
425        let num_views = 100000;
426        let mut selected = HashMap::<_, u64>::new();
427        // Test the leader election for 100000 views.
428        for i in 0..num_views {
429            let leader = select_randomized_leader(&randomized_committee, i);
430            *selected.entry(leader).or_insert(0) += 1;
431        }
432
433        // Total variation distance
434        let mut tvd = 0.;
435        let total_stakes = stake_table_entries
436            .iter()
437            .map(|e| e.stake())
438            .sum::<U256>()
439            .to::<u64>() as f64;
440        for entry in stake_table_entries {
441            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
442            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
443            tvd += (expected - actual).abs();
444        }
445
446        // sanity check
447        assert!(tvd >= 0.0);
448        // Allow a small margin of error
449        assert!(tvd < 0.03);
450    }
451}