hotshot_testing/
test_task.rs1use std::{sync::Arc, time::Duration};
8
9use async_broadcast::{Receiver, Sender};
10use async_lock::RwLock;
11use async_trait::async_trait;
12use futures::future::select_all;
13use hotshot::{
14 traits::TestableNodeImplementation,
15 types::{Event, Message},
16};
17use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState};
18use hotshot_types::{
19 message::UpgradeLock,
20 traits::{
21 network::ConnectedNetwork,
22 node_implementation::{NodeType, Versions},
23 },
24};
25use tokio::{
26 spawn,
27 task::JoinHandle,
28 time::{sleep, timeout},
29};
30use tracing::error;
31
32use crate::test_runner::Node;
33
34pub enum TestResult {
36 Pass,
38 Fail(Box<dyn std::fmt::Debug + Send + Sync>),
40}
41
42pub fn spawn_timeout_task(test_sender: Sender<TestEvent>, timeout: Duration) -> JoinHandle<()> {
43 tokio::spawn(async move {
44 sleep(timeout).await;
45
46 error!("Decide timeout triggered.");
47 let _ = test_sender.broadcast(TestEvent::Shutdown).await;
48 })
49}
50
51#[async_trait]
52pub trait TestTaskState: Send {
54 type Event: Clone + Send + Sync;
56 type Error: std::fmt::Display;
58
59 async fn handle_event(
61 &mut self,
62 (event, id): (Self::Event, usize),
63 ) -> std::result::Result<(), Self::Error>;
64
65 async fn check(&self) -> TestResult;
67}
68
69pub type AnyTestTaskState<TYPES> = Box<
72 dyn TestTaskState<Event = hotshot_types::event::Event<TYPES>, Error = anyhow::Error>
73 + Send
74 + Sync,
75>;
76
77#[async_trait]
78impl<TYPES: NodeType> TestTaskState for AnyTestTaskState<TYPES> {
79 type Event = Event<TYPES>;
80 type Error = anyhow::Error;
81
82 async fn handle_event(
83 &mut self,
84 event: (Self::Event, usize),
85 ) -> std::result::Result<(), anyhow::Error> {
86 (**self).handle_event(event).await
87 }
88
89 async fn check(&self) -> TestResult {
90 (**self).check().await
91 }
92}
93
94#[async_trait]
95pub trait TestTaskStateSeed<TYPES, I, V>: Send
96where
97 TYPES: NodeType,
98 I: TestableNodeImplementation<TYPES>,
99 V: Versions,
100{
101 async fn into_state(
102 self: Box<Self>,
103 handles: Arc<RwLock<Vec<Node<TYPES, I, V>>>>,
104 ) -> AnyTestTaskState<TYPES>;
105}
106
107pub struct TestTask<S: TestTaskState> {
113 state: S,
117 receivers: Vec<Receiver<S::Event>>,
119 test_receiver: Receiver<TestEvent>,
121}
122
123#[derive(Clone, Debug)]
124pub enum TestEvent {
125 Shutdown,
126}
127
128impl<S: TestTaskState + Send + 'static> TestTask<S> {
129 pub fn new(
131 state: S,
132 receivers: Vec<Receiver<S::Event>>,
133 test_receiver: Receiver<TestEvent>,
134 ) -> Self {
135 TestTask {
136 state,
137 receivers,
138 test_receiver,
139 }
140 }
141
142 pub fn run(mut self) -> JoinHandle<TestResult> {
145 spawn(async move {
146 loop {
147 if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() {
148 break self.state.check().await;
149 }
150
151 self.receivers.retain(|receiver| !receiver.is_closed());
152
153 let mut messages = Vec::new();
154
155 for receiver in &mut self.receivers {
156 messages.push(receiver.recv());
157 }
158
159 match timeout(Duration::from_millis(2500), select_all(messages)).await {
160 Ok((Ok(input), id, _)) => {
161 let _ = S::handle_event(&mut self.state, (input, id))
162 .await
163 .inspect_err(|e| tracing::error!("{e}"));
164 },
165 Ok((Err(e), _id, _)) => {
166 error!("Error from one channel in test task {e:?}");
167 sleep(Duration::from_millis(4000)).await;
168 },
169 _ => {},
170 };
171 }
172 })
173 }
174}
175
176pub async fn add_network_message_test_task<
178 TYPES: NodeType,
179 V: Versions,
180 NET: ConnectedNetwork<TYPES::SignatureKey>,
181>(
182 internal_event_stream: Sender<Arc<HotShotEvent<TYPES>>>,
183 external_event_stream: Sender<Event<TYPES>>,
184 upgrade_lock: UpgradeLock<TYPES, V>,
185 channel: Arc<NET>,
186 public_key: TYPES::SignatureKey,
187 id: u64,
188) -> JoinHandle<()> {
189 let net = Arc::clone(&channel);
190 let network_state: NetworkMessageTaskState<_, _> = NetworkMessageTaskState {
191 internal_event_stream: internal_event_stream.clone(),
192 external_event_stream: external_event_stream.clone(),
193 public_key,
194 upgrade_lock: upgrade_lock.clone(),
195 id,
196 };
197
198 let network = Arc::clone(&net);
199 let mut state = network_state.clone();
200
201 spawn(async move {
202 loop {
203 let message = match network.recv_message().await {
205 Ok(message) => message,
206 Err(e) => {
207 error!("Failed to receive message: {e:?}");
208 continue;
209 },
210 };
211
212 let deserialized_message: Message<TYPES> =
214 match upgrade_lock.deserialize(&message).await {
215 Ok((message, _)) => message,
216 Err(e) => {
217 tracing::error!("Failed to deserialize message: {e:?}");
218 continue;
219 },
220 };
221
222 state.handle_message(deserialized_message).await;
224 }
225 })
226}