hotshot_task_impls/
harness.rs1use std::{sync::Arc, time::Duration};
8
9use async_broadcast::broadcast;
10use hotshot_task::task::{ConsensusTaskRegistry, Task, TaskState};
11use hotshot_types::traits::node_implementation::NodeType;
12use tokio::time::timeout;
13
14use crate::events::{HotShotEvent, HotShotTaskCompleted};
15
16pub struct TestHarnessState<TYPES: NodeType> {
18 expected_output: Vec<HotShotEvent<TYPES>>,
20 allow_extra_output: bool,
22}
23
24#[allow(clippy::implicit_hasher)]
35#[allow(clippy::panic)]
36pub async fn run_harness<TYPES, S: TaskState<Event = HotShotEvent<TYPES>> + Send + 'static>(
37 input: Vec<HotShotEvent<TYPES>>,
38 expected_output: Vec<HotShotEvent<TYPES>>,
39 state: S,
40 allow_extra_output: bool,
41) where
42 TYPES: NodeType,
43{
44 let mut registry = ConsensusTaskRegistry::new();
45 let (to_task, from_test) = broadcast(1024);
47 let (to_test, mut from_task) = broadcast(1024);
48 let mut test_state = TestHarnessState {
49 expected_output,
50 allow_extra_output,
51 };
52
53 let task = Task::new(state, to_test.clone(), from_test.clone());
54
55 let handle = task.run();
56 let test_future = async move {
57 loop {
58 if let Ok(event) = from_task.recv_direct().await {
59 if let Some(HotShotTaskCompleted) = check_event(event, &mut test_state) {
60 break;
61 }
62 }
63 }
64 };
65
66 registry.register(handle);
67
68 for event in input {
69 to_task.broadcast_direct(Arc::new(event)).await.unwrap();
70 }
71
72 assert!(
73 timeout(Duration::from_secs(2), test_future).await.is_ok(),
74 "Test timeout out before all all expected outputs received"
75 );
76}
77
78#[allow(clippy::needless_pass_by_value)]
88fn check_event<TYPES: NodeType>(
89 event: Arc<HotShotEvent<TYPES>>,
90 state: &mut TestHarnessState<TYPES>,
91) -> Option<HotShotTaskCompleted> {
92 if !state.allow_extra_output || !state.expected_output.is_empty() {
96 assert!(
97 state.expected_output.contains(&event),
98 "Got an unexpected event: {event:?}",
99 );
100 }
101
102 let idx = state
105 .expected_output
106 .iter()
107 .position(|x| *x == *event)
108 .unwrap();
109
110 state.expected_output.remove(idx);
111
112 if state.expected_output.is_empty() {
113 tracing::info!("test harness task completed");
114 return Some(HotShotTaskCompleted);
115 }
116
117 None
118}