hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::collections::BTreeMap;
8
9use sha2::{Digest, Sha256};
10
11use crate::traits::{
12    node_implementation::{ConsensusTime, NodeType},
13    storage::StoreDrbProgressFn,
14};
15
16pub struct DrbInput {
17    /// The epoch we are calculating the result for
18    pub epoch: u64,
19    /// The iteration this seed is from. For fresh calculations, this should be `0`.
20    pub iteration: u64,
21    /// the value of the drb calculation at the current iteration
22    pub value: [u8; 32],
23}
24
25// TODO: Add the following consts once we bench the hash time.
26// <https://github.com/EspressoSystems/HotShot/issues/3880>
27// /// Highest number of hashes that a hardware can complete in a second.
28// const `HASHES_PER_SECOND`
29// /// Time a DRB calculation will take, in terms of number of views.
30// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
31
32// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
33// the hash time.
34// <https://github.com/EspressoSystems/HotShot/issues/3880>
35/// Arbitrary number of times the hash function will be repeatedly called.
36pub const DIFFICULTY_LEVEL: u64 = 10;
37
38/// Interval at which to store the results
39pub const DRB_CHECKPOINT_INTERVAL: u64 = 3;
40
41/// DRB seed input for epoch 1 and 2.
42pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
43/// DRB result for epoch 1 and 2.
44pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
45
46/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
47pub type DrbSeedInput = [u8; 32];
48
49/// Alias for DRB result from `compute_drb_result`.
50pub type DrbResult = [u8; 32];
51
52/// Number of previous results and seeds to keep
53pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
54
55// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
56// once we bench the hash time.
57// <https://github.com/EspressoSystems/HotShot/issues/3880>
58/// Difficulty level of the DRB calculation.
59///
60/// Represents the number of times the hash function will be repeatedly called.
61#[must_use]
62pub fn difficulty_level() -> u64 {
63    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
64}
65
66/// Compute the DRB result for the leader rotation.
67///
68/// This is to be started two epochs in advance and spawned in a non-blocking thread.
69///
70/// # Arguments
71/// * `drb_seed_input` - Serialized QC signature.
72#[must_use]
73pub fn compute_drb_result(
74    drb_input: DrbInput,
75    store_drb_progress: StoreDrbProgressFn,
76) -> DrbResult {
77    let mut hash = drb_input.value.to_vec();
78    let mut iteration = drb_input.iteration;
79    let remaining_iterations = DIFFICULTY_LEVEL
80      .checked_sub(iteration)
81      .unwrap_or_else( ||
82        panic!(
83          "DRB difficulty level {} exceeds the iteration {} of the input we were given. This is a fatal error", 
84          DIFFICULTY_LEVEL,
85          iteration
86        )
87      );
88
89    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
90
91    // loop up to, but not including, the `final_checkpoint`
92    for _ in 0..final_checkpoint {
93        for _ in 0..DRB_CHECKPOINT_INTERVAL {
94            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
95            // <https://github.com/EspressoSystems/HotShot/issues/3880>
96            hash = Sha256::digest(hash).to_vec();
97        }
98
99        let mut partial_drb_result = [0u8; 32];
100        partial_drb_result.copy_from_slice(&hash);
101
102        iteration += DRB_CHECKPOINT_INTERVAL;
103
104        let storage = store_drb_progress.clone();
105        tokio::spawn(async move {
106            storage(drb_input.epoch, iteration, partial_drb_result).await;
107        });
108    }
109
110    let final_checkpoint_iteration = iteration;
111
112    // perform the remaining iterations
113    for _ in final_checkpoint_iteration..DIFFICULTY_LEVEL {
114        hash = Sha256::digest(hash).to_vec();
115        iteration += 1;
116    }
117
118    // Convert the hash to the DRB result.
119    let mut drb_result = [0u8; 32];
120    drb_result.copy_from_slice(&hash);
121    drb_result
122}
123
124/// Seeds for DRB computation and computed results.
125#[derive(Clone, Debug)]
126pub struct DrbResults<TYPES: NodeType> {
127    /// Stored results from computations
128    pub results: BTreeMap<TYPES::Epoch, DrbResult>,
129}
130
131impl<TYPES: NodeType> DrbResults<TYPES> {
132    #[must_use]
133    /// Constructor with initial values for epochs 1 and 2.
134    pub fn new() -> Self {
135        Self {
136            results: BTreeMap::from([
137                (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
138                (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
139            ]),
140        }
141    }
142
143    pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
144        self.results.insert(epoch, result);
145    }
146
147    /// Garbage collects internal data structures
148    pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
149        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
150            return;
151        }
152
153        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
154        // N.B. x.split_off(y) returns the part of the map where key >= y
155
156        // Remove result entries older than EPOCH
157        self.results = self.results.split_off(&retain_epoch);
158    }
159}
160
161impl<TYPES: NodeType> Default for DrbResults<TYPES> {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// Functions for leader selection based on the DRB.
168///
169/// The algorithm we use is:
170///
171/// Initialization:
172/// - obtain `drb: [u8; 32]` from the DRB calculation
173/// - sort the stake table for a given epoch by `xor(drb, public_key)`
174/// - generate a cdf of the cumulative stake using this newly-sorted table,
175///   along with a hash of the stake table entries
176///
177/// Selecting a leader:
178/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
179/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
180///   is strictly smaller than the cdf entry
181/// - return the corresponding node as the leader for that view
182pub mod election {
183    use alloy::primitives::{U256, U512};
184    use sha2::{Digest, Sha256, Sha512};
185
186    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
187
188    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
189    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
190        let drb: Vec<u8> = drb.to_vec();
191
192        let mut result: Vec<u8> = vec![];
193
194        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
195            result.push(drb_byte ^ public_key_byte);
196        }
197
198        result
199    }
200
201    /// Generate the stake table CDF, as well as a hash of the resulting stake table
202    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
203        mut stake_table: Vec<Entry>,
204        drb: [u8; 32],
205    ) -> RandomizedCommittee<Entry> {
206        // sort by xor(public_key, drb_result)
207        stake_table.sort_by(|a, b| {
208            cyclic_xor(drb, a.public_key().to_bytes())
209                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
210        });
211
212        let mut hasher = Sha256::new();
213
214        let mut cumulative_stake = U256::from(0);
215        let mut cdf = vec![];
216
217        for entry in stake_table {
218            cumulative_stake += entry.stake();
219            hasher.update(entry.public_key().to_bytes());
220
221            cdf.push((entry, cumulative_stake));
222        }
223
224        RandomizedCommittee {
225            cdf,
226            stake_table_hash: hasher.finalize().into(),
227            drb,
228        }
229    }
230
231    /// select the leader for a view
232    ///
233    /// # Panics
234    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
235    ///
236    /// Note that we try to downcast a U512 to a U256,
237    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
238    pub fn select_randomized_leader<
239        SignatureKey,
240        Entry: StakeTableEntryType<SignatureKey> + Clone,
241    >(
242        randomized_committee: &RandomizedCommittee<Entry>,
243        view: u64,
244    ) -> Entry {
245        let RandomizedCommittee {
246            cdf,
247            stake_table_hash,
248            drb,
249        } = randomized_committee;
250        // We hash the concatenated drb, view and stake table hash.
251        let mut hasher = Sha512::new();
252        hasher.update(drb);
253        hasher.update(view.to_le_bytes());
254        hasher.update(stake_table_hash);
255        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
256
257        // then calculate the remainder modulo the total stake as a U512
258        let remainder: U512 =
259            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
260
261        // and drop the top 32 bytes, downcasting to a U256
262        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
263
264        // now find the first index where the breakpoint is strictly smaller than the cdf
265        //
266        // in principle, this may result in an index larger than `cdf.len()`.
267        // however, we have ensured by construction that `breakpoint < total_stake`
268        // and so the largest index we can actually return is `cdf.len() - 1`
269        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
270
271        // and return the corresponding entry
272        cdf[index].0.clone()
273    }
274
275    #[derive(Clone, Debug)]
276    pub struct RandomizedCommittee<Entry> {
277        /// cdf of nodes by cumulative stake
278        cdf: Vec<(Entry, U256)>,
279        /// Hash of the stake table
280        stake_table_hash: [u8; 32],
281        /// DRB result
282        drb: [u8; 32],
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use std::collections::HashMap;
289
290    use alloy::primitives::U256;
291    use rand::RngCore;
292    use sha2::{Digest, Sha256};
293
294    use super::election::{generate_stake_cdf, select_randomized_leader};
295    use crate::{
296        signature_key::BLSPubKey,
297        stake_table::StakeTableEntry,
298        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
299    };
300
301    #[test]
302    fn test_randomized_leader() {
303        let mut rng = rand::thread_rng();
304        // use an arbitrary Sha256 output.
305        let drb: [u8; 32] = Sha256::digest(b"drb").into();
306        // a stake table with 10 nodes, each with a stake of 1-100
307        let stake_table_entries: Vec<_> = (0..10)
308            .map(|i| StakeTableEntry {
309                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
310                stake_amount: U256::from(rng.next_u64() % 100 + 1),
311            })
312            .collect();
313        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
314
315        // Number of views to test
316        let num_views = 100000;
317        let mut selected = HashMap::<_, u64>::new();
318        // Test the leader election for 100000 views.
319        for i in 0..num_views {
320            let leader = select_randomized_leader(&randomized_committee, i);
321            *selected.entry(leader).or_insert(0) += 1;
322        }
323
324        // Total variation distance
325        let mut tvd = 0.;
326        let total_stakes = stake_table_entries
327            .iter()
328            .map(|e| e.stake())
329            .sum::<U256>()
330            .to::<u64>() as f64;
331        for entry in stake_table_entries {
332            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
333            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
334            tvd += (expected - actual).abs();
335        }
336
337        // sanity check
338        assert!(tvd >= 0.0);
339        // Allow a small margin of error
340        assert!(tvd < 0.03);
341    }
342}