hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::BTreeMap, sync::Arc};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14    message::UpgradeLock,
15    traits::{
16        node_implementation::{ConsensusTime, NodeType, Versions},
17        storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18    },
19    HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24    /// The epoch we are calculating the result for
25    pub epoch: u64,
26    /// The iteration this seed is from. For fresh calculations, this should be `0`.
27    pub iteration: u64,
28    /// the value of the drb calculation at the current iteration
29    pub value: [u8; 32],
30    /// difficulty value for the DRB calculation
31    pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35    Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38    upgrade_lock: UpgradeLock<TYPES, V>,
39    config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41    let base_difficulty = config.drb_difficulty;
42    let upgrade_difficulty = config.drb_upgrade_difficulty;
43    Arc::new(move |view| {
44        let upgrade_lock = upgrade_lock.clone();
45        Box::pin(async move {
46            if upgrade_lock.upgraded_drb_and_header(view).await {
47                upgrade_difficulty
48            } else {
49                base_difficulty
50            }
51        })
52    })
53}
54
55// TODO: Add the following consts once we bench the hash time.
56// <https://github.com/EspressoSystems/HotShot/issues/3880>
57// /// Highest number of hashes that a hardware can complete in a second.
58// const `HASHES_PER_SECOND`
59// /// Time a DRB calculation will take, in terms of number of views.
60// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
61
62// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
63// the hash time.
64// <https://github.com/EspressoSystems/HotShot/issues/3880>
65/// Arbitrary number of times the hash function will be repeatedly called.
66pub const DIFFICULTY_LEVEL: u64 = 10;
67
68/// Interval at which to store the results
69pub const DRB_CHECKPOINT_INTERVAL: u64 = 1000000;
70
71/// DRB seed input for epoch 1 and 2.
72pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73/// DRB result for epoch 1 and 2.
74pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
77pub type DrbSeedInput = [u8; 32];
78
79/// Alias for DRB result from `compute_drb_result`.
80pub type DrbResult = [u8; 32];
81
82/// Number of previous results and seeds to keep
83pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
86// once we bench the hash time.
87// <https://github.com/EspressoSystems/HotShot/issues/3880>
88/// Difficulty level of the DRB calculation.
89///
90/// Represents the number of times the hash function will be repeatedly called.
91#[must_use]
92pub fn difficulty_level() -> u64 {
93    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96/// Compute the DRB result for the leader rotation.
97///
98/// This is to be started two epochs in advance and spawned in a non-blocking thread.
99///
100/// # Arguments
101/// * `drb_seed_input` - Serialized QC signature.
102#[must_use]
103pub async fn compute_drb_result(
104    drb_input: DrbInput,
105    store_drb_progress: StoreDrbProgressFn,
106    load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108    let mut drb_input = drb_input;
109
110    if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111        if loaded_drb_input.iteration >= drb_input.iteration {
112            drb_input = loaded_drb_input;
113        }
114    }
115
116    let mut hash = drb_input.value.to_vec();
117    let mut iteration = drb_input.iteration;
118    let remaining_iterations = drb_input.difficulty_level
119      .checked_sub(iteration)
120      .unwrap_or_else(||
121        panic!(
122          "DRB difficulty level {} exceeds the iteration {} of the input we were given. This is a fatal error", 
123          drb_input.difficulty_level,
124          iteration
125        )
126      );
127
128    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
129
130    // loop up to, but not including, the `final_checkpoint`
131    for _ in 0..final_checkpoint {
132        for _ in 0..DRB_CHECKPOINT_INTERVAL {
133            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
134            // <https://github.com/EspressoSystems/HotShot/issues/3880>
135            hash = Sha256::digest(hash).to_vec();
136        }
137
138        let mut partial_drb_result = [0u8; 32];
139        partial_drb_result.copy_from_slice(&hash);
140
141        iteration += DRB_CHECKPOINT_INTERVAL;
142
143        let updated_drb_input = DrbInput {
144            epoch: drb_input.epoch,
145            iteration,
146            value: partial_drb_result,
147            difficulty_level: drb_input.difficulty_level,
148        };
149
150        let store_drb_progress = store_drb_progress.clone();
151        tokio::spawn(async move {
152            if let Err(e) = store_drb_progress(updated_drb_input).await {
153                tracing::warn!("Failed to store DRB progress during calculation: {}", e);
154            }
155        });
156    }
157
158    let final_checkpoint_iteration = iteration;
159
160    // perform the remaining iterations
161    for _ in final_checkpoint_iteration..drb_input.difficulty_level {
162        hash = Sha256::digest(hash).to_vec();
163        iteration += 1;
164    }
165
166    // Convert the hash to the DRB result.
167    let mut drb_result = [0u8; 32];
168    drb_result.copy_from_slice(&hash);
169    drb_result
170}
171
172/// Seeds for DRB computation and computed results.
173#[derive(Clone, Debug)]
174pub struct DrbResults<TYPES: NodeType> {
175    /// Stored results from computations
176    pub results: BTreeMap<TYPES::Epoch, DrbResult>,
177}
178
179impl<TYPES: NodeType> DrbResults<TYPES> {
180    #[must_use]
181    /// Constructor with initial values for epochs 1 and 2.
182    pub fn new() -> Self {
183        Self {
184            results: BTreeMap::from([
185                (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
186                (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
187            ]),
188        }
189    }
190
191    pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
192        self.results.insert(epoch, result);
193    }
194
195    /// Garbage collects internal data structures
196    pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
197        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
198            return;
199        }
200
201        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
202        // N.B. x.split_off(y) returns the part of the map where key >= y
203
204        // Remove result entries older than EPOCH
205        self.results = self.results.split_off(&retain_epoch);
206    }
207}
208
209impl<TYPES: NodeType> Default for DrbResults<TYPES> {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215/// Functions for leader selection based on the DRB.
216///
217/// The algorithm we use is:
218///
219/// Initialization:
220/// - obtain `drb: [u8; 32]` from the DRB calculation
221/// - sort the stake table for a given epoch by `xor(drb, public_key)`
222/// - generate a cdf of the cumulative stake using this newly-sorted table,
223///   along with a hash of the stake table entries
224///
225/// Selecting a leader:
226/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
227/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
228///   is strictly smaller than the cdf entry
229/// - return the corresponding node as the leader for that view
230pub mod election {
231    use alloy::primitives::{U256, U512};
232    use sha2::{Digest, Sha256, Sha512};
233
234    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
235
236    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
237    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
238        let drb: Vec<u8> = drb.to_vec();
239
240        let mut result: Vec<u8> = vec![];
241
242        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
243            result.push(drb_byte ^ public_key_byte);
244        }
245
246        result
247    }
248
249    /// Generate the stake table CDF, as well as a hash of the resulting stake table
250    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
251        mut stake_table: Vec<Entry>,
252        drb: [u8; 32],
253    ) -> RandomizedCommittee<Entry> {
254        // sort by xor(public_key, drb_result)
255        stake_table.sort_by(|a, b| {
256            cyclic_xor(drb, a.public_key().to_bytes())
257                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
258        });
259
260        let mut hasher = Sha256::new();
261
262        let mut cumulative_stake = U256::from(0);
263        let mut cdf = vec![];
264
265        for entry in stake_table {
266            cumulative_stake += entry.stake();
267            hasher.update(entry.public_key().to_bytes());
268
269            cdf.push((entry, cumulative_stake));
270        }
271
272        RandomizedCommittee {
273            cdf,
274            stake_table_hash: hasher.finalize().into(),
275            drb,
276        }
277    }
278
279    /// select the leader for a view
280    ///
281    /// # Panics
282    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
283    ///
284    /// Note that we try to downcast a U512 to a U256,
285    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
286    pub fn select_randomized_leader<
287        SignatureKey,
288        Entry: StakeTableEntryType<SignatureKey> + Clone,
289    >(
290        randomized_committee: &RandomizedCommittee<Entry>,
291        view: u64,
292    ) -> Entry {
293        let RandomizedCommittee {
294            cdf,
295            stake_table_hash,
296            drb,
297        } = randomized_committee;
298        // We hash the concatenated drb, view and stake table hash.
299        let mut hasher = Sha512::new();
300        hasher.update(drb);
301        hasher.update(view.to_le_bytes());
302        hasher.update(stake_table_hash);
303        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
304
305        // then calculate the remainder modulo the total stake as a U512
306        let remainder: U512 =
307            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
308
309        // and drop the top 32 bytes, downcasting to a U256
310        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
311
312        // now find the first index where the breakpoint is strictly smaller than the cdf
313        //
314        // in principle, this may result in an index larger than `cdf.len()`.
315        // however, we have ensured by construction that `breakpoint < total_stake`
316        // and so the largest index we can actually return is `cdf.len() - 1`
317        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
318
319        // and return the corresponding entry
320        cdf[index].0.clone()
321    }
322
323    #[derive(Clone, Debug)]
324    pub struct RandomizedCommittee<Entry> {
325        /// cdf of nodes by cumulative stake
326        cdf: Vec<(Entry, U256)>,
327        /// Hash of the stake table
328        stake_table_hash: [u8; 32],
329        /// DRB result
330        drb: [u8; 32],
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use std::collections::HashMap;
337
338    use alloy::primitives::U256;
339    use rand::RngCore;
340    use sha2::{Digest, Sha256};
341
342    use super::election::{generate_stake_cdf, select_randomized_leader};
343    use crate::{
344        signature_key::BLSPubKey,
345        stake_table::StakeTableEntry,
346        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
347    };
348
349    #[test]
350    fn test_randomized_leader() {
351        let mut rng = rand::thread_rng();
352        // use an arbitrary Sha256 output.
353        let drb: [u8; 32] = Sha256::digest(b"drb").into();
354        // a stake table with 10 nodes, each with a stake of 1-100
355        let stake_table_entries: Vec<_> = (0..10)
356            .map(|i| StakeTableEntry {
357                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
358                stake_amount: U256::from(rng.next_u64() % 100 + 1),
359            })
360            .collect();
361        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
362
363        // Number of views to test
364        let num_views = 100000;
365        let mut selected = HashMap::<_, u64>::new();
366        // Test the leader election for 100000 views.
367        for i in 0..num_views {
368            let leader = select_randomized_leader(&randomized_committee, i);
369            *selected.entry(leader).or_insert(0) += 1;
370        }
371
372        // Total variation distance
373        let mut tvd = 0.;
374        let total_stakes = stake_table_entries
375            .iter()
376            .map(|e| e.stake())
377            .sum::<U256>()
378            .to::<u64>() as f64;
379        for entry in stake_table_entries {
380            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
381            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
382            tvd += (expected - actual).abs();
383        }
384
385        // sanity check
386        assert!(tvd >= 0.0);
387        // Allow a small margin of error
388        assert!(tvd < 0.03);
389    }
390}