1#![allow(clippy::panic)]
8use std::{
9 collections::{BTreeMap, HashMap, HashSet},
10 marker::PhantomData,
11 sync::Arc,
12};
13
14use async_broadcast::{Receiver, Sender, broadcast};
15use async_lock::RwLock;
16use futures::future::join_all;
17use hotshot::{
18 HotShotInitializer, InitializerEpochInfo, SystemContext,
19 traits::TestableNodeImplementation,
20 types::{Event, SystemContextHandle},
21};
22use hotshot_example_types::{
23 block_types::TestBlockHeader,
24 state_types::{TestInstanceState, TestValidatedState},
25 storage_types::TestStorage,
26};
27use hotshot_task_impls::events::HotShotEvent;
28use hotshot_types::{
29 HotShotConfig, ValidatorConfig,
30 consensus::ConsensusMetricsValue,
31 constants::EVENT_CHANNEL_SIZE,
32 data::{EpochNumber, Leaf2, ViewNumber},
33 drb::INITIAL_DRB_RESULT,
34 epoch_membership::EpochMembershipCoordinator,
35 simple_certificate::QuorumCertificate2,
36 storage_metrics::StorageMetricsValue,
37 traits::{
38 election::Membership,
39 network::ConnectedNetwork,
40 node_implementation::{NodeImplementation, NodeType},
41 },
42};
43use tide_disco::Url;
44#[allow(deprecated)]
45use tracing::info;
46
47use super::{
48 completion_task::CompletionTask, consistency_task::ConsistencyTask, txn_task::TxnTask,
49};
50use crate::{
51 block_builder::{BuilderTask, TestBuilderImplementation},
52 completion_task::CompletionTaskDescription,
53 spinning_task::{ChangeNode, NodeAction, SpinningTask},
54 test_builder::create_test_handle,
55 test_launcher::{Network, TestLauncher},
56 test_task::{TestResult, TestTask, spawn_timeout_task},
57 txn_task::TxnTaskDescription,
58 view_sync_task::ViewSyncTask,
59};
60
61pub trait TaskErr: std::error::Error + Sync + Send + 'static {}
62
63impl<T: std::error::Error + Sync + Send + 'static> TaskErr for T {}
64
65impl<
66 TYPES: NodeType<
67 InstanceState = TestInstanceState,
68 ValidatedState = TestValidatedState,
69 BlockHeader = TestBlockHeader,
70 >,
71 I: TestableNodeImplementation<TYPES>,
72 N: ConnectedNetwork<TYPES::SignatureKey>,
73> TestRunner<TYPES, I, N>
74where
75 I: TestableNodeImplementation<TYPES>,
76 I: NodeImplementation<TYPES, Network = N, Storage = TestStorage<TYPES>>,
77 <TYPES as NodeType>::Membership: Membership<TYPES, Storage = TestStorage<TYPES>>,
78{
79 #[allow(clippy::too_many_lines)]
84 pub async fn run_test<B: TestBuilderImplementation<TYPES>>(mut self) {
85 let (test_sender, test_receiver) = broadcast(EVENT_CHANNEL_SIZE);
86 let spinning_changes = self
87 .launcher
88 .metadata
89 .spinning_properties
90 .node_changes
91 .clone();
92
93 let mut late_start_nodes: HashSet<u64> = HashSet::new();
94 let mut restart_nodes: HashSet<u64> = HashSet::new();
95 for (_, changes) in &spinning_changes {
96 for change in changes {
97 if matches!(change.updown, NodeAction::Up) {
98 late_start_nodes.insert(change.idx.try_into().unwrap());
99 }
100 if matches!(change.updown, NodeAction::RestartDown(_)) {
101 restart_nodes.insert(change.idx.try_into().unwrap());
102 }
103 }
104 }
105
106 self.add_nodes::<B>(
107 self.launcher
108 .metadata
109 .test_config
110 .num_nodes_with_stake
111 .into(),
112 &late_start_nodes,
113 &restart_nodes,
114 )
115 .await;
116 let mut event_rxs = vec![];
117 let mut internal_event_rxs = vec![];
118
119 for node in &self.nodes {
120 let r = node.handle.event_stream_known_impl();
121 event_rxs.push(r);
122 }
123 for node in &self.nodes {
124 let r = node.handle.internal_event_stream_receiver_known_impl();
125 internal_event_rxs.push(r);
126 }
127
128 let TestRunner {
129 launcher,
130 nodes,
131 late_start,
132 next_node_id: _,
133 _pd: _,
134 } = self;
135
136 let mut task_futs = vec![];
137 let meta = launcher.metadata.clone();
138
139 let handles = Arc::new(RwLock::new(nodes));
140
141 let txn_task =
142 if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description {
143 let txn_task = TxnTask {
144 handles: Arc::clone(&handles),
145 next_node_idx: Some(0),
146 duration,
147 shutdown_chan: test_receiver.clone(),
148 };
149 Some(txn_task)
150 } else {
151 None
152 };
153
154 let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) =
156 meta.completion_task_description;
157 let completion_task = CompletionTask {
158 tx: test_sender.clone(),
159 rx: test_receiver.clone(),
160 duration: time_based.duration,
161 };
162
163 let mut changes: BTreeMap<ViewNumber, Vec<ChangeNode>> = BTreeMap::new();
166 for (view, mut change) in spinning_changes {
167 changes
168 .entry(ViewNumber::new(view))
169 .or_default()
170 .append(&mut change);
171 }
172
173 let spinning_task_state = SpinningTask {
174 epoch_height: launcher.metadata.test_config.epoch_height,
175 epoch_start_block: launcher.metadata.test_config.epoch_start_block,
176 start_epoch_info: Vec::new(),
177 handles: Arc::clone(&handles),
178 late_start,
179 latest_view: None,
180 changes,
181 last_decided_leaf: Leaf2::genesis(
182 &TestValidatedState::default(),
183 &TestInstanceState::default(),
184 launcher.metadata.upgrade.base,
185 )
186 .await,
187 high_qc: QuorumCertificate2::genesis(
188 &TestValidatedState::default(),
189 &TestInstanceState::default(),
190 launcher.metadata.upgrade,
191 )
192 .await,
193 next_epoch_high_qc: None,
194 async_delay_config: launcher.metadata.async_delay_config,
195 restart_contexts: HashMap::new(),
196 channel_generator: launcher.resource_generators.channel_generator,
197 state_cert: None,
198 node_stakes: launcher.metadata.node_stakes.clone(),
199 upgrade: launcher.metadata.upgrade,
200 };
201 let spinning_task = TestTask::<SpinningTask<TYPES, N, I>>::new(
202 spinning_task_state,
203 event_rxs.clone(),
204 test_receiver.clone(),
205 );
206
207 let consistency_task_state = ConsistencyTask {
208 consensus_leaves: BTreeMap::new(),
209 safety_properties: launcher.metadata.overall_safety_properties.clone(),
210 test_sender: test_sender.clone(),
211 errors: vec![],
212 ensure_upgrade: launcher.metadata.upgrade_view.is_some(),
213 validate_transactions: launcher.metadata.validate_transactions,
214 timeout_task: spawn_timeout_task(
215 test_sender.clone(),
216 launcher.metadata.overall_safety_properties.decide_timeout,
217 ),
218 };
219
220 let consistency_task = TestTask::<ConsistencyTask<TYPES>>::new(
221 consistency_task_state,
222 event_rxs.clone(),
223 test_receiver.clone(),
224 );
225
226 let view_sync_task_state = ViewSyncTask {
228 hit_view_sync: HashSet::new(),
229 description: launcher.metadata.view_sync_properties,
230 _pd: PhantomData,
231 };
232
233 let view_sync_task = TestTask::<ViewSyncTask<TYPES, I>>::new(
234 view_sync_task_state,
235 internal_event_rxs,
236 test_receiver.clone(),
237 );
238
239 let nodes = handles.read().await;
240
241 for node in &*nodes {
243 node.network.wait_for_ready().await;
244 }
245
246 for node in &*nodes {
248 if !late_start_nodes.contains(&node.node_id) {
249 node.handle.hotshot.start_consensus().await;
250 }
251 }
252
253 drop(nodes);
254
255 for seed in launcher.additional_test_tasks {
256 let task = TestTask::new(
257 seed.into_state(Arc::clone(&handles)).await,
258 event_rxs.clone(),
259 test_receiver.clone(),
260 );
261 task_futs.push(task.run());
262 }
263
264 task_futs.push(consistency_task.run());
265 task_futs.push(view_sync_task.run());
266 task_futs.push(spinning_task.run());
267
268 let txn_handle = txn_task.map(|txn| txn.run());
270 let completion_handle = completion_task.run();
271
272 let mut error_list = vec![];
273
274 let results = join_all(task_futs).await;
275
276 for result in results {
277 match result {
278 Ok(res) => match res {
279 TestResult::Pass => {
280 info!("Task shut down successfully");
281 },
282 TestResult::Fail(e) => error_list.push(e),
283 },
284 Err(e) => {
285 tracing::error!("Error Joining the test task {e:?}");
286 },
287 }
288 }
289
290 if let Some(handle) = txn_handle {
291 handle.abort();
292 }
293 let mut nodes = handles.write().await;
296
297 for node in &mut *nodes {
298 node.handle.shut_down().await;
299 }
300 tracing::info!("Nodes shutdown");
301
302 completion_handle.abort();
303
304 assert!(
305 error_list.is_empty(),
306 "{}",
307 error_list
308 .iter()
309 .fold("TEST FAILED! Results:".to_string(), |acc, error| {
310 format!("{acc}\n\n{error:?}")
311 })
312 );
313 }
314
315 pub async fn init_builders<B: TestBuilderImplementation<TYPES>>(
316 &self,
317 ) -> (Vec<Box<dyn BuilderTask<TYPES>>>, Vec<Url>) {
318 let mut builder_tasks = Vec::new();
319 let mut builder_urls = Vec::new();
320
321 let mut ports = Vec::new();
322 for _ in &self.launcher.metadata.builders {
323 let port =
324 test_utils::reserve_tcp_port().expect("OS should have ephemeral ports available");
325 ports.push(port);
326 }
327
328 for (metadata, builder_port) in self.launcher.metadata.builders.iter().zip(&ports) {
329 let builder_url =
330 Url::parse(&format!("http://localhost:{builder_port}")).expect("Invalid URL");
331 let builder_task = B::start(
332 0, builder_url.clone(),
334 B::Config::default(),
335 metadata.changes.clone(),
336 )
337 .await;
338 builder_tasks.push(builder_task);
339 builder_urls.push(builder_url);
340 }
341
342 (builder_tasks, builder_urls)
343 }
344
345 pub async fn add_nodes<B: TestBuilderImplementation<TYPES>>(
350 &mut self,
351 total: usize,
352 late_start: &HashSet<u64>,
353 restart: &HashSet<u64>,
354 ) -> Vec<u64> {
355 let mut results = vec![];
356 let config = self.launcher.metadata.test_config.clone();
357
358 let (mut builder_tasks, builder_urls) = self.init_builders::<B>().await;
361
362 let mut uninitialized_nodes = Vec::new();
364 let mut networks_ready = Vec::new();
365
366 for i in 0..total {
367 let mut config = config.clone();
368 if let Some(upgrade_view) = self.launcher.metadata.upgrade_view {
369 config.set_view_upgrade(upgrade_view);
370 }
371 let node_id = self.next_node_id;
372 self.next_node_id += 1;
373 tracing::debug!("launch node {i}");
374
375 config.builder_urls = builder_urls
376 .clone()
377 .try_into()
378 .expect("Non-empty by construction");
379
380 let storage = (self.launcher.resource_generators.storage)(node_id);
381
382 let is_da = node_id < config.da_staked_committee_size as u64;
384
385 let validator_config = ValidatorConfig::<TYPES>::generated_from_seed_indexed(
387 [0u8; 32],
388 node_id,
389 self.launcher.metadata.node_stakes.get(node_id),
390 is_da,
391 );
392
393 let public_key = validator_config.public_key.clone();
394
395 let network = if late_start.contains(&node_id) && self.launcher.metadata.skip_late {
396 None
397 } else {
398 let net = (self.launcher.resource_generators.channel_generator)(node_id).await;
399 networks_ready.push({
400 let net = net.clone();
401 async move { net.wait_for_ready().await }
402 });
403 Some(net)
404 };
405
406 if late_start.contains(&node_id) {
407 if self.launcher.metadata.skip_late {
408 self.late_start.insert(
409 node_id,
410 LateStartNode {
411 network: None,
412 context: LateNodeContext::UninitializedContext(
413 LateNodeContextParameters {
414 storage: storage.clone(),
415 config,
416 },
417 ),
418 },
419 );
420 } else {
421 let initializer = HotShotInitializer::<TYPES>::from_genesis(
422 TestInstanceState::new(
423 self.launcher
424 .metadata
425 .async_delay_config
426 .get(&node_id)
427 .cloned()
428 .unwrap_or_default(),
429 ),
430 config.epoch_height,
431 config.epoch_start_block,
432 vec![InitializerEpochInfo::<TYPES> {
433 epoch: EpochNumber::new(1),
434 drb_result: INITIAL_DRB_RESULT,
435 block_header: None,
436 }],
437 self.launcher.metadata.upgrade,
438 )
439 .await
440 .unwrap();
441
442 let network = network.clone().expect("!skip_late => network created");
443
444 let hotshot = Self::add_node_with_config(
445 node_id,
446 network.clone(),
447 <TYPES as NodeType>::Membership::new::<I>(
448 config.known_nodes_with_stake.clone(),
449 config.known_da_nodes.clone(),
450 storage.clone(),
451 network.clone(),
452 public_key.clone(),
453 config.epoch_height,
454 ),
455 initializer,
456 config,
457 self.launcher.metadata.upgrade,
458 validator_config,
459 storage,
460 )
461 .await;
462 self.late_start.insert(
463 node_id,
464 LateStartNode {
465 network: Some(network),
466 context: LateNodeContext::InitializedContext(hotshot),
467 },
468 );
469 }
470 } else {
471 let network = network.expect("!late_start => network created");
472 uninitialized_nodes.push((
473 node_id,
474 network.clone(),
475 <TYPES as NodeType>::Membership::new::<I>(
476 config.known_nodes_with_stake.clone(),
477 config.known_da_nodes.clone(),
478 storage.clone(),
479 network,
480 public_key.clone(),
481 config.epoch_height,
482 ),
483 config,
484 storage,
485 ));
486 }
487
488 results.push(node_id);
489 }
490
491 for node_id in &results {
494 if restart.contains(node_id) {
495 self.late_start.insert(
496 *node_id,
497 LateStartNode {
498 network: None,
499 context: LateNodeContext::Restart,
500 },
501 );
502 }
503 }
504
505 join_all(networks_ready).await;
507
508 for (node_id, network, memberships, config, storage) in uninitialized_nodes {
510 let handle = create_test_handle(
511 self.launcher.metadata.clone(),
512 node_id,
513 network.clone(),
514 Arc::new(RwLock::new(memberships)),
515 config.clone(),
516 storage,
517 )
518 .await;
519
520 match node_id.cmp(&(config.da_staked_committee_size as u64 - 1)) {
521 std::cmp::Ordering::Less => {
522 if let Some(task) = builder_tasks.pop() {
523 task.start(Box::new(handle.event_stream()))
524 }
525 },
526 std::cmp::Ordering::Equal => {
527 while let Some(task) = builder_tasks.pop() {
529 task.start(Box::new(handle.event_stream()))
530 }
531 },
532 std::cmp::Ordering::Greater => {},
533 }
534
535 self.nodes.push(Node {
536 node_id,
537 network,
538 handle,
539 });
540 }
541
542 results
543 }
544
545 #[allow(clippy::too_many_arguments)]
549 pub async fn add_node_with_config(
550 node_id: u64,
551 network: Network<TYPES, I>,
552 memberships: TYPES::Membership,
553 initializer: HotShotInitializer<TYPES>,
554 config: HotShotConfig<TYPES>,
555 upgrade: versions::Upgrade,
556 validator_config: ValidatorConfig<TYPES>,
557 storage: I::Storage,
558 ) -> Arc<SystemContext<TYPES, I>> {
559 let private_key = validator_config.private_key.clone();
561 let public_key = validator_config.public_key.clone();
562 let state_private_key = validator_config.state_private_key.clone();
563 let epoch_height = config.epoch_height;
564
565 SystemContext::new(
566 public_key,
567 private_key,
568 state_private_key,
569 node_id,
570 config,
571 upgrade,
572 EpochMembershipCoordinator::new(
573 Arc::new(RwLock::new(memberships)),
574 epoch_height,
575 &storage.clone(),
576 ),
577 network,
578 initializer,
579 ConsensusMetricsValue::default(),
580 storage,
581 StorageMetricsValue::default(),
582 )
583 .await
584 }
585
586 #[allow(clippy::too_many_arguments, clippy::type_complexity)]
590 pub async fn add_node_with_config_and_channels(
591 node_id: u64,
592 network: Network<TYPES, I>,
593 memberships: Arc<RwLock<TYPES::Membership>>,
594 initializer: HotShotInitializer<TYPES>,
595 config: HotShotConfig<TYPES>,
596 upgrade: versions::Upgrade,
597 validator_config: ValidatorConfig<TYPES>,
598 storage: I::Storage,
599 internal_channel: (
600 Sender<Arc<HotShotEvent<TYPES>>>,
601 Receiver<Arc<HotShotEvent<TYPES>>>,
602 ),
603 external_channel: (Sender<Event<TYPES>>, Receiver<Event<TYPES>>),
604 ) -> Arc<SystemContext<TYPES, I>> {
605 let private_key = validator_config.private_key.clone();
607 let public_key = validator_config.public_key.clone();
608 let state_private_key = validator_config.state_private_key.clone();
609 let epoch_height = config.epoch_height;
610
611 SystemContext::new_from_channels(
612 public_key,
613 private_key,
614 state_private_key,
615 node_id,
616 config,
617 upgrade,
618 EpochMembershipCoordinator::new(memberships, epoch_height, &storage.clone()),
619 network,
620 initializer,
621 ConsensusMetricsValue::default(),
622 storage,
623 StorageMetricsValue::default(),
624 internal_channel,
625 external_channel,
626 )
627 .await
628 }
629}
630
631pub struct Node<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
633 pub node_id: u64,
635 pub network: Network<TYPES, I>,
637 pub handle: SystemContextHandle<TYPES, I>,
639}
640
641pub struct LateNodeContextParameters<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
644 pub storage: I::Storage,
646
647 pub config: HotShotConfig<TYPES>,
649}
650
651#[allow(clippy::large_enum_variant)]
653pub enum LateNodeContext<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
654 InitializedContext(Arc<SystemContext<TYPES, I>>),
657
658 UninitializedContext(LateNodeContextParameters<TYPES, I>),
661 Restart,
663}
664
665pub struct LateStartNode<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
667 pub network: Option<Network<TYPES, I>>,
669 pub context: LateNodeContext<TYPES, I>,
672}
673
674pub struct TestRunner<
677 TYPES: NodeType,
678 I: TestableNodeImplementation<TYPES>,
679 N: ConnectedNetwork<TYPES::SignatureKey>,
680> {
681 pub(crate) launcher: TestLauncher<TYPES, I>,
683 pub(crate) nodes: Vec<Node<TYPES, I>>,
685 pub(crate) late_start: HashMap<u64, LateStartNode<TYPES, I>>,
687 pub(crate) next_node_id: u64,
689 pub(crate) _pd: PhantomData<N>,
691}