hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::BTreeMap, sync::Arc};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14    message::UpgradeLock,
15    traits::{
16        node_implementation::{ConsensusTime, NodeType, Versions},
17        storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18    },
19    HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24    /// The epoch we are calculating the result for
25    pub epoch: u64,
26    /// The iteration this seed is from. For fresh calculations, this should be `0`.
27    pub iteration: u64,
28    /// the value of the drb calculation at the current iteration
29    pub value: [u8; 32],
30    /// difficulty value for the DRB calculation
31    pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35    Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38    upgrade_lock: UpgradeLock<TYPES, V>,
39    config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41    let base_difficulty = config.drb_difficulty;
42    let upgrade_difficulty = config.drb_upgrade_difficulty;
43    Arc::new(move |view| {
44        let upgrade_lock = upgrade_lock.clone();
45        Box::pin(async move {
46            if upgrade_lock.upgraded_drb_and_header(view).await {
47                upgrade_difficulty
48            } else {
49                base_difficulty
50            }
51        })
52    })
53}
54
55// TODO: Add the following consts once we bench the hash time.
56// <https://github.com/EspressoSystems/HotShot/issues/3880>
57// /// Highest number of hashes that a hardware can complete in a second.
58// const `HASHES_PER_SECOND`
59// /// Time a DRB calculation will take, in terms of number of views.
60// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
61
62// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
63// the hash time.
64// <https://github.com/EspressoSystems/HotShot/issues/3880>
65/// Arbitrary number of times the hash function will be repeatedly called.
66pub const DIFFICULTY_LEVEL: u64 = 10;
67
68/// Interval at which to store the results
69pub const DRB_CHECKPOINT_INTERVAL: u64 = 1000000;
70
71/// DRB seed input for epoch 1 and 2.
72pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73/// DRB result for epoch 1 and 2.
74pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
77pub type DrbSeedInput = [u8; 32];
78
79/// Alias for DRB result from `compute_drb_result`.
80pub type DrbResult = [u8; 32];
81
82/// Number of previous results and seeds to keep
83pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
86// once we bench the hash time.
87// <https://github.com/EspressoSystems/HotShot/issues/3880>
88/// Difficulty level of the DRB calculation.
89///
90/// Represents the number of times the hash function will be repeatedly called.
91#[must_use]
92pub fn difficulty_level() -> u64 {
93    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96/// Compute the DRB result for the leader rotation.
97///
98/// This is to be started two epochs in advance and spawned in a non-blocking thread.
99///
100/// # Arguments
101/// * `drb_seed_input` - Serialized QC signature.
102#[must_use]
103pub async fn compute_drb_result(
104    drb_input: DrbInput,
105    store_drb_progress: StoreDrbProgressFn,
106    load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108    let mut drb_input = drb_input;
109
110    if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111        if loaded_drb_input.iteration >= drb_input.iteration {
112            drb_input = loaded_drb_input;
113        }
114    }
115
116    let mut hash = drb_input.value.to_vec();
117    let mut iteration = drb_input.iteration;
118    let remaining_iterations = drb_input
119        .difficulty_level
120        .checked_sub(iteration)
121        .unwrap_or_else(|| {
122            panic!(
123                "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
124                 This is a fatal error",
125                drb_input.difficulty_level, iteration
126            )
127        });
128
129    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
130
131    // loop up to, but not including, the `final_checkpoint`
132    for _ in 0..final_checkpoint {
133        for _ in 0..DRB_CHECKPOINT_INTERVAL {
134            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
135            // <https://github.com/EspressoSystems/HotShot/issues/3880>
136            hash = Sha256::digest(hash).to_vec();
137        }
138
139        let mut partial_drb_result = [0u8; 32];
140        partial_drb_result.copy_from_slice(&hash);
141
142        iteration += DRB_CHECKPOINT_INTERVAL;
143
144        let updated_drb_input = DrbInput {
145            epoch: drb_input.epoch,
146            iteration,
147            value: partial_drb_result,
148            difficulty_level: drb_input.difficulty_level,
149        };
150
151        let store_drb_progress = store_drb_progress.clone();
152        tokio::spawn(async move {
153            if let Err(e) = store_drb_progress(updated_drb_input).await {
154                tracing::warn!("Failed to store DRB progress during calculation: {}", e);
155            }
156        });
157    }
158
159    let final_checkpoint_iteration = iteration;
160
161    // perform the remaining iterations
162    for _ in final_checkpoint_iteration..drb_input.difficulty_level {
163        hash = Sha256::digest(hash).to_vec();
164        iteration += 1;
165    }
166
167    // Convert the hash to the DRB result.
168    let mut drb_result = [0u8; 32];
169    drb_result.copy_from_slice(&hash);
170    drb_result
171}
172
173/// Seeds for DRB computation and computed results.
174#[derive(Clone, Debug)]
175pub struct DrbResults<TYPES: NodeType> {
176    /// Stored results from computations
177    pub results: BTreeMap<TYPES::Epoch, DrbResult>,
178}
179
180impl<TYPES: NodeType> DrbResults<TYPES> {
181    #[must_use]
182    /// Constructor with initial values for epochs 1 and 2.
183    pub fn new() -> Self {
184        Self {
185            results: BTreeMap::from([
186                (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
187                (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
188            ]),
189        }
190    }
191
192    pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
193        self.results.insert(epoch, result);
194    }
195
196    /// Garbage collects internal data structures
197    pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
198        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
199            return;
200        }
201
202        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
203        // N.B. x.split_off(y) returns the part of the map where key >= y
204
205        // Remove result entries older than EPOCH
206        self.results = self.results.split_off(&retain_epoch);
207    }
208}
209
210impl<TYPES: NodeType> Default for DrbResults<TYPES> {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216/// Functions for leader selection based on the DRB.
217///
218/// The algorithm we use is:
219///
220/// Initialization:
221/// - obtain `drb: [u8; 32]` from the DRB calculation
222/// - sort the stake table for a given epoch by `xor(drb, public_key)`
223/// - generate a cdf of the cumulative stake using this newly-sorted table,
224///   along with a hash of the stake table entries
225///
226/// Selecting a leader:
227/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
228/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
229///   is strictly smaller than the cdf entry
230/// - return the corresponding node as the leader for that view
231pub mod election {
232    use alloy::primitives::{U256, U512};
233    use sha2::{Digest, Sha256, Sha512};
234
235    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
236
237    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
238    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
239        let drb: Vec<u8> = drb.to_vec();
240
241        let mut result: Vec<u8> = vec![];
242
243        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
244            result.push(drb_byte ^ public_key_byte);
245        }
246
247        result
248    }
249
250    /// Generate the stake table CDF, as well as a hash of the resulting stake table
251    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
252        mut stake_table: Vec<Entry>,
253        drb: [u8; 32],
254    ) -> RandomizedCommittee<Entry> {
255        // sort by xor(public_key, drb_result)
256        stake_table.sort_by(|a, b| {
257            cyclic_xor(drb, a.public_key().to_bytes())
258                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
259        });
260
261        let mut hasher = Sha256::new();
262
263        let mut cumulative_stake = U256::from(0);
264        let mut cdf = vec![];
265
266        for entry in stake_table {
267            cumulative_stake += entry.stake();
268            hasher.update(entry.public_key().to_bytes());
269
270            cdf.push((entry, cumulative_stake));
271        }
272
273        RandomizedCommittee {
274            cdf,
275            stake_table_hash: hasher.finalize().into(),
276            drb,
277        }
278    }
279
280    /// select the leader for a view
281    ///
282    /// # Panics
283    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
284    ///
285    /// Note that we try to downcast a U512 to a U256,
286    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
287    pub fn select_randomized_leader<
288        SignatureKey,
289        Entry: StakeTableEntryType<SignatureKey> + Clone,
290    >(
291        randomized_committee: &RandomizedCommittee<Entry>,
292        view: u64,
293    ) -> Entry {
294        let RandomizedCommittee {
295            cdf,
296            stake_table_hash,
297            drb,
298        } = randomized_committee;
299        // We hash the concatenated drb, view and stake table hash.
300        let mut hasher = Sha512::new();
301        hasher.update(drb);
302        hasher.update(view.to_le_bytes());
303        hasher.update(stake_table_hash);
304        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
305
306        // then calculate the remainder modulo the total stake as a U512
307        let remainder: U512 =
308            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
309
310        // and drop the top 32 bytes, downcasting to a U256
311        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
312
313        // now find the first index where the breakpoint is strictly smaller than the cdf
314        //
315        // in principle, this may result in an index larger than `cdf.len()`.
316        // however, we have ensured by construction that `breakpoint < total_stake`
317        // and so the largest index we can actually return is `cdf.len() - 1`
318        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
319
320        // and return the corresponding entry
321        cdf[index].0.clone()
322    }
323
324    #[derive(Clone, Debug)]
325    pub struct RandomizedCommittee<Entry> {
326        /// cdf of nodes by cumulative stake
327        cdf: Vec<(Entry, U256)>,
328        /// Hash of the stake table
329        stake_table_hash: [u8; 32],
330        /// DRB result
331        drb: [u8; 32],
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use std::collections::HashMap;
338
339    use alloy::primitives::U256;
340    use rand::RngCore;
341    use sha2::{Digest, Sha256};
342
343    use super::election::{generate_stake_cdf, select_randomized_leader};
344    use crate::{
345        signature_key::BLSPubKey,
346        stake_table::StakeTableEntry,
347        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
348    };
349
350    #[test]
351    fn test_randomized_leader() {
352        let mut rng = rand::thread_rng();
353        // use an arbitrary Sha256 output.
354        let drb: [u8; 32] = Sha256::digest(b"drb").into();
355        // a stake table with 10 nodes, each with a stake of 1-100
356        let stake_table_entries: Vec<_> = (0..10)
357            .map(|i| StakeTableEntry {
358                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
359                stake_amount: U256::from(rng.next_u64() % 100 + 1),
360            })
361            .collect();
362        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
363
364        // Number of views to test
365        let num_views = 100000;
366        let mut selected = HashMap::<_, u64>::new();
367        // Test the leader election for 100000 views.
368        for i in 0..num_views {
369            let leader = select_randomized_leader(&randomized_committee, i);
370            *selected.entry(leader).or_insert(0) += 1;
371        }
372
373        // Total variation distance
374        let mut tvd = 0.;
375        let total_stakes = stake_table_entries
376            .iter()
377            .map(|e| e.stake())
378            .sum::<U256>()
379            .to::<u64>() as f64;
380        for entry in stake_table_entries {
381            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
382            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
383            tvd += (expected - actual).abs();
384        }
385
386        // sanity check
387        assert!(tvd >= 0.0);
388        // Allow a small margin of error
389        assert!(tvd < 0.03);
390    }
391}