hotshot_testing/
test_task.rs1use std::{sync::Arc, time::Duration};
8
9use async_broadcast::{Receiver, Sender};
10use async_lock::RwLock;
11use async_trait::async_trait;
12use futures::future::select_all;
13use hotshot::{
14 traits::TestableNodeImplementation,
15 types::{Event, Message},
16};
17use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState};
18use hotshot_types::{
19 message::UpgradeLock,
20 traits::{network::ConnectedNetwork, node_implementation::NodeType},
21};
22use tokio::{
23 spawn,
24 task::JoinHandle,
25 time::{sleep, timeout},
26};
27use tracing::error;
28
29use crate::test_runner::Node;
30
31pub enum TestResult {
33 Pass,
35 Fail(Box<dyn std::fmt::Debug + Send + Sync>),
37}
38
39pub fn spawn_timeout_task(test_sender: Sender<TestEvent>, timeout: Duration) -> JoinHandle<()> {
40 tokio::spawn(async move {
41 sleep(timeout).await;
42
43 error!("Decide timeout triggered.");
44 let _ = test_sender.broadcast(TestEvent::Shutdown).await;
45 })
46}
47
48#[async_trait]
49pub trait TestTaskState: Send {
51 type Event: Clone + Send + Sync;
53 type Error: std::fmt::Display;
55
56 async fn handle_event(
58 &mut self,
59 (event, id): (Self::Event, usize),
60 ) -> std::result::Result<(), Self::Error>;
61
62 async fn check(&self) -> TestResult;
64}
65
66pub type AnyTestTaskState<TYPES> = Box<
69 dyn TestTaskState<Event = hotshot_types::event::Event<TYPES>, Error = anyhow::Error>
70 + Send
71 + Sync,
72>;
73
74#[async_trait]
75impl<TYPES: NodeType> TestTaskState for AnyTestTaskState<TYPES> {
76 type Event = Event<TYPES>;
77 type Error = anyhow::Error;
78
79 async fn handle_event(
80 &mut self,
81 event: (Self::Event, usize),
82 ) -> std::result::Result<(), anyhow::Error> {
83 (**self).handle_event(event).await
84 }
85
86 async fn check(&self) -> TestResult {
87 (**self).check().await
88 }
89}
90
91#[async_trait]
92pub trait TestTaskStateSeed<TYPES, I>: Send
93where
94 TYPES: NodeType,
95 I: TestableNodeImplementation<TYPES>,
96{
97 async fn into_state(
98 self: Box<Self>,
99 handles: Arc<RwLock<Vec<Node<TYPES, I>>>>,
100 ) -> AnyTestTaskState<TYPES>;
101}
102
103pub struct TestTask<S: TestTaskState> {
109 state: S,
113 receivers: Vec<Receiver<S::Event>>,
115 test_receiver: Receiver<TestEvent>,
117}
118
119#[derive(Clone, Debug)]
120pub enum TestEvent {
121 Shutdown,
122}
123
124impl<S: TestTaskState + Send + 'static> TestTask<S> {
125 pub fn new(
127 state: S,
128 receivers: Vec<Receiver<S::Event>>,
129 test_receiver: Receiver<TestEvent>,
130 ) -> Self {
131 TestTask {
132 state,
133 receivers,
134 test_receiver,
135 }
136 }
137
138 pub fn run(mut self) -> JoinHandle<TestResult> {
141 spawn(async move {
142 loop {
143 if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() {
144 break self.state.check().await;
145 }
146
147 self.receivers.retain(|receiver| !receiver.is_closed());
148
149 let mut messages = Vec::new();
150
151 for receiver in &mut self.receivers {
152 messages.push(receiver.recv());
153 }
154
155 match timeout(Duration::from_millis(2500), select_all(messages)).await {
156 Ok((Ok(input), id, _)) => {
157 let _ = S::handle_event(&mut self.state, (input, id))
158 .await
159 .inspect_err(|e| tracing::error!("{e}"));
160 },
161 Ok((Err(e), _id, _)) => {
162 error!("Error from one channel in test task {e:?}");
163 sleep(Duration::from_millis(4000)).await;
164 },
165 _ => {},
166 };
167 }
168 })
169 }
170}
171
172pub async fn add_network_message_test_task<
174 TYPES: NodeType,
175 NET: ConnectedNetwork<TYPES::SignatureKey>,
176>(
177 internal_event_stream: Sender<Arc<HotShotEvent<TYPES>>>,
178 external_event_stream: Sender<Event<TYPES>>,
179 upgrade_lock: UpgradeLock<TYPES>,
180 channel: Arc<NET>,
181 public_key: TYPES::SignatureKey,
182 id: u64,
183) -> JoinHandle<()> {
184 let net = Arc::clone(&channel);
185 let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState {
186 internal_event_stream: internal_event_stream.clone(),
187 external_event_stream: external_event_stream.clone(),
188 public_key,
189 upgrade_lock: upgrade_lock.clone(),
190 id,
191 };
192
193 let network = Arc::clone(&net);
194 let mut state = network_state.clone();
195
196 spawn(async move {
197 loop {
198 let message = match network.recv_message().await {
200 Ok(message) => message,
201 Err(e) => {
202 error!("Failed to receive message: {e:?}");
203 continue;
204 },
205 };
206
207 let deserialized_message: Message<TYPES> = match upgrade_lock.deserialize(&message) {
209 Ok((message, _)) => message,
210 Err(e) => {
211 tracing::error!("Failed to deserialize message: {e:?}");
212 continue;
213 },
214 };
215
216 state.handle_message(deserialized_message).await;
218 }
219 })
220}