hotshot_testing/
test_runner.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7#![allow(clippy::panic)]
8use std::{
9    collections::{BTreeMap, HashMap, HashSet},
10    marker::PhantomData,
11    sync::Arc,
12};
13
14use async_broadcast::{broadcast, Receiver, Sender};
15use async_lock::RwLock;
16use futures::future::join_all;
17use hotshot::{
18    traits::TestableNodeImplementation,
19    types::{Event, SystemContextHandle},
20    HotShotInitializer, InitializerEpochInfo, SystemContext,
21};
22use hotshot_example_types::{
23    block_types::TestBlockHeader,
24    state_types::{TestInstanceState, TestValidatedState},
25    storage_types::TestStorage,
26};
27use hotshot_task_impls::events::HotShotEvent;
28use hotshot_types::{
29    consensus::ConsensusMetricsValue,
30    constants::EVENT_CHANNEL_SIZE,
31    data::Leaf2,
32    drb::INITIAL_DRB_RESULT,
33    epoch_membership::EpochMembershipCoordinator,
34    simple_certificate::QuorumCertificate2,
35    storage_metrics::StorageMetricsValue,
36    traits::{
37        election::Membership,
38        network::ConnectedNetwork,
39        node_implementation::{ConsensusTime, NodeImplementation, NodeType, Versions},
40    },
41    HotShotConfig, ValidatorConfig,
42};
43use tide_disco::Url;
44#[allow(deprecated)]
45use tracing::info;
46
47use super::{
48    completion_task::CompletionTask, consistency_task::ConsistencyTask, txn_task::TxnTask,
49};
50use crate::{
51    block_builder::{BuilderTask, TestBuilderImplementation},
52    completion_task::CompletionTaskDescription,
53    spinning_task::{ChangeNode, NodeAction, SpinningTask},
54    test_builder::create_test_handle,
55    test_launcher::{Network, TestLauncher},
56    test_task::{spawn_timeout_task, TestResult, TestTask},
57    txn_task::TxnTaskDescription,
58    view_sync_task::ViewSyncTask,
59};
60
61pub trait TaskErr: std::error::Error + Sync + Send + 'static {}
62impl<T: std::error::Error + Sync + Send + 'static> TaskErr for T {}
63
64impl<
65        TYPES: NodeType<
66            InstanceState = TestInstanceState,
67            ValidatedState = TestValidatedState,
68            BlockHeader = TestBlockHeader,
69        >,
70        I: TestableNodeImplementation<TYPES>,
71        V: Versions,
72        N: ConnectedNetwork<TYPES::SignatureKey>,
73    > TestRunner<TYPES, I, V, N>
74where
75    I: TestableNodeImplementation<TYPES>,
76    I: NodeImplementation<TYPES, Network = N, Storage = TestStorage<TYPES>>,
77    <TYPES as NodeType>::Membership: Membership<TYPES, Storage = TestStorage<TYPES>>,
78{
79    /// execute test
80    ///
81    /// # Panics
82    /// if the test fails
83    #[allow(clippy::too_many_lines)]
84    pub async fn run_test<B: TestBuilderImplementation<TYPES>>(mut self) {
85        let (test_sender, test_receiver) = broadcast(EVENT_CHANNEL_SIZE);
86        let spinning_changes = self
87            .launcher
88            .metadata
89            .spinning_properties
90            .node_changes
91            .clone();
92
93        let mut late_start_nodes: HashSet<u64> = HashSet::new();
94        let mut restart_nodes: HashSet<u64> = HashSet::new();
95        for (_, changes) in &spinning_changes {
96            for change in changes {
97                if matches!(change.updown, NodeAction::Up) {
98                    late_start_nodes.insert(change.idx.try_into().unwrap());
99                }
100                if matches!(change.updown, NodeAction::RestartDown(_)) {
101                    restart_nodes.insert(change.idx.try_into().unwrap());
102                }
103            }
104        }
105
106        self.add_nodes::<B>(
107            self.launcher
108                .metadata
109                .test_config
110                .num_nodes_with_stake
111                .into(),
112            &late_start_nodes,
113            &restart_nodes,
114        )
115        .await;
116        let mut event_rxs = vec![];
117        let mut internal_event_rxs = vec![];
118
119        for node in &self.nodes {
120            let r = node.handle.event_stream_known_impl();
121            event_rxs.push(r);
122        }
123        for node in &self.nodes {
124            let r = node.handle.internal_event_stream_receiver_known_impl();
125            internal_event_rxs.push(r);
126        }
127
128        let TestRunner {
129            launcher,
130            nodes,
131            late_start,
132            next_node_id: _,
133            _pd: _,
134        } = self;
135
136        let mut task_futs = vec![];
137        let meta = launcher.metadata.clone();
138
139        let handles = Arc::new(RwLock::new(nodes));
140
141        let txn_task =
142            if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description {
143                let txn_task = TxnTask {
144                    handles: Arc::clone(&handles),
145                    next_node_idx: Some(0),
146                    duration,
147                    shutdown_chan: test_receiver.clone(),
148                };
149                Some(txn_task)
150            } else {
151                None
152            };
153
154        // add completion task
155        let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) =
156            meta.completion_task_description;
157        let completion_task = CompletionTask {
158            tx: test_sender.clone(),
159            rx: test_receiver.clone(),
160            duration: time_based.duration,
161        };
162
163        // add spinning task
164        // map spinning to view
165        let mut changes: BTreeMap<TYPES::View, Vec<ChangeNode>> = BTreeMap::new();
166        for (view, mut change) in spinning_changes {
167            changes
168                .entry(TYPES::View::new(view))
169                .or_insert_with(Vec::new)
170                .append(&mut change);
171        }
172
173        let spinning_task_state = SpinningTask {
174            epoch_height: launcher.metadata.test_config.epoch_height,
175            epoch_start_block: launcher.metadata.test_config.epoch_start_block,
176            start_epoch_info: Vec::new(),
177            handles: Arc::clone(&handles),
178            late_start,
179            latest_view: None,
180            changes,
181            last_decided_leaf: Leaf2::genesis::<V>(
182                &TestValidatedState::default(),
183                &TestInstanceState::default(),
184            )
185            .await,
186            high_qc: QuorumCertificate2::genesis::<V>(
187                &TestValidatedState::default(),
188                &TestInstanceState::default(),
189            )
190            .await,
191            next_epoch_high_qc: None,
192            async_delay_config: launcher.metadata.async_delay_config,
193            restart_contexts: HashMap::new(),
194            channel_generator: launcher.resource_generators.channel_generator,
195            state_cert: None,
196            node_stakes: launcher.metadata.node_stakes.clone(),
197        };
198        let spinning_task = TestTask::<SpinningTask<TYPES, N, I, V>>::new(
199            spinning_task_state,
200            event_rxs.clone(),
201            test_receiver.clone(),
202        );
203
204        let consistency_task_state = ConsistencyTask {
205            consensus_leaves: BTreeMap::new(),
206            safety_properties: launcher.metadata.overall_safety_properties.clone(),
207            test_sender: test_sender.clone(),
208            errors: vec![],
209            ensure_upgrade: launcher.metadata.upgrade_view.is_some(),
210            validate_transactions: launcher.metadata.validate_transactions,
211            timeout_task: spawn_timeout_task(
212                test_sender.clone(),
213                launcher.metadata.overall_safety_properties.decide_timeout,
214            ),
215            _pd: PhantomData,
216        };
217
218        let consistency_task = TestTask::<ConsistencyTask<TYPES, V>>::new(
219            consistency_task_state,
220            event_rxs.clone(),
221            test_receiver.clone(),
222        );
223
224        // add view sync task
225        let view_sync_task_state = ViewSyncTask {
226            hit_view_sync: HashSet::new(),
227            description: launcher.metadata.view_sync_properties,
228            _pd: PhantomData,
229        };
230
231        let view_sync_task = TestTask::<ViewSyncTask<TYPES, I>>::new(
232            view_sync_task_state,
233            internal_event_rxs,
234            test_receiver.clone(),
235        );
236
237        let nodes = handles.read().await;
238
239        // wait for networks to be ready
240        for node in &*nodes {
241            node.network.wait_for_ready().await;
242        }
243
244        // Start hotshot
245        for node in &*nodes {
246            if !late_start_nodes.contains(&node.node_id) {
247                node.handle.hotshot.start_consensus().await;
248            }
249        }
250
251        drop(nodes);
252
253        for seed in launcher.additional_test_tasks {
254            let task = TestTask::new(
255                seed.into_state(Arc::clone(&handles)).await,
256                event_rxs.clone(),
257                test_receiver.clone(),
258            );
259            task_futs.push(task.run());
260        }
261
262        task_futs.push(consistency_task.run());
263        task_futs.push(view_sync_task.run());
264        task_futs.push(spinning_task.run());
265
266        // `generator` tasks that do not process events.
267        let txn_handle = txn_task.map(|txn| txn.run());
268        let completion_handle = completion_task.run();
269
270        let mut error_list = vec![];
271
272        let results = join_all(task_futs).await;
273
274        for result in results {
275            match result {
276                Ok(res) => match res {
277                    TestResult::Pass => {
278                        info!("Task shut down successfully");
279                    },
280                    TestResult::Fail(e) => error_list.push(e),
281                },
282                Err(e) => {
283                    tracing::error!("Error Joining the test task {e:?}");
284                },
285            }
286        }
287
288        if let Some(handle) = txn_handle {
289            handle.abort();
290        }
291        // Shutdown all of the servers at the end
292
293        let mut nodes = handles.write().await;
294
295        for node in &mut *nodes {
296            node.handle.shut_down().await;
297        }
298        tracing::info!("Nodes shutdown");
299
300        completion_handle.abort();
301
302        assert!(
303            error_list.is_empty(),
304            "{}",
305            error_list
306                .iter()
307                .fold("TEST FAILED! Results:".to_string(), |acc, error| {
308                    format!("{acc}\n\n{error:?}")
309                })
310        );
311    }
312
313    pub async fn init_builders<B: TestBuilderImplementation<TYPES>>(
314        &self,
315    ) -> (Vec<Box<dyn BuilderTask<TYPES>>>, Vec<Url>) {
316        let mut builder_tasks = Vec::new();
317        let mut builder_urls = Vec::new();
318        for metadata in &self.launcher.metadata.builders {
319            let builder_port = portpicker::pick_unused_port().expect("No free ports");
320            let builder_url =
321                Url::parse(&format!("http://localhost:{builder_port}")).expect("Invalid URL");
322            let builder_task = B::start(
323                0, // This field gets updated while the test is running, 0 is just to seed it
324                builder_url.clone(),
325                B::Config::default(),
326                metadata.changes.clone(),
327            )
328            .await;
329            builder_tasks.push(builder_task);
330            builder_urls.push(builder_url);
331        }
332
333        (builder_tasks, builder_urls)
334    }
335
336    /// Add nodes.
337    ///
338    /// # Panics
339    /// Panics if unable to create a [`HotShotInitializer`]
340    pub async fn add_nodes<B: TestBuilderImplementation<TYPES>>(
341        &mut self,
342        total: usize,
343        late_start: &HashSet<u64>,
344        restart: &HashSet<u64>,
345    ) -> Vec<u64> {
346        let mut results = vec![];
347        let config = self.launcher.metadata.test_config.clone();
348
349        // Num_nodes is updated on the fly now via claim_block_with_num_nodes. This stays around to seed num_nodes
350        // in the builders for tests which don't update that field.
351        let (mut builder_tasks, builder_urls) = self.init_builders::<B>().await;
352
353        // Collect uninitialized nodes because we need to wait for all networks to be ready before starting the tasks
354        let mut uninitialized_nodes = Vec::new();
355        let mut networks_ready = Vec::new();
356
357        for i in 0..total {
358            let mut config = config.clone();
359            if let Some(upgrade_view) = self.launcher.metadata.upgrade_view {
360                config.set_view_upgrade(upgrade_view);
361            }
362            let node_id = self.next_node_id;
363            self.next_node_id += 1;
364            tracing::debug!("launch node {i}");
365
366            config.builder_urls = builder_urls
367                .clone()
368                .try_into()
369                .expect("Non-empty by construction");
370
371            let network = (self.launcher.resource_generators.channel_generator)(node_id).await;
372            let storage = (self.launcher.resource_generators.storage)(node_id);
373
374            let network_clone = network.clone();
375            let networks_ready_future = async move {
376                network_clone.wait_for_ready().await;
377            };
378
379            networks_ready.push(networks_ready_future);
380
381            // See whether or not we should be DA
382            let is_da = node_id < config.da_staked_committee_size as u64;
383
384            // We assign node's public key and stake value rather than read from config file since it's a test
385            let validator_config = ValidatorConfig::<TYPES>::generated_from_seed_indexed(
386                [0u8; 32],
387                node_id,
388                self.launcher.metadata.node_stakes.get(node_id),
389                is_da,
390            );
391
392            let public_key = validator_config.public_key.clone();
393
394            if late_start.contains(&node_id) {
395                if self.launcher.metadata.skip_late {
396                    self.late_start.insert(
397                        node_id,
398                        LateStartNode {
399                            network: None,
400                            context: LateNodeContext::UninitializedContext(
401                                LateNodeContextParameters {
402                                    storage: storage.clone(),
403                                    memberships: <TYPES as NodeType>::Membership::new::<I>(
404                                        config.known_nodes_with_stake.clone(),
405                                        config.known_da_nodes.clone(),
406                                        storage.clone(),
407                                        network.clone(),
408                                        public_key.clone(),
409                                        config.epoch_height,
410                                    ),
411                                    config,
412                                },
413                            ),
414                        },
415                    );
416                } else {
417                    let initializer = HotShotInitializer::<TYPES>::from_genesis::<V>(
418                        TestInstanceState::new(
419                            self.launcher
420                                .metadata
421                                .async_delay_config
422                                .get(&node_id)
423                                .cloned()
424                                .unwrap_or_default(),
425                        ),
426                        config.epoch_height,
427                        config.epoch_start_block,
428                        vec![InitializerEpochInfo::<TYPES> {
429                            epoch: TYPES::Epoch::new(1),
430                            drb_result: INITIAL_DRB_RESULT,
431                            block_header: None,
432                        }],
433                    )
434                    .await
435                    .unwrap();
436
437                    let hotshot = Self::add_node_with_config(
438                        node_id,
439                        network.clone(),
440                        <TYPES as NodeType>::Membership::new::<I>(
441                            config.known_nodes_with_stake.clone(),
442                            config.known_da_nodes.clone(),
443                            storage.clone(),
444                            network.clone(),
445                            public_key.clone(),
446                            config.epoch_height,
447                        ),
448                        initializer,
449                        config,
450                        validator_config,
451                        storage,
452                    )
453                    .await;
454                    self.late_start.insert(
455                        node_id,
456                        LateStartNode {
457                            network: Some(network),
458                            context: LateNodeContext::InitializedContext(hotshot),
459                        },
460                    );
461                }
462            } else {
463                uninitialized_nodes.push((
464                    node_id,
465                    network.clone(),
466                    <TYPES as NodeType>::Membership::new::<I>(
467                        config.known_nodes_with_stake.clone(),
468                        config.known_da_nodes.clone(),
469                        storage.clone(),
470                        network,
471                        public_key.clone(),
472                        config.epoch_height,
473                    ),
474                    config,
475                    storage,
476                ));
477            }
478
479            results.push(node_id);
480        }
481
482        // Add the restart nodes after the rest.  This must be done after all the original networks are
483        // created because this will reset the bootstrap info for the restarted nodes
484        for node_id in &results {
485            if restart.contains(node_id) {
486                self.late_start.insert(
487                    *node_id,
488                    LateStartNode {
489                        network: None,
490                        context: LateNodeContext::Restart,
491                    },
492                );
493            }
494        }
495
496        // Wait for all networks to be ready
497        join_all(networks_ready).await;
498
499        // Then start the necessary tasks
500        for (node_id, network, memberships, config, storage) in uninitialized_nodes {
501            let handle = create_test_handle(
502                self.launcher.metadata.clone(),
503                node_id,
504                network.clone(),
505                Arc::new(RwLock::new(memberships)),
506                config.clone(),
507                storage,
508            )
509            .await;
510
511            match node_id.cmp(&(config.da_staked_committee_size as u64 - 1)) {
512                std::cmp::Ordering::Less => {
513                    if let Some(task) = builder_tasks.pop() {
514                        task.start(Box::new(handle.event_stream()))
515                    }
516                },
517                std::cmp::Ordering::Equal => {
518                    // If we have more builder tasks than DA nodes, pin them all on the last node.
519                    while let Some(task) = builder_tasks.pop() {
520                        task.start(Box::new(handle.event_stream()))
521                    }
522                },
523                std::cmp::Ordering::Greater => {},
524            }
525
526            self.nodes.push(Node {
527                node_id,
528                network,
529                handle,
530            });
531        }
532
533        results
534    }
535
536    /// add a specific node with a config
537    /// # Panics
538    /// if unable to initialize the node's `SystemContext` based on the config
539    #[allow(clippy::too_many_arguments)]
540    pub async fn add_node_with_config(
541        node_id: u64,
542        network: Network<TYPES, I>,
543        memberships: TYPES::Membership,
544        initializer: HotShotInitializer<TYPES>,
545        config: HotShotConfig<TYPES>,
546        validator_config: ValidatorConfig<TYPES>,
547        storage: I::Storage,
548    ) -> Arc<SystemContext<TYPES, I, V>> {
549        // Get key pair for certificate aggregation
550        let private_key = validator_config.private_key.clone();
551        let public_key = validator_config.public_key.clone();
552        let state_private_key = validator_config.state_private_key.clone();
553        let epoch_height = config.epoch_height;
554
555        SystemContext::new(
556            public_key,
557            private_key,
558            state_private_key,
559            node_id,
560            config,
561            EpochMembershipCoordinator::new(
562                Arc::new(RwLock::new(memberships)),
563                epoch_height,
564                &storage.clone(),
565            ),
566            network,
567            initializer,
568            ConsensusMetricsValue::default(),
569            storage,
570            StorageMetricsValue::default(),
571        )
572        .await
573    }
574
575    /// add a specific node with a config
576    /// # Panics
577    /// if unable to initialize the node's `SystemContext` based on the config
578    #[allow(clippy::too_many_arguments, clippy::type_complexity)]
579    pub async fn add_node_with_config_and_channels(
580        node_id: u64,
581        network: Network<TYPES, I>,
582        memberships: Arc<RwLock<TYPES::Membership>>,
583        initializer: HotShotInitializer<TYPES>,
584        config: HotShotConfig<TYPES>,
585        validator_config: ValidatorConfig<TYPES>,
586        storage: I::Storage,
587        internal_channel: (
588            Sender<Arc<HotShotEvent<TYPES>>>,
589            Receiver<Arc<HotShotEvent<TYPES>>>,
590        ),
591        external_channel: (Sender<Event<TYPES>>, Receiver<Event<TYPES>>),
592    ) -> Arc<SystemContext<TYPES, I, V>> {
593        // Get key pair for certificate aggregation
594        let private_key = validator_config.private_key.clone();
595        let public_key = validator_config.public_key.clone();
596        let state_private_key = validator_config.state_private_key.clone();
597        let epoch_height = config.epoch_height;
598
599        SystemContext::new_from_channels(
600            public_key,
601            private_key,
602            state_private_key,
603            node_id,
604            config,
605            EpochMembershipCoordinator::new(memberships, epoch_height, &storage.clone()),
606            network,
607            initializer,
608            ConsensusMetricsValue::default(),
609            storage,
610            StorageMetricsValue::default(),
611            internal_channel,
612            external_channel,
613        )
614        .await
615    }
616}
617
618/// a node participating in a test
619pub struct Node<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
620    /// The node's unique identifier
621    pub node_id: u64,
622    /// The underlying network belonging to the node
623    pub network: Network<TYPES, I>,
624    /// The handle to the node's internals
625    pub handle: SystemContextHandle<TYPES, I, V>,
626}
627
628/// This type combines all of the parameters needed to build the context for a node that started
629/// late during a unit test or integration test.
630pub struct LateNodeContextParameters<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
631    /// The storage trait for Sequencer persistence.
632    pub storage: I::Storage,
633
634    /// The memberships of this particular node.
635    pub memberships: TYPES::Membership,
636
637    /// The config associated with this node.
638    pub config: HotShotConfig<TYPES>,
639}
640
641/// The late node context dictates how we're building a node that started late during the test.
642#[allow(clippy::large_enum_variant)]
643pub enum LateNodeContext<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
644    /// The system context that we're passing directly to the node, this means the node is already
645    /// initialized successfully.
646    InitializedContext(Arc<SystemContext<TYPES, I, V>>),
647
648    /// The system context that we're passing to the node when it is not yet initialized, so we're
649    /// initializing it based on the received leaf and init parameters.
650    UninitializedContext(LateNodeContextParameters<TYPES, I>),
651    /// The node is to be restarted so we will build the context from the node that was already running.
652    Restart,
653}
654
655/// A yet-to-be-started node that participates in tests
656pub struct LateStartNode<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
657    /// The underlying network belonging to the node
658    pub network: Option<Network<TYPES, I>>,
659    /// Either the context to which we will use to launch HotShot for initialized node when it's
660    /// time, or the parameters that will be used to initialize the node and launch HotShot.
661    pub context: LateNodeContext<TYPES, I, V>,
662}
663
664/// The runner of a test network
665/// spin up and down nodes, execute rounds
666pub struct TestRunner<
667    TYPES: NodeType,
668    I: TestableNodeImplementation<TYPES>,
669    V: Versions,
670    N: ConnectedNetwork<TYPES::SignatureKey>,
671> {
672    /// test launcher, contains a bunch of useful metadata and closures
673    pub(crate) launcher: TestLauncher<TYPES, I, V>,
674    /// nodes in the test
675    pub(crate) nodes: Vec<Node<TYPES, I, V>>,
676    /// nodes with a late start
677    pub(crate) late_start: HashMap<u64, LateStartNode<TYPES, I, V>>,
678    /// the next node unique identifier
679    pub(crate) next_node_id: u64,
680    /// Phantom for N
681    pub(crate) _pd: PhantomData<N>,
682}