hotshot_types/
drb.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::BTreeMap, sync::Arc, time::Instant};
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use vbs::version::{StaticVersionType, Version};
13
14use crate::{
15    traits::{
16        node_implementation::{ConsensusTime, NodeType, Versions},
17        storage::{LoadDrbProgressFn, StoreDrbProgressFn},
18    },
19    HotShotConfig,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
23pub struct DrbInput {
24    /// The epoch we are calculating the result for
25    pub epoch: u64,
26    /// The iteration this seed is from. For fresh calculations, this should be `0`.
27    pub iteration: u64,
28    /// the value of the drb calculation at the current iteration
29    pub value: [u8; 32],
30    /// difficulty value for the DRB calculation
31    pub difficulty_level: u64,
32}
33
34pub type DrbDifficultySelectorFn =
35    Arc<dyn Fn(Version) -> BoxFuture<'static, u64> + Send + Sync + 'static>;
36
37pub fn drb_difficulty_selector<TYPES: NodeType, V: Versions>(
38    config: &HotShotConfig<TYPES>,
39) -> DrbDifficultySelectorFn {
40    let base_difficulty = config.drb_difficulty;
41    let upgrade_difficulty = config.drb_upgrade_difficulty;
42    Arc::new(move |version| {
43        Box::pin(async move {
44            if version >= V::DrbAndHeaderUpgrade::VERSION {
45                upgrade_difficulty
46            } else {
47                base_difficulty
48            }
49        })
50    })
51}
52
53// TODO: Add the following consts once we bench the hash time.
54// <https://github.com/EspressoSystems/HotShot/issues/3880>
55// /// Highest number of hashes that a hardware can complete in a second.
56// const `HASHES_PER_SECOND`
57// /// Time a DRB calculation will take, in terms of number of views.
58// const `DRB_CALCULATION_NUM_VIEW`: u64 = 300;
59
60// TODO: Replace this with an accurate number calculated by `fn difficulty_level()` once we bench
61// the hash time.
62// <https://github.com/EspressoSystems/HotShot/issues/3880>
63/// Arbitrary number of times the hash function will be repeatedly called.
64pub const DIFFICULTY_LEVEL: u64 = 10;
65
66/// Interval at which to store the results
67pub const DRB_CHECKPOINT_INTERVAL: u64 = 1_000_000_000;
68
69/// DRB seed input for epoch 1 and 2.
70pub const INITIAL_DRB_SEED_INPUT: [u8; 32] = [0; 32];
71/// DRB result for epoch 1 and 2.
72pub const INITIAL_DRB_RESULT: [u8; 32] = [0; 32];
73
74/// Alias for DRB seed input for `compute_drb_result`, serialized from the QC signature.
75pub type DrbSeedInput = [u8; 32];
76
77/// Alias for DRB result from `compute_drb_result`.
78pub type DrbResult = [u8; 32];
79
80/// Number of previous results and seeds to keep
81pub const KEEP_PREVIOUS_RESULT_COUNT: u64 = 8;
82
83// TODO: Use `HASHES_PER_SECOND` * `VIEW_TIMEOUT` * `DRB_CALCULATION_NUM_VIEW` to calculate this
84// once we bench the hash time.
85// <https://github.com/EspressoSystems/HotShot/issues/3880>
86/// Difficulty level of the DRB calculation.
87///
88/// Represents the number of times the hash function will be repeatedly called.
89#[must_use]
90pub fn difficulty_level() -> u64 {
91    unimplemented!("Use an arbitrary `DIFFICULTY_LEVEL` for now before we bench the hash time.");
92}
93
94/// Compute the DRB result for the leader rotation.
95///
96/// This is to be started two epochs in advance and spawned in a non-blocking thread.
97///
98/// # Arguments
99/// * `drb_seed_input` - Serialized QC signature.
100#[must_use]
101pub async fn compute_drb_result(
102    drb_input: DrbInput,
103    store_drb_progress: StoreDrbProgressFn,
104    load_drb_progress: LoadDrbProgressFn,
105) -> DrbResult {
106    tracing::warn!("Beginning DRB calculation with input {:?}", drb_input);
107    let mut drb_input = drb_input;
108
109    if let Ok(loaded_drb_input) = load_drb_progress(drb_input.epoch).await {
110        if loaded_drb_input.difficulty_level != drb_input.difficulty_level {
111            tracing::error!(
112                "We are calculating the DRB result with input {drb_input:?}, but we had \
113                 previously stored {loaded_drb_input:?} with a different difficulty level for \
114                 this epoch. Discarding the value from storage"
115            );
116        } else if loaded_drb_input.iteration >= drb_input.iteration {
117            drb_input = loaded_drb_input;
118        }
119    }
120
121    let mut hash = drb_input.value.to_vec();
122    let mut iteration = drb_input.iteration;
123    let remaining_iterations = drb_input
124        .difficulty_level
125        .checked_sub(iteration)
126        .unwrap_or_else(|| {
127            panic!(
128                "DRB difficulty level {} exceeds the iteration {} of the input we were given. \
129                 This is a fatal error",
130                drb_input.difficulty_level, iteration
131            )
132        });
133
134    let final_checkpoint = remaining_iterations / DRB_CHECKPOINT_INTERVAL;
135
136    let mut last_time = Instant::now();
137    let mut last_iteration = iteration;
138
139    // loop up to, but not including, the `final_checkpoint`
140    for _ in 0..final_checkpoint {
141        hash = tokio::task::spawn_blocking(move || {
142            let mut hash_tmp = hash.clone();
143            for _ in 0..DRB_CHECKPOINT_INTERVAL {
144                // TODO: This may be optimized to avoid memcopies after we bench the hash time.
145                // <https://github.com/EspressoSystems/HotShot/issues/3880>
146                hash_tmp = Sha256::digest(&hash_tmp).to_vec();
147            }
148
149            hash_tmp
150        })
151        .await
152        .expect("DRB calculation failed: this should never happen");
153
154        let mut partial_drb_result = [0u8; 32];
155        partial_drb_result.copy_from_slice(&hash);
156
157        iteration += DRB_CHECKPOINT_INTERVAL;
158
159        let updated_drb_input = DrbInput {
160            epoch: drb_input.epoch,
161            iteration,
162            value: partial_drb_result,
163            difficulty_level: drb_input.difficulty_level,
164        };
165
166        let elapsed_time = last_time.elapsed().as_millis();
167
168        let store_drb_progress = store_drb_progress.clone();
169        tokio::spawn(async move {
170            tracing::warn!(
171                "Storing partial DRB progress: {:?}. Time elapsed since the previous iteration of \
172                 {:?}: {:?}",
173                updated_drb_input,
174                last_iteration,
175                elapsed_time
176            );
177            if let Err(e) = store_drb_progress(updated_drb_input).await {
178                tracing::warn!("Failed to store DRB progress during calculation: {}", e);
179            }
180        });
181
182        last_time = Instant::now();
183        last_iteration = iteration;
184    }
185
186    let final_checkpoint_iteration = iteration;
187
188    // perform the remaining iterations
189    hash = tokio::task::spawn_blocking(move || {
190        let mut hash_tmp = hash.clone();
191        for _ in final_checkpoint_iteration..drb_input.difficulty_level {
192            // TODO: This may be optimized to avoid memcopies after we bench the hash time.
193            // <https://github.com/EspressoSystems/HotShot/issues/3880>
194            hash_tmp = Sha256::digest(&hash_tmp).to_vec();
195        }
196
197        hash_tmp
198    })
199    .await
200    .expect("DRB calculation failed: this should never happen");
201
202    // Convert the hash to the DRB result.
203    let mut drb_result = [0u8; 32];
204    drb_result.copy_from_slice(&hash);
205
206    let final_drb_input = DrbInput {
207        epoch: drb_input.epoch,
208        iteration: drb_input.difficulty_level,
209        value: drb_result,
210        difficulty_level: drb_input.difficulty_level,
211    };
212
213    tracing::warn!("Completed DRB calculation. Result: {:?}", final_drb_input);
214
215    let store_drb_progress = store_drb_progress.clone();
216    tokio::spawn(async move {
217        if let Err(e) = store_drb_progress(final_drb_input).await {
218            tracing::warn!("Failed to store DRB progress during calculation: {}", e);
219        }
220    });
221
222    drb_result
223}
224
225/// Seeds for DRB computation and computed results.
226#[derive(Clone, Debug)]
227pub struct DrbResults<TYPES: NodeType> {
228    /// Stored results from computations
229    pub results: BTreeMap<TYPES::Epoch, DrbResult>,
230}
231
232impl<TYPES: NodeType> DrbResults<TYPES> {
233    #[must_use]
234    /// Constructor with initial values for epochs 1 and 2.
235    pub fn new() -> Self {
236        Self {
237            results: BTreeMap::from([
238                (TYPES::Epoch::new(1), INITIAL_DRB_RESULT),
239                (TYPES::Epoch::new(2), INITIAL_DRB_RESULT),
240            ]),
241        }
242    }
243
244    pub fn store_result(&mut self, epoch: TYPES::Epoch, result: DrbResult) {
245        self.results.insert(epoch, result);
246    }
247
248    /// Garbage collects internal data structures
249    pub fn garbage_collect(&mut self, epoch: TYPES::Epoch) {
250        if epoch.u64() < KEEP_PREVIOUS_RESULT_COUNT {
251            return;
252        }
253
254        let retain_epoch = epoch - KEEP_PREVIOUS_RESULT_COUNT;
255        // N.B. x.split_off(y) returns the part of the map where key >= y
256
257        // Remove result entries older than EPOCH
258        self.results = self.results.split_off(&retain_epoch);
259    }
260}
261
262impl<TYPES: NodeType> Default for DrbResults<TYPES> {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268/// Functions for leader selection based on the DRB.
269///
270/// The algorithm we use is:
271///
272/// Initialization:
273/// - obtain `drb: [u8; 32]` from the DRB calculation
274/// - sort the stake table for a given epoch by `xor(drb, public_key)`
275/// - generate a cdf of the cumulative stake using this newly-sorted table,
276///   along with a hash of the stake table entries
277///
278/// Selecting a leader:
279/// - calculate the SHA512 hash of the `drb_result`, `view_number` and `stake_table_hash`
280/// - find the first index in the cdf for which the remainder of this hash modulo the `total_stake`
281///   is strictly smaller than the cdf entry
282/// - return the corresponding node as the leader for that view
283pub mod election {
284    use alloy::primitives::{U256, U512};
285    use sha2::{Digest, Sha256, Sha512};
286
287    use crate::traits::signature_key::{SignatureKey, StakeTableEntryType};
288
289    /// Calculate `xor(drb.cycle(), public_key)`, returning the result as a vector of bytes
290    fn cyclic_xor(drb: [u8; 32], public_key: Vec<u8>) -> Vec<u8> {
291        let drb: Vec<u8> = drb.to_vec();
292
293        let mut result: Vec<u8> = vec![];
294
295        for (drb_byte, public_key_byte) in public_key.iter().zip(drb.iter().cycle()) {
296            result.push(drb_byte ^ public_key_byte);
297        }
298
299        result
300    }
301
302    /// Generate the stake table CDF, as well as a hash of the resulting stake table
303    pub fn generate_stake_cdf<Key: SignatureKey, Entry: StakeTableEntryType<Key>>(
304        mut stake_table: Vec<Entry>,
305        drb: [u8; 32],
306    ) -> RandomizedCommittee<Entry> {
307        // sort by xor(public_key, drb_result)
308        stake_table.sort_by(|a, b| {
309            cyclic_xor(drb, a.public_key().to_bytes())
310                .cmp(&cyclic_xor(drb, b.public_key().to_bytes()))
311        });
312
313        let mut hasher = Sha256::new();
314
315        let mut cumulative_stake = U256::from(0);
316        let mut cdf = vec![];
317
318        for entry in stake_table {
319            cumulative_stake += entry.stake();
320            hasher.update(entry.public_key().to_bytes());
321
322            cdf.push((entry, cumulative_stake));
323        }
324
325        RandomizedCommittee {
326            cdf,
327            stake_table_hash: hasher.finalize().into(),
328            drb,
329        }
330    }
331
332    /// select the leader for a view
333    ///
334    /// # Panics
335    /// Panics if `cdf` is empty. Results in undefined behaviour if `cdf` is not ordered.
336    ///
337    /// Note that we try to downcast a U512 to a U256,
338    /// but this should never panic because the U512 should be strictly smaller than U256::MAX by construction.
339    pub fn select_randomized_leader<
340        SignatureKey,
341        Entry: StakeTableEntryType<SignatureKey> + Clone,
342    >(
343        randomized_committee: &RandomizedCommittee<Entry>,
344        view: u64,
345    ) -> Entry {
346        let RandomizedCommittee {
347            cdf,
348            stake_table_hash,
349            drb,
350        } = randomized_committee;
351        // We hash the concatenated drb, view and stake table hash.
352        let mut hasher = Sha512::new();
353        hasher.update(drb);
354        hasher.update(view.to_le_bytes());
355        hasher.update(stake_table_hash);
356        let raw_breakpoint: [u8; 64] = hasher.finalize().into();
357
358        // then calculate the remainder modulo the total stake as a U512
359        let remainder: U512 =
360            U512::from_le_bytes(raw_breakpoint) % U512::from(cdf.last().unwrap().1);
361
362        // and drop the top 32 bytes, downcasting to a U256
363        let breakpoint: U256 = U256::from_le_slice(&remainder.to_le_bytes_vec()[0..32]);
364
365        // now find the first index where the breakpoint is strictly smaller than the cdf
366        //
367        // in principle, this may result in an index larger than `cdf.len()`.
368        // however, we have ensured by construction that `breakpoint < total_stake`
369        // and so the largest index we can actually return is `cdf.len() - 1`
370        let index = cdf.partition_point(|(_, cumulative_stake)| breakpoint >= *cumulative_stake);
371
372        // and return the corresponding entry
373        cdf[index].0.clone()
374    }
375
376    #[derive(Clone, Debug)]
377    pub struct RandomizedCommittee<Entry> {
378        /// cdf of nodes by cumulative stake
379        cdf: Vec<(Entry, U256)>,
380        /// Hash of the stake table
381        stake_table_hash: [u8; 32],
382        /// DRB result
383        drb: [u8; 32],
384    }
385
386    impl<Entry> RandomizedCommittee<Entry> {
387        pub fn drb_result(&self) -> [u8; 32] {
388            self.drb
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::collections::HashMap;
396
397    use alloy::primitives::U256;
398    use rand::RngCore;
399    use sha2::{Digest, Sha256};
400
401    use super::election::{generate_stake_cdf, select_randomized_leader};
402    use crate::{
403        signature_key::BLSPubKey,
404        stake_table::StakeTableEntry,
405        traits::signature_key::{BuilderSignatureKey, StakeTableEntryType},
406    };
407
408    #[test]
409    fn test_randomized_leader() {
410        let mut rng = rand::thread_rng();
411        // use an arbitrary Sha256 output.
412        let drb: [u8; 32] = Sha256::digest(b"drb").into();
413        // a stake table with 10 nodes, each with a stake of 1-100
414        let stake_table_entries: Vec<_> = (0..10)
415            .map(|i| StakeTableEntry {
416                stake_key: BLSPubKey::generated_from_seed_indexed([0u8; 32], i).0,
417                stake_amount: U256::from(rng.next_u64() % 100 + 1),
418            })
419            .collect();
420        let randomized_committee = generate_stake_cdf(stake_table_entries.clone(), drb);
421
422        // Number of views to test
423        let num_views = 100000;
424        let mut selected = HashMap::<_, u64>::new();
425        // Test the leader election for 100000 views.
426        for i in 0..num_views {
427            let leader = select_randomized_leader(&randomized_committee, i);
428            *selected.entry(leader).or_insert(0) += 1;
429        }
430
431        // Total variation distance
432        let mut tvd = 0.;
433        let total_stakes = stake_table_entries
434            .iter()
435            .map(|e| e.stake())
436            .sum::<U256>()
437            .to::<u64>() as f64;
438        for entry in stake_table_entries {
439            let expected = entry.stake().to::<u64>() as f64 / total_stakes;
440            let actual = *selected.get(&entry).unwrap_or(&0) as f64 / num_views as f64;
441            tvd += (expected - actual).abs();
442        }
443
444        // sanity check
445        assert!(tvd >= 0.0);
446        // Allow a small margin of error
447        assert!(tvd < 0.03);
448    }
449}