hotshot_task_impls/
harness.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{sync::Arc, time::Duration};
8
9use async_broadcast::broadcast;
10use hotshot_task::task::{ConsensusTaskRegistry, Task, TaskState};
11use hotshot_types::traits::node_implementation::NodeType;
12use tokio::time::timeout;
13
14use crate::events::{HotShotEvent, HotShotTaskCompleted};
15
16/// The state for the test harness task. Keeps track of which events and how many we expect to get
17pub struct TestHarnessState<TYPES: NodeType> {
18    /// The expected events we get from the test.  Maps an event to the number of times we expect to see it
19    expected_output: Vec<HotShotEvent<TYPES>>,
20    /// If true we won't fail the test if extra events come in
21    allow_extra_output: bool,
22}
23
24/// Runs a test by building the task using `build_fn` and then passing it the `input` events
25/// and testing the make sure all of the `expected_output` events are seen
26///
27/// # Arguments
28/// * `event_stream` - if given, will be used to register the task builder.
29/// * `allow_extra_output` - whether to allow an extra output after we've seen all expected
30///   outputs. Should be `false` in most cases.
31///
32/// # Panics
33/// Panics if any state the test expects is not set. Panicking causes a test failure
34#[allow(clippy::implicit_hasher)]
35#[allow(clippy::panic)]
36pub async fn run_harness<TYPES, S: TaskState<Event = HotShotEvent<TYPES>> + Send + 'static>(
37    input: Vec<HotShotEvent<TYPES>>,
38    expected_output: Vec<HotShotEvent<TYPES>>,
39    state: S,
40    allow_extra_output: bool,
41) where
42    TYPES: NodeType,
43{
44    let mut registry = ConsensusTaskRegistry::new();
45    // set up two broadcast channels so the test sends to the task and the task back to the test
46    let (to_task, from_test) = broadcast(1024);
47    let (to_test, mut from_task) = broadcast(1024);
48    let mut test_state = TestHarnessState {
49        expected_output,
50        allow_extra_output,
51    };
52
53    let task = Task::new(state, to_test.clone(), from_test.clone());
54
55    let handle = task.run();
56    let test_future = async move {
57        loop {
58            if let Ok(event) = from_task.recv_direct().await {
59                if let Some(HotShotTaskCompleted) = check_event(event, &mut test_state) {
60                    break;
61                }
62            }
63        }
64    };
65
66    registry.register(handle);
67
68    for event in input {
69        to_task.broadcast_direct(Arc::new(event)).await.unwrap();
70    }
71
72    assert!(
73        timeout(Duration::from_secs(2), test_future).await.is_ok(),
74        "Test timeout out before all all expected outputs received"
75    );
76}
77
78/// Handles an event for the Test Harness Task.  If the event is expected, remove it from
79/// the `expected_output` in state.  If unexpected fail test.
80///
81/// # Arguments
82/// * `allow_extra_output` - whether to allow an extra output after we've seen all expected
83///   outputs. Should be `false` in most cases.
84///
85///  # Panics
86/// Will panic to fail the test when it receives and unexpected event
87#[allow(clippy::needless_pass_by_value)]
88fn check_event<TYPES: NodeType>(
89    event: Arc<HotShotEvent<TYPES>>,
90    state: &mut TestHarnessState<TYPES>,
91) -> Option<HotShotTaskCompleted> {
92    // Check the output in either case:
93    // * We allow outputs only in our expected output set.
94    // * We haven't received all expected outputs yet.
95    if !state.allow_extra_output || !state.expected_output.is_empty() {
96        assert!(
97            state.expected_output.contains(&event),
98            "Got an unexpected event: {event:?}",
99        );
100    }
101
102    // NOTE: We only care about finding a single instance of the output event, and we just
103    // iteratively remove the entries until they're gone.
104    let idx = state
105        .expected_output
106        .iter()
107        .position(|x| *x == *event)
108        .unwrap();
109
110    state.expected_output.remove(idx);
111
112    if state.expected_output.is_empty() {
113        tracing::info!("test harness task completed");
114        return Some(HotShotTaskCompleted);
115    }
116
117    None
118}