hotshot_testing/
test_runner.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7#![allow(clippy::panic)]
8use std::{
9    collections::{BTreeMap, HashMap, HashSet},
10    marker::PhantomData,
11    sync::Arc,
12};
13
14use async_broadcast::{Receiver, Sender, broadcast};
15use async_lock::RwLock;
16use futures::future::join_all;
17use hotshot::{
18    HotShotInitializer, InitializerEpochInfo, SystemContext,
19    traits::TestableNodeImplementation,
20    types::{Event, SystemContextHandle},
21};
22use hotshot_example_types::{
23    block_types::TestBlockHeader,
24    state_types::{TestInstanceState, TestValidatedState},
25    storage_types::TestStorage,
26};
27use hotshot_task_impls::events::HotShotEvent;
28use hotshot_types::{
29    HotShotConfig, ValidatorConfig,
30    consensus::ConsensusMetricsValue,
31    constants::EVENT_CHANNEL_SIZE,
32    data::{EpochNumber, Leaf2, ViewNumber},
33    drb::INITIAL_DRB_RESULT,
34    epoch_membership::EpochMembershipCoordinator,
35    simple_certificate::QuorumCertificate2,
36    storage_metrics::StorageMetricsValue,
37    traits::{
38        election::Membership,
39        network::ConnectedNetwork,
40        node_implementation::{NodeImplementation, NodeType},
41    },
42};
43use tide_disco::Url;
44#[allow(deprecated)]
45use tracing::info;
46
47use super::{
48    completion_task::CompletionTask, consistency_task::ConsistencyTask, txn_task::TxnTask,
49};
50use crate::{
51    block_builder::{BuilderTask, TestBuilderImplementation},
52    completion_task::CompletionTaskDescription,
53    spinning_task::{ChangeNode, NodeAction, SpinningTask},
54    test_builder::create_test_handle,
55    test_launcher::{Network, TestLauncher},
56    test_task::{TestResult, TestTask, spawn_timeout_task},
57    txn_task::TxnTaskDescription,
58    view_sync_task::ViewSyncTask,
59};
60
61pub trait TaskErr: std::error::Error + Sync + Send + 'static {}
62
63impl<T: std::error::Error + Sync + Send + 'static> TaskErr for T {}
64
65impl<
66    TYPES: NodeType<
67            InstanceState = TestInstanceState,
68            ValidatedState = TestValidatedState,
69            BlockHeader = TestBlockHeader,
70        >,
71    I: TestableNodeImplementation<TYPES>,
72    N: ConnectedNetwork<TYPES::SignatureKey>,
73> TestRunner<TYPES, I, N>
74where
75    I: TestableNodeImplementation<TYPES>,
76    I: NodeImplementation<TYPES, Network = N, Storage = TestStorage<TYPES>>,
77    <TYPES as NodeType>::Membership: Membership<TYPES, Storage = TestStorage<TYPES>>,
78{
79    /// execute test
80    ///
81    /// # Panics
82    /// if the test fails
83    #[allow(clippy::too_many_lines)]
84    pub async fn run_test<B: TestBuilderImplementation<TYPES>>(mut self) {
85        let (test_sender, test_receiver) = broadcast(EVENT_CHANNEL_SIZE);
86        let spinning_changes = self
87            .launcher
88            .metadata
89            .spinning_properties
90            .node_changes
91            .clone();
92
93        let mut late_start_nodes: HashSet<u64> = HashSet::new();
94        let mut restart_nodes: HashSet<u64> = HashSet::new();
95        for (_, changes) in &spinning_changes {
96            for change in changes {
97                if matches!(change.updown, NodeAction::Up) {
98                    late_start_nodes.insert(change.idx.try_into().unwrap());
99                }
100                if matches!(change.updown, NodeAction::RestartDown(_)) {
101                    restart_nodes.insert(change.idx.try_into().unwrap());
102                }
103            }
104        }
105
106        self.add_nodes::<B>(
107            self.launcher
108                .metadata
109                .test_config
110                .num_nodes_with_stake
111                .into(),
112            &late_start_nodes,
113            &restart_nodes,
114        )
115        .await;
116        let mut event_rxs = vec![];
117        let mut internal_event_rxs = vec![];
118
119        for node in &self.nodes {
120            let r = node.handle.event_stream_known_impl();
121            event_rxs.push(r);
122        }
123        for node in &self.nodes {
124            let r = node.handle.internal_event_stream_receiver_known_impl();
125            internal_event_rxs.push(r);
126        }
127
128        let TestRunner {
129            launcher,
130            nodes,
131            late_start,
132            next_node_id: _,
133            _pd: _,
134        } = self;
135
136        let mut task_futs = vec![];
137        let meta = launcher.metadata.clone();
138
139        let handles = Arc::new(RwLock::new(nodes));
140
141        let txn_task =
142            if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description {
143                let txn_task = TxnTask {
144                    handles: Arc::clone(&handles),
145                    next_node_idx: Some(0),
146                    duration,
147                    shutdown_chan: test_receiver.clone(),
148                };
149                Some(txn_task)
150            } else {
151                None
152            };
153
154        // add completion task
155        let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) =
156            meta.completion_task_description;
157        let completion_task = CompletionTask {
158            tx: test_sender.clone(),
159            rx: test_receiver.clone(),
160            duration: time_based.duration,
161        };
162
163        // add spinning task
164        // map spinning to view
165        let mut changes: BTreeMap<ViewNumber, Vec<ChangeNode>> = BTreeMap::new();
166        for (view, mut change) in spinning_changes {
167            changes
168                .entry(ViewNumber::new(view))
169                .or_default()
170                .append(&mut change);
171        }
172
173        let spinning_task_state = SpinningTask {
174            epoch_height: launcher.metadata.test_config.epoch_height,
175            epoch_start_block: launcher.metadata.test_config.epoch_start_block,
176            start_epoch_info: Vec::new(),
177            handles: Arc::clone(&handles),
178            late_start,
179            latest_view: None,
180            changes,
181            last_decided_leaf: Leaf2::genesis(
182                &TestValidatedState::default(),
183                &TestInstanceState::default(),
184                launcher.metadata.upgrade.base,
185            )
186            .await,
187            high_qc: QuorumCertificate2::genesis(
188                &TestValidatedState::default(),
189                &TestInstanceState::default(),
190                launcher.metadata.upgrade,
191            )
192            .await,
193            next_epoch_high_qc: None,
194            async_delay_config: launcher.metadata.async_delay_config,
195            restart_contexts: HashMap::new(),
196            channel_generator: launcher.resource_generators.channel_generator,
197            state_cert: None,
198            node_stakes: launcher.metadata.node_stakes.clone(),
199            upgrade: launcher.metadata.upgrade,
200        };
201        let spinning_task = TestTask::<SpinningTask<TYPES, N, I>>::new(
202            spinning_task_state,
203            event_rxs.clone(),
204            test_receiver.clone(),
205        );
206
207        let consistency_task_state = ConsistencyTask {
208            consensus_leaves: BTreeMap::new(),
209            safety_properties: launcher.metadata.overall_safety_properties.clone(),
210            test_sender: test_sender.clone(),
211            errors: vec![],
212            ensure_upgrade: launcher.metadata.upgrade_view.is_some(),
213            validate_transactions: launcher.metadata.validate_transactions,
214            timeout_task: spawn_timeout_task(
215                test_sender.clone(),
216                launcher.metadata.overall_safety_properties.decide_timeout,
217            ),
218        };
219
220        let consistency_task = TestTask::<ConsistencyTask<TYPES>>::new(
221            consistency_task_state,
222            event_rxs.clone(),
223            test_receiver.clone(),
224        );
225
226        // add view sync task
227        let view_sync_task_state = ViewSyncTask {
228            hit_view_sync: HashSet::new(),
229            description: launcher.metadata.view_sync_properties,
230            _pd: PhantomData,
231        };
232
233        let view_sync_task = TestTask::<ViewSyncTask<TYPES, I>>::new(
234            view_sync_task_state,
235            internal_event_rxs,
236            test_receiver.clone(),
237        );
238
239        let nodes = handles.read().await;
240
241        // wait for networks to be ready
242        for node in &*nodes {
243            node.network.wait_for_ready().await;
244        }
245
246        // Start hotshot
247        for node in &*nodes {
248            if !late_start_nodes.contains(&node.node_id) {
249                node.handle.hotshot.start_consensus().await;
250            }
251        }
252
253        drop(nodes);
254
255        for seed in launcher.additional_test_tasks {
256            let task = TestTask::new(
257                seed.into_state(Arc::clone(&handles)).await,
258                event_rxs.clone(),
259                test_receiver.clone(),
260            );
261            task_futs.push(task.run());
262        }
263
264        task_futs.push(consistency_task.run());
265        task_futs.push(view_sync_task.run());
266        task_futs.push(spinning_task.run());
267
268        // `generator` tasks that do not process events.
269        let txn_handle = txn_task.map(|txn| txn.run());
270        let completion_handle = completion_task.run();
271
272        let mut error_list = vec![];
273
274        let results = join_all(task_futs).await;
275
276        for result in results {
277            match result {
278                Ok(res) => match res {
279                    TestResult::Pass => {
280                        info!("Task shut down successfully");
281                    },
282                    TestResult::Fail(e) => error_list.push(e),
283                },
284                Err(e) => {
285                    tracing::error!("Error Joining the test task {e:?}");
286                },
287            }
288        }
289
290        if let Some(handle) = txn_handle {
291            handle.abort();
292        }
293        // Shutdown all of the servers at the end
294
295        let mut nodes = handles.write().await;
296
297        for node in &mut *nodes {
298            node.handle.shut_down().await;
299        }
300        tracing::info!("Nodes shutdown");
301
302        completion_handle.abort();
303
304        assert!(
305            error_list.is_empty(),
306            "{}",
307            error_list
308                .iter()
309                .fold("TEST FAILED! Results:".to_string(), |acc, error| {
310                    format!("{acc}\n\n{error:?}")
311                })
312        );
313    }
314
315    pub async fn init_builders<B: TestBuilderImplementation<TYPES>>(
316        &self,
317    ) -> (Vec<Box<dyn BuilderTask<TYPES>>>, Vec<Url>) {
318        let mut builder_tasks = Vec::new();
319        let mut builder_urls = Vec::new();
320
321        let mut ports = Vec::new();
322        for _ in &self.launcher.metadata.builders {
323            let port =
324                test_utils::reserve_tcp_port().expect("OS should have ephemeral ports available");
325            ports.push(port);
326        }
327
328        for (metadata, builder_port) in self.launcher.metadata.builders.iter().zip(&ports) {
329            let builder_url =
330                Url::parse(&format!("http://localhost:{builder_port}")).expect("Invalid URL");
331            let builder_task = B::start(
332                0, // This field gets updated while the test is running, 0 is just to seed it
333                builder_url.clone(),
334                B::Config::default(),
335                metadata.changes.clone(),
336            )
337            .await;
338            builder_tasks.push(builder_task);
339            builder_urls.push(builder_url);
340        }
341
342        (builder_tasks, builder_urls)
343    }
344
345    /// Add nodes.
346    ///
347    /// # Panics
348    /// Panics if unable to create a [`HotShotInitializer`]
349    pub async fn add_nodes<B: TestBuilderImplementation<TYPES>>(
350        &mut self,
351        total: usize,
352        late_start: &HashSet<u64>,
353        restart: &HashSet<u64>,
354    ) -> Vec<u64> {
355        let mut results = vec![];
356        let config = self.launcher.metadata.test_config.clone();
357
358        // Num_nodes is updated on the fly now via claim_block_with_num_nodes. This stays around to seed num_nodes
359        // in the builders for tests which don't update that field.
360        let (mut builder_tasks, builder_urls) = self.init_builders::<B>().await;
361
362        // Collect uninitialized nodes because we need to wait for all networks to be ready before starting the tasks
363        let mut uninitialized_nodes = Vec::new();
364        let mut networks_ready = Vec::new();
365
366        for i in 0..total {
367            let mut config = config.clone();
368            if let Some(upgrade_view) = self.launcher.metadata.upgrade_view {
369                config.set_view_upgrade(upgrade_view);
370            }
371            let node_id = self.next_node_id;
372            self.next_node_id += 1;
373            tracing::debug!("launch node {i}");
374
375            config.builder_urls = builder_urls
376                .clone()
377                .try_into()
378                .expect("Non-empty by construction");
379
380            let storage = (self.launcher.resource_generators.storage)(node_id);
381
382            // See whether or not we should be DA
383            let is_da = node_id < config.da_staked_committee_size as u64;
384
385            // We assign node's public key and stake value rather than read from config file since it's a test
386            let validator_config = ValidatorConfig::<TYPES>::generated_from_seed_indexed(
387                [0u8; 32],
388                node_id,
389                self.launcher.metadata.node_stakes.get(node_id),
390                is_da,
391            );
392
393            let public_key = validator_config.public_key.clone();
394
395            let network = if late_start.contains(&node_id) && self.launcher.metadata.skip_late {
396                None
397            } else {
398                let net = (self.launcher.resource_generators.channel_generator)(node_id).await;
399                networks_ready.push({
400                    let net = net.clone();
401                    async move { net.wait_for_ready().await }
402                });
403                Some(net)
404            };
405
406            if late_start.contains(&node_id) {
407                if self.launcher.metadata.skip_late {
408                    self.late_start.insert(
409                        node_id,
410                        LateStartNode {
411                            network: None,
412                            context: LateNodeContext::UninitializedContext(
413                                LateNodeContextParameters {
414                                    storage: storage.clone(),
415                                    config,
416                                },
417                            ),
418                        },
419                    );
420                } else {
421                    let initializer = HotShotInitializer::<TYPES>::from_genesis(
422                        TestInstanceState::new(
423                            self.launcher
424                                .metadata
425                                .async_delay_config
426                                .get(&node_id)
427                                .cloned()
428                                .unwrap_or_default(),
429                        ),
430                        config.epoch_height,
431                        config.epoch_start_block,
432                        vec![InitializerEpochInfo::<TYPES> {
433                            epoch: EpochNumber::new(1),
434                            drb_result: INITIAL_DRB_RESULT,
435                            block_header: None,
436                        }],
437                        self.launcher.metadata.upgrade,
438                    )
439                    .await
440                    .unwrap();
441
442                    let network = network.clone().expect("!skip_late => network created");
443
444                    let hotshot = Self::add_node_with_config(
445                        node_id,
446                        network.clone(),
447                        <TYPES as NodeType>::Membership::new::<I>(
448                            config.known_nodes_with_stake.clone(),
449                            config.known_da_nodes.clone(),
450                            storage.clone(),
451                            network.clone(),
452                            public_key.clone(),
453                            config.epoch_height,
454                        ),
455                        initializer,
456                        config,
457                        self.launcher.metadata.upgrade,
458                        validator_config,
459                        storage,
460                    )
461                    .await;
462                    self.late_start.insert(
463                        node_id,
464                        LateStartNode {
465                            network: Some(network),
466                            context: LateNodeContext::InitializedContext(hotshot),
467                        },
468                    );
469                }
470            } else {
471                let network = network.expect("!late_start => network created");
472                uninitialized_nodes.push((
473                    node_id,
474                    network.clone(),
475                    <TYPES as NodeType>::Membership::new::<I>(
476                        config.known_nodes_with_stake.clone(),
477                        config.known_da_nodes.clone(),
478                        storage.clone(),
479                        network,
480                        public_key.clone(),
481                        config.epoch_height,
482                    ),
483                    config,
484                    storage,
485                ));
486            }
487
488            results.push(node_id);
489        }
490
491        // Add the restart nodes after the rest.  This must be done after all the original networks are
492        // created because this will reset the bootstrap info for the restarted nodes
493        for node_id in &results {
494            if restart.contains(node_id) {
495                self.late_start.insert(
496                    *node_id,
497                    LateStartNode {
498                        network: None,
499                        context: LateNodeContext::Restart,
500                    },
501                );
502            }
503        }
504
505        // Wait for all networks to be ready
506        join_all(networks_ready).await;
507
508        // Then start the necessary tasks
509        for (node_id, network, memberships, config, storage) in uninitialized_nodes {
510            let handle = create_test_handle(
511                self.launcher.metadata.clone(),
512                node_id,
513                network.clone(),
514                Arc::new(RwLock::new(memberships)),
515                config.clone(),
516                storage,
517            )
518            .await;
519
520            match node_id.cmp(&(config.da_staked_committee_size as u64 - 1)) {
521                std::cmp::Ordering::Less => {
522                    if let Some(task) = builder_tasks.pop() {
523                        task.start(Box::new(handle.event_stream()))
524                    }
525                },
526                std::cmp::Ordering::Equal => {
527                    // If we have more builder tasks than DA nodes, pin them all on the last node.
528                    while let Some(task) = builder_tasks.pop() {
529                        task.start(Box::new(handle.event_stream()))
530                    }
531                },
532                std::cmp::Ordering::Greater => {},
533            }
534
535            self.nodes.push(Node {
536                node_id,
537                network,
538                handle,
539            });
540        }
541
542        results
543    }
544
545    /// add a specific node with a config
546    /// # Panics
547    /// if unable to initialize the node's `SystemContext` based on the config
548    #[allow(clippy::too_many_arguments)]
549    pub async fn add_node_with_config(
550        node_id: u64,
551        network: Network<TYPES, I>,
552        memberships: TYPES::Membership,
553        initializer: HotShotInitializer<TYPES>,
554        config: HotShotConfig<TYPES>,
555        upgrade: versions::Upgrade,
556        validator_config: ValidatorConfig<TYPES>,
557        storage: I::Storage,
558    ) -> Arc<SystemContext<TYPES, I>> {
559        // Get key pair for certificate aggregation
560        let private_key = validator_config.private_key.clone();
561        let public_key = validator_config.public_key.clone();
562        let state_private_key = validator_config.state_private_key.clone();
563        let epoch_height = config.epoch_height;
564
565        SystemContext::new(
566            public_key,
567            private_key,
568            state_private_key,
569            node_id,
570            config,
571            upgrade,
572            EpochMembershipCoordinator::new(
573                Arc::new(RwLock::new(memberships)),
574                epoch_height,
575                &storage.clone(),
576            ),
577            network,
578            initializer,
579            ConsensusMetricsValue::default(),
580            storage,
581            StorageMetricsValue::default(),
582        )
583        .await
584    }
585
586    /// add a specific node with a config
587    /// # Panics
588    /// if unable to initialize the node's `SystemContext` based on the config
589    #[allow(clippy::too_many_arguments, clippy::type_complexity)]
590    pub async fn add_node_with_config_and_channels(
591        node_id: u64,
592        network: Network<TYPES, I>,
593        memberships: Arc<RwLock<TYPES::Membership>>,
594        initializer: HotShotInitializer<TYPES>,
595        config: HotShotConfig<TYPES>,
596        upgrade: versions::Upgrade,
597        validator_config: ValidatorConfig<TYPES>,
598        storage: I::Storage,
599        internal_channel: (
600            Sender<Arc<HotShotEvent<TYPES>>>,
601            Receiver<Arc<HotShotEvent<TYPES>>>,
602        ),
603        external_channel: (Sender<Event<TYPES>>, Receiver<Event<TYPES>>),
604    ) -> Arc<SystemContext<TYPES, I>> {
605        // Get key pair for certificate aggregation
606        let private_key = validator_config.private_key.clone();
607        let public_key = validator_config.public_key.clone();
608        let state_private_key = validator_config.state_private_key.clone();
609        let epoch_height = config.epoch_height;
610
611        SystemContext::new_from_channels(
612            public_key,
613            private_key,
614            state_private_key,
615            node_id,
616            config,
617            upgrade,
618            EpochMembershipCoordinator::new(memberships, epoch_height, &storage.clone()),
619            network,
620            initializer,
621            ConsensusMetricsValue::default(),
622            storage,
623            StorageMetricsValue::default(),
624            internal_channel,
625            external_channel,
626        )
627        .await
628    }
629}
630
631/// a node participating in a test
632pub struct Node<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
633    /// The node's unique identifier
634    pub node_id: u64,
635    /// The underlying network belonging to the node
636    pub network: Network<TYPES, I>,
637    /// The handle to the node's internals
638    pub handle: SystemContextHandle<TYPES, I>,
639}
640
641/// This type combines all of the parameters needed to build the context for a node that started
642/// late during a unit test or integration test.
643pub struct LateNodeContextParameters<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
644    /// The storage trait for Sequencer persistence.
645    pub storage: I::Storage,
646
647    /// The config associated with this node.
648    pub config: HotShotConfig<TYPES>,
649}
650
651/// The late node context dictates how we're building a node that started late during the test.
652#[allow(clippy::large_enum_variant)]
653pub enum LateNodeContext<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
654    /// The system context that we're passing directly to the node, this means the node is already
655    /// initialized successfully.
656    InitializedContext(Arc<SystemContext<TYPES, I>>),
657
658    /// The system context that we're passing to the node when it is not yet initialized, so we're
659    /// initializing it based on the received leaf and init parameters.
660    UninitializedContext(LateNodeContextParameters<TYPES, I>),
661    /// The node is to be restarted so we will build the context from the node that was already running.
662    Restart,
663}
664
665/// A yet-to-be-started node that participates in tests
666pub struct LateStartNode<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
667    /// The underlying network belonging to the node
668    pub network: Option<Network<TYPES, I>>,
669    /// Either the context to which we will use to launch HotShot for initialized node when it's
670    /// time, or the parameters that will be used to initialize the node and launch HotShot.
671    pub context: LateNodeContext<TYPES, I>,
672}
673
674/// The runner of a test network
675/// spin up and down nodes, execute rounds
676pub struct TestRunner<
677    TYPES: NodeType,
678    I: TestableNodeImplementation<TYPES>,
679    N: ConnectedNetwork<TYPES::SignatureKey>,
680> {
681    /// test launcher, contains a bunch of useful metadata and closures
682    pub(crate) launcher: TestLauncher<TYPES, I>,
683    /// nodes in the test
684    pub(crate) nodes: Vec<Node<TYPES, I>>,
685    /// nodes with a late start
686    pub(crate) late_start: HashMap<u64, LateStartNode<TYPES, I>>,
687    /// the next node unique identifier
688    pub(crate) next_node_id: u64,
689    /// Phantom for N
690    pub(crate) _pd: PhantomData<N>,
691}