hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12
13use crate::{
14    message::UpgradeLock,
15    traits::{
16        node_implementation::{ConsensusTime, NodeType, Versions},
17        storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18    },
19    HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24    /// The epoch we are calculating the result for
25    pub epoch: u64,
26    /// The iteration this seed is from. For fresh calculations, this should be `0`.
27    pub iteration: u64,
28    /// the value of the drb calculation at the current iteration
29    pub value: [u8; 32],
30    /// difficulty value for the DRB calculation
31    pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn<TYPES> =
35    Arc<dyn Fn(<TYPES as NodeType>::View) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38    upgrade_lock: UpgradeLock<TYPES, V>,
39    config: &HotShotConfig<TYPES>,
40) -> DrbDifficultySelectorFn<TYPES> {
41    let base_difficulty = config.drb_difficulty;
42    let upgrade_difficulty = config.drb_upgrade_difficulty;
43    Arc::new(move |view| {
44        let upgrade_lock = upgrade_lock.clone();
45        Box::pin(async move {
46            if upgrade_lock.upgraded_drb_and_header(view).await {
47                upgrade_difficulty
48            } else {
49                base_difficulty
50            }
51        })
52    })
53}
54
55// TODO: Add the following consts once we bench the hash time.
56// <https://github.com/EspressoSystems/HotShot/issues/3880>
57// /// Highest number of hashes that a hardware can complete in a second.
58// const `HASHES_PER_SECOND`
59// /// Time a DRB calculation will take, in terms of number of views.
60// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
61
62// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
63// the hash time.
64// <https://github.com/EspressoSystems/HotShot/issues/3880>
65/// Arbitrary number of times the hash function will be repeatedly called.
66pub const DIFFICULTY_LEVEL: u64 = 10;
67
68/// Interval at which to store the results
69pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
70
71/// DRB seed input for epoch 1 and 2.
72pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
73/// DRB result for epoch 1 and 2.
74pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
75
76/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
77pub type DrbSeedInput = [u8; 32];
78
79/// Alias for DRB result from `compute_drb_result`.
80pub type DrbResult = [u8; 32];
81
82/// Number of previous results and seeds to keep
83pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
84
85// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
86// once we bench the hash time.
87// <https://github.com/EspressoSystems/HotShot/issues/3880>
88/// Difficulty level of the DRB calculation.
89///
90/// Represents the number of times the hash function will be repeatedly called.
91#[must_use]
92pub fn difficulty_level() -> u64 {
93    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
94}
95
96/// Compute the DRB result for the leader rotation.
97///
98/// This is to be started two epochs in advance and spawned in a non-blocking thread.
99///
100/// # Arguments
101/// * `drb_seed_input` - Serialized QC signature.
102#[must_use]
103pub async fn compute_drb_result(
104    drb_input: DrbInput,
105    store_drb_progress: StoreDrbProgressFn,
106    load_drb_progress: LoadDrbProgressFn,
107) -> DrbResult {
108    let mut drb_input = drb_input;
109
110    if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
111        if loaded_drb_input.iteration >= drb_input.iteration {
112            drb_input = loaded_drb_input;
113        }
114    }
115
116    let mut hash = drb_input.value.to_vec();
117    let mut iteration = drb_input.iteration;
118    let remaining_iterations = drb_input
119        .difficulty_level
120        .checked_sub(iteration)
121        .unwrap_or_else(|| {
122            panic!(
123                "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
124                 This is a fatal error",
125                drb_input.difficulty_level, iteration
126            )
127        });
128
129    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
130
131    let mut last_time = Instant::now();
132    let mut last_iteration = iteration;
133
134    // loop up to, but not including, the `final_checkpoint`
135    for _ in 0..final_checkpoint {
136        hash = tokio::task::spawn_blocking(move || {
137            let mut hash_tmp = hash.clone();
138            for _ in 0..DRB_CHECKPOINT_INTERVAL {
139                // TODO: This may be optimized to avoid memcopies after we bench the hash time.
140                // <https://github.com/EspressoSystems/HotShot/issues/3880>
141                hash_tmp = Sha256::digest(&hash_tmp).to_vec();
142            }
143
144            hash_tmp
145        })
146        .await
147        .expect("DRB calculation failed: this should never happen");
148
149        let mut partial_drb_result = [0u8; 32];
150        partial_drb_result.copy_from_slice(&hash);
151
152        iteration += DRB_CHECKPOINT_INTERVAL;
153
154        let updated_drb_input = DrbInput {
155            epoch: drb_input.epoch,
156            iteration,
157            value: partial_drb_result,
158            difficulty_level: drb_input.difficulty_level,
159        };
160
161        let elapsed_time = last_time.elapsed().as_millis();
162
163        let store_drb_progress = store_drb_progress.clone();
164        tokio::spawn(async move {
165            tracing::warn!(
166                "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
167                 {:?}: {:?}",
168                updated_drb_input,
169                last_iteration,
170                elapsed_time
171            );
172            if let Err(e) = store_drb_progress(updated_drb_input).await {
173                tracing::warn!("Failed to store DRB progress during calculation: {}", e);
174            }
175        });
176
177        last_time = Instant::now();
178        last_iteration = iteration;
179    }
180
181    let final_checkpoint_iteration = iteration;
182
183    // perform the remaining iterations
184    hash = tokio::task::spawn_blocking(move || {
185        let mut hash_tmp = hash.clone();
186        for _ in final_checkpoint_iteration..drb_input.difficulty_level {
187            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
188            // <https://github.com/EspressoSystems/HotShot/issues/3880>
189            hash_tmp = Sha256::digest(&hash_tmp).to_vec();
190        }
191
192        hash_tmp
193    })
194    .await
195    .expect("DRB calculation failed: this should never happen");
196
197    // Convert the hash to the DRB result.
198    let mut drb_result = [0u8; 32];
199    drb_result.copy_from_slice(&hash);
200
201    let final_drb_input = DrbInput {
202        epoch: drb_input.epoch,
203        iteration: drb_input.difficulty_level,
204        value: drb_result,
205        difficulty_level: drb_input.difficulty_level,
206    };
207
208    let store_drb_progress = store_drb_progress.clone();
209    tokio::spawn(async move {
210        if let Err(e) = store_drb_progress(final_drb_input).await {
211            tracing::warn!("Failed to store DRB progress during calculation: {}", e);
212        }
213    });
214
215    drb_result
216}
217
218/// Seeds for DRB computation and computed results.
219#[derive(Clone, Debug)]
220pub struct DrbResults<TYPES: NodeType> {
221    /// Stored results from computations
222    pub results: BTreeMap<TYPES::Epoch, DrbResult>,
223}
224
225impl<TYPES: NodeType> DrbResults<TYPES> {
226    #[must_use]
227    /// Constructor with initial values for epochs 1 and 2.
228    pub fn new() -> Self {
229        Self {
230            results: BTreeMap::from([
231                (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
232                (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
233            ]),
234        }
235    }
236
237    pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
238        self.results.insert(epoch, result);
239    }
240
241    /// Garbage collects internal data structures
242    pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
243        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
244            return;
245        }
246
247        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
248        // N.B. x.split_off(y) returns the part of the map where key >= y
249
250        // Remove result entries older than EPOCH
251        self.results = self.results.split_off(&retain_epoch);
252    }
253}
254
255impl<TYPES: NodeType> Default for DrbResults<TYPES> {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// Functions for leader selection based on the DRB.
262///
263/// The algorithm we use is:
264///
265/// Initialization:
266/// - obtain `drb: [u8; 32]` from the DRB calculation
267/// - sort the stake table for a given epoch by `xor(drb, public_key)`
268/// - generate a cdf of the cumulative stake using this newly-sorted table,
269///   along with a hash of the stake table entries
270///
271/// Selecting a leader:
272/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
273/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
274///   is strictly smaller than the cdf entry
275/// - return the corresponding node as the leader for that view
276pub mod election {
277    use alloy::primitives::{U256, U512};
278    use sha2::{Digest, Sha256, Sha512};
279
280    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
281
282    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
283    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
284        let drb: Vec<u8> = drb.to_vec();
285
286        let mut result: Vec<u8> = vec![];
287
288        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
289            result.push(drb_byte ^ public_key_byte);
290        }
291
292        result
293    }
294
295    /// Generate the stake table CDF, as well as a hash of the resulting stake table
296    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
297        mut stake_table: Vec<Entry>,
298        drb: [u8; 32],
299    ) -> RandomizedCommittee<Entry> {
300        // sort by xor(public_key, drb_result)
301        stake_table.sort_by(|a, b| {
302            cyclic_xor(drb, a.public_key().to_bytes())
303                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
304        });
305
306        let mut hasher = Sha256::new();
307
308        let mut cumulative_stake = U256::from(0);
309        let mut cdf = vec![];
310
311        for entry in stake_table {
312            cumulative_stake += entry.stake();
313            hasher.update(entry.public_key().to_bytes());
314
315            cdf.push((entry, cumulative_stake));
316        }
317
318        RandomizedCommittee {
319            cdf,
320            stake_table_hash: hasher.finalize().into(),
321            drb,
322        }
323    }
324
325    /// select the leader for a view
326    ///
327    /// # Panics
328    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
329    ///
330    /// Note that we try to downcast a U512 to a U256,
331    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
332    pub fn select_randomized_leader<
333        SignatureKey,
334        Entry: StakeTableEntryType<SignatureKey> + Clone,
335    >(
336        randomized_committee: &RandomizedCommittee<Entry>,
337        view: u64,
338    ) -> Entry {
339        let RandomizedCommittee {
340            cdf,
341            stake_table_hash,
342            drb,
343        } = randomized_committee;
344        // We hash the concatenated drb, view and stake table hash.
345        let mut hasher = Sha512::new();
346        hasher.update(drb);
347        hasher.update(view.to_le_bytes());
348        hasher.update(stake_table_hash);
349        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
350
351        // then calculate the remainder modulo the total stake as a U512
352        let remainder: U512 =
353            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
354
355        // and drop the top 32 bytes, downcasting to a U256
356        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
357
358        // now find the first index where the breakpoint is strictly smaller than the cdf
359        //
360        // in principle, this may result in an index larger than `cdf.len()`.
361        // however, we have ensured by construction that `breakpoint < total_stake`
362        // and so the largest index we can actually return is `cdf.len() - 1`
363        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
364
365        // and return the corresponding entry
366        cdf[index].0.clone()
367    }
368
369    #[derive(Clone, Debug)]
370    pub struct RandomizedCommittee<Entry> {
371        /// cdf of nodes by cumulative stake
372        cdf: Vec<(Entry, U256)>,
373        /// Hash of the stake table
374        stake_table_hash: [u8; 32],
375        /// DRB result
376        drb: [u8; 32],
377    }
378
379    impl<Entry> RandomizedCommittee<Entry> {
380        pub fn drb_result(&self) -> [u8; 32] {
381            self.drb
382        }
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use std::collections::HashMap;
389
390    use alloy::primitives::U256;
391    use rand::RngCore;
392    use sha2::{Digest, Sha256};
393
394    use super::election::{generate_stake_cdf, select_randomized_leader};
395    use crate::{
396        signature_key::BLSPubKey,
397        stake_table::StakeTableEntry,
398        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
399    };
400
401    #[test]
402    fn test_randomized_leader() {
403        let mut rng = rand::thread_rng();
404        // use an arbitrary Sha256 output.
405        let drb: [u8; 32] = Sha256::digest(b"drb").into();
406        // a stake table with 10 nodes, each with a stake of 1-100
407        let stake_table_entries: Vec<_> = (0..10)
408            .map(|i| StakeTableEntry {
409                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
410                stake_amount: U256::from(rng.next_u64() % 100 + 1),
411            })
412            .collect();
413        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
414
415        // Number of views to test
416        let num_views = 100000;
417        let mut selected = HashMap::<_, u64>::new();
418        // Test the leader election for 100000 views.
419        for i in 0..num_views {
420            let leader = select_randomized_leader(&randomized_committee, i);
421            *selected.entry(leader).or_insert(0) += 1;
422        }
423
424        // Total variation distance
425        let mut tvd = 0.;
426        let total_stakes = stake_table_entries
427            .iter()
428            .map(|e| e.stake())
429            .sum::<U256>()
430            .to::<u64>() as f64;
431        for entry in stake_table_entries {
432            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
433            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
434            tvd += (expected - actual).abs();
435        }
436
437        // sanity check
438        assert!(tvd >= 0.0);
439        // Allow a small margin of error
440        assert!(tvd < 0.03);
441    }
442}