hotshot_example_types/membership/
strict_membership.rs1use std::{collections::HashSet, fmt, fmt::Debug, sync::Arc};
2
3use alloy::primitives::U256;
4use async_broadcast::Receiver;
5use async_lock::RwLock;
6use hotshot_types::{
7 PeerConfig,
8 data::{EpochNumber, Leaf2, ViewNumber},
9 drb::DrbResult,
10 event::Event,
11 stake_table::HSStakeTable,
12 traits::{
13 block_contents::BlockHeader,
14 election::{Membership, NoStakeTableHash},
15 node_implementation::{NodeImplementation, NodeType},
16 signature_key::StakeTableEntryType,
17 },
18 utils::{epoch_from_block_number, root_block_in_epoch, transition_block_for_epoch},
19};
20
21use crate::{
22 membership::{fetcher::Leaf2Fetcher, stake_table::TestStakeTable},
23 storage_types::TestStorage,
24};
25
26#[derive(Clone)]
27pub struct StrictMembership<
28 TYPES: NodeType,
29 StakeTable: TestStakeTable<TYPES::SignatureKey, TYPES::StateSignatureKey>,
30> {
31 inner: StakeTable,
32 epochs: HashSet<EpochNumber>,
33 drbs: HashSet<EpochNumber>,
34 fetcher: Arc<RwLock<Leaf2Fetcher<TYPES>>>,
35 epoch_height: u64,
36}
37
38impl<TYPES, StakeTable> Debug for StrictMembership<TYPES, StakeTable>
39where
40 TYPES: NodeType,
41 StakeTable: TestStakeTable<TYPES::SignatureKey, TYPES::StateSignatureKey>,
42{
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
44 f.debug_struct("StrictMembership")
45 .field("inner", &self.inner)
46 .field("epochs", &self.epochs)
47 .field("drbs", &self.drbs)
48 .finish()
49 }
50}
51
52impl<TYPES: NodeType, StakeTable: TestStakeTable<TYPES::SignatureKey, TYPES::StateSignatureKey>>
53 StrictMembership<TYPES, StakeTable>
54{
55 fn assert_has_stake_table(&self, epoch: Option<EpochNumber>) {
56 let Some(epoch) = epoch else {
57 return;
58 };
59 assert!(
60 self.epochs.contains(&epoch),
61 "Failed stake table check for epoch {epoch}"
62 );
63 }
64 fn assert_has_randomized_stake_table(&self, epoch: Option<EpochNumber>) {
65 let Some(epoch) = epoch else {
66 return;
67 };
68 assert!(
69 self.drbs.contains(&epoch),
70 "Failed drb check for epoch {epoch}"
71 );
72 }
73}
74
75impl<TYPES: NodeType, StakeTable: TestStakeTable<TYPES::SignatureKey, TYPES::StateSignatureKey>>
76 Membership<TYPES> for StrictMembership<TYPES, StakeTable>
77{
78 type Error = anyhow::Error;
79 type StakeTableHash = NoStakeTableHash;
80 type Storage = TestStorage<TYPES>;
81
82 fn new<I: NodeImplementation<TYPES>>(
83 quorum_members: Vec<hotshot_types::PeerConfig<TYPES>>,
84 da_members: Vec<hotshot_types::PeerConfig<TYPES>>,
85 storage: Self::Storage,
86 network: Arc<<I as NodeImplementation<TYPES>>::Network>,
87 public_key: TYPES::SignatureKey,
88 epoch_height: u64,
89 ) -> Self {
90 let fetcher = Leaf2Fetcher::new::<I>(network, storage, public_key);
91
92 Self {
93 inner: TestStakeTable::new(
94 quorum_members.into_iter().map(Into::into).collect(),
95 da_members.into_iter().map(Into::into).collect(),
96 ),
97 epochs: HashSet::new(),
98 drbs: HashSet::new(),
99 fetcher: RwLock::new(fetcher).into(),
100 epoch_height,
101 }
102 }
103
104 async fn set_external_channel(&mut self, external_channel: Receiver<Event<TYPES>>) {
105 self.fetcher
106 .write()
107 .await
108 .set_external_channel(external_channel)
109 }
110
111 fn stake_table(&self, epoch: Option<EpochNumber>) -> HSStakeTable<TYPES> {
112 self.assert_has_stake_table(epoch);
113 let peer_configs = self
114 .inner
115 .stake_table(epoch.map(|e| *e))
116 .into_iter()
117 .map(Into::into)
118 .collect();
119 HSStakeTable(peer_configs)
120 }
121
122 fn da_stake_table(&self, epoch: Option<EpochNumber>) -> HSStakeTable<TYPES> {
123 self.assert_has_stake_table(epoch);
124 let peer_configs = self
125 .inner
126 .da_stake_table(epoch.map(|e| *e))
127 .into_iter()
128 .map(Into::into)
129 .collect();
130 HSStakeTable(peer_configs)
131 }
132
133 fn committee_members(
134 &self,
135 _view_number: ViewNumber,
136 epoch: Option<EpochNumber>,
137 ) -> std::collections::BTreeSet<TYPES::SignatureKey> {
138 self.assert_has_stake_table(epoch);
139 self.inner
140 .stake_table(epoch.map(|e| *e))
141 .into_iter()
142 .map(|entry| entry.signature_key)
143 .collect()
144 }
145
146 fn da_committee_members(
147 &self,
148 _view_number: ViewNumber,
149 epoch: Option<EpochNumber>,
150 ) -> std::collections::BTreeSet<TYPES::SignatureKey> {
151 self.assert_has_stake_table(epoch);
152 self.inner
153 .da_stake_table(epoch.map(|e| *e))
154 .into_iter()
155 .map(|entry| entry.signature_key)
156 .collect()
157 }
158
159 fn stake(
160 &self,
161 pub_key: &TYPES::SignatureKey,
162 epoch: Option<EpochNumber>,
163 ) -> Option<hotshot_types::PeerConfig<TYPES>> {
164 self.assert_has_stake_table(epoch);
165 self.inner
166 .stake(pub_key.clone(), epoch.map(|e| *e))
167 .map(Into::into)
168 }
169
170 fn da_stake(
171 &self,
172 pub_key: &TYPES::SignatureKey,
173 epoch: Option<EpochNumber>,
174 ) -> Option<hotshot_types::PeerConfig<TYPES>> {
175 self.assert_has_stake_table(epoch);
176 self.inner
177 .da_stake(pub_key.clone(), epoch.map(|e| *e))
178 .map(Into::into)
179 }
180
181 fn has_stake(
183 &self,
184 pub_key: &<TYPES as NodeType>::SignatureKey,
185 epoch: Option<EpochNumber>,
186 ) -> bool {
187 self.assert_has_stake_table(epoch);
188
189 self.stake(pub_key, epoch)
190 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
191 }
192
193 fn has_da_stake(
195 &self,
196 pub_key: &<TYPES as NodeType>::SignatureKey,
197 epoch: Option<EpochNumber>,
198 ) -> bool {
199 self.assert_has_stake_table(epoch);
200
201 self.da_stake(pub_key, epoch)
202 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
203 }
204
205 fn lookup_leader(
206 &self,
207 view: ViewNumber,
208 epoch: Option<EpochNumber>,
209 ) -> anyhow::Result<TYPES::SignatureKey> {
210 self.assert_has_randomized_stake_table(epoch);
211 self.inner.lookup_leader(*view, epoch.map(|e| *e))
212 }
213
214 fn total_nodes(&self, epoch: Option<EpochNumber>) -> usize {
215 self.assert_has_stake_table(epoch);
216 self.inner.stake_table(epoch.map(|e| *e)).len()
217 }
218
219 fn da_total_nodes(&self, epoch: Option<EpochNumber>) -> usize {
220 self.assert_has_stake_table(epoch);
221 self.inner.da_stake_table(epoch.map(|e| *e)).len()
222 }
223
224 fn has_stake_table(&self, epoch: EpochNumber) -> bool {
225 let has_stake_table = self.inner.has_stake_table(*epoch);
226
227 assert_eq!(has_stake_table, self.epochs.contains(&epoch));
228
229 has_stake_table
230 }
231
232 fn has_randomized_stake_table(&self, epoch: EpochNumber) -> anyhow::Result<bool> {
233 if !self.has_stake_table(epoch) {
234 return Ok(false);
235 }
236 let has_randomized_stake_table = self.inner.has_randomized_stake_table(*epoch);
237
238 if let Ok(result) = has_randomized_stake_table {
239 assert_eq!(result, self.drbs.contains(&epoch));
240 } else {
241 assert!(!self.drbs.contains(&epoch));
242 }
243
244 has_randomized_stake_table
245 }
246
247 fn add_drb_result(&mut self, epoch: EpochNumber, drb_result: hotshot_types::drb::DrbResult) {
248 self.assert_has_stake_table(Some(epoch));
249
250 self.drbs.insert(epoch);
251 self.inner.add_drb_result(*epoch, drb_result);
252 }
253
254 fn first_epoch(&self) -> Option<EpochNumber> {
255 self.inner.first_epoch().map(EpochNumber::new)
256 }
257
258 fn set_first_epoch(&mut self, epoch: EpochNumber, initial_drb_result: DrbResult) {
259 self.epochs.insert(epoch);
260 self.epochs.insert(epoch + 1);
261
262 self.drbs.insert(epoch);
263 self.drbs.insert(epoch + 1);
264
265 self.inner.set_first_epoch(*epoch, initial_drb_result);
266 }
267
268 async fn add_epoch_root(
269 membership: Arc<RwLock<Self>>,
270 block_header: TYPES::BlockHeader,
271 ) -> anyhow::Result<()> {
272 let mut membership_writer = membership.write().await;
273
274 let epoch =
275 epoch_from_block_number(block_header.block_number(), membership_writer.epoch_height)
276 + 2;
277
278 membership_writer.epochs.insert(EpochNumber::new(epoch));
279
280 membership_writer.inner.add_epoch_root(epoch);
281
282 Ok(())
283 }
284
285 async fn get_epoch_root(
286 membership: Arc<RwLock<Self>>,
287 epoch: EpochNumber,
288 ) -> anyhow::Result<Leaf2<TYPES>> {
289 let membership_reader = membership.read().await;
290
291 let block_height = root_block_in_epoch(*epoch, membership_reader.epoch_height);
292
293 let stake_table = membership_reader.inner.stake_table(Some(*epoch));
294 let fetcher = membership_reader.fetcher.clone();
295
296 drop(membership_reader);
297
298 for node in stake_table {
299 if let Ok(leaf) = fetcher
300 .read()
301 .await
302 .fetch_leaf(block_height, node.signature_key)
303 .await
304 {
305 return Ok(leaf);
306 }
307 }
308
309 anyhow::bail!("Failed to fetch epoch root from any peer");
310 }
311
312 async fn get_epoch_drb(
313 membership: Arc<RwLock<Self>>,
314 epoch: EpochNumber,
315 ) -> anyhow::Result<DrbResult> {
316 let membership_reader = membership.read().await;
317
318 let epoch_height = membership_reader.epoch_height;
319 let epoch_drb = membership_reader.inner.get_epoch_drb(*epoch);
320 let fetcher = membership_reader.fetcher.clone();
321
322 drop(membership_reader);
323
324 if let Ok(drb_result) = epoch_drb {
325 Ok(drb_result)
326 } else {
327 let previous_epoch = match epoch.checked_sub(1) {
328 Some(epoch) => epoch,
329 None => {
330 anyhow::bail!("Missing initial DRB result for epoch {epoch:?}");
331 },
332 };
333
334 let drb_block_height = transition_block_for_epoch(previous_epoch, epoch_height);
335
336 let membership_reader = membership.read().await;
337 let stake_table = membership_reader.inner.stake_table(Some(previous_epoch));
338 drop(membership_reader);
339
340 let mut drb_leaf = None;
341
342 for node in stake_table {
343 if let Ok(leaf) = fetcher
344 .read()
345 .await
346 .fetch_leaf(drb_block_height, node.signature_key)
347 .await
348 {
349 drb_leaf = Some(leaf);
350 break;
351 }
352 }
353
354 match drb_leaf {
355 Some(leaf) => Ok(leaf.next_drb_result.expect(
356 "We fetched a leaf that is missing a DRB result. This should be impossible.",
357 )),
358 None => {
359 anyhow::bail!(
360 "Failed to fetch leaf from all nodes. Height: {drb_block_height}"
361 );
362 },
363 }
364 }
365 }
366
367 fn add_da_committee(&mut self, first_epoch: u64, committee: Vec<PeerConfig<TYPES>>) {
368 self.inner
369 .add_da_committee(first_epoch, committee.into_iter().map(Into::into).collect());
370 }
371}