hotshot_testing/
test_task.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{sync::Arc, time::Duration};
8
9use async_broadcast::{Receiver, Sender};
10use async_lock::RwLock;
11use async_trait::async_trait;
12use futures::future::select_all;
13use hotshot::{
14    traits::TestableNodeImplementation,
15    types::{Event, Message},
16};
17use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState};
18use hotshot_types::{
19    message::UpgradeLock,
20    traits::{
21        network::ConnectedNetwork,
22        node_implementation::{NodeType, Versions},
23    },
24};
25use tokio::{
26    spawn,
27    task::JoinHandle,
28    time::{sleep, timeout},
29};
30use tracing::error;
31
32use crate::test_runner::Node;
33
34/// enum describing how the tasks completed
35pub enum TestResult {
36    /// the test task passed
37    Pass,
38    /// the test task failed with an error
39    Fail(Box<dyn std::fmt::Debug + Send + Sync>),
40}
41
42pub fn spawn_timeout_task(test_sender: Sender<TestEvent>, timeout: Duration) -> JoinHandle<()> {
43    tokio::spawn(async move {
44        sleep(timeout).await;
45
46        error!("Decide timeout triggered.");
47        let _ = test_sender.broadcast(TestEvent::Shutdown).await;
48    })
49}
50
51#[async_trait]
52/// Type for mutable task state that can be used as the state for a `Task`
53pub trait TestTaskState: Send {
54    /// Type of event sent and received by the task
55    type Event: Clone + Send + Sync;
56    /// Type of error produced by the task
57    type Error: std::fmt::Display;
58
59    /// Handles an event from one of multiple receivers.
60    async fn handle_event(
61        &mut self,
62        (event, id): (Self::Event, usize),
63    ) -> std::result::Result<(), Self::Error>;
64
65    /// Check the result of the test.
66    async fn check(&self) -> TestResult;
67}
68
69/// Type alias for type-erased [`TestTaskState`] to be used for
70/// dynamic dispatch
71pub type AnyTestTaskState<TYPES> = Box<
72    dyn TestTaskState<Event = hotshot_types::event::Event<TYPES>, Error = anyhow::Error>
73        + Send
74        + Sync,
75>;
76
77#[async_trait]
78impl<TYPES: NodeType> TestTaskState for AnyTestTaskState<TYPES> {
79    type Event = Event<TYPES>;
80    type Error = anyhow::Error;
81
82    async fn handle_event(
83        &mut self,
84        event: (Self::Event, usize),
85    ) -> std::result::Result<(), anyhow::Error> {
86        (**self).handle_event(event).await
87    }
88
89    async fn check(&self) -> TestResult {
90        (**self).check().await
91    }
92}
93
94#[async_trait]
95pub trait TestTaskStateSeed<TYPES, I, V>: Send
96where
97    TYPES: NodeType,
98    I: TestableNodeImplementation<TYPES>,
99    V: Versions,
100{
101    async fn into_state(
102        self: Box<Self>,
103        handles: Arc<RwLock<Vec<Node<TYPES, I, V>>>>,
104    ) -> AnyTestTaskState<TYPES>;
105}
106
107/// A basic task which loops waiting for events to come from `event_receiver`
108/// and then handles them using it's state
109/// It sends events to other `Task`s through `event_sender`
110/// This should be used as the primary building block for long running
111/// or medium running tasks (i.e. anything that can't be described as a dependency task)
112pub struct TestTask<S: TestTaskState> {
113    /// The state of the task.  It is fed events from `event_sender`
114    /// and mutates it state ocordingly.  Also it signals the task
115    /// if it is complete/should shutdown
116    state: S,
117    /// Receives events that are broadcast from any task, including itself
118    receivers: Vec<Receiver<S::Event>>,
119    /// Receiver for test events, used for communication between test tasks.
120    test_receiver: Receiver<TestEvent>,
121}
122
123#[derive(Clone, Debug)]
124pub enum TestEvent {
125    Shutdown,
126}
127
128impl<S: TestTaskState + Send + 'static> TestTask<S> {
129    /// Create a new task
130    pub fn new(
131        state: S,
132        receivers: Vec<Receiver<S::Event>>,
133        test_receiver: Receiver<TestEvent>,
134    ) -> Self {
135        TestTask {
136            state,
137            receivers,
138            test_receiver,
139        }
140    }
141
142    /// Spawn the task loop, consuming self.  Will continue until
143    /// the task reaches some shutdown condition
144    pub fn run(mut self) -> JoinHandle<TestResult> {
145        spawn(async move {
146            loop {
147                if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() {
148                    break self.state.check().await;
149                }
150
151                self.receivers.retain(|receiver| !receiver.is_closed());
152
153                let mut messages = Vec::new();
154
155                for receiver in &mut self.receivers {
156                    messages.push(receiver.recv());
157                }
158
159                match timeout(Duration::from_millis(2500), select_all(messages)).await {
160                    Ok((Ok(input), id, _)) => {
161                        let _ = S::handle_event(&mut self.state, (input, id))
162                            .await
163                            .inspect_err(|e| tracing::error!("{e}"));
164                    },
165                    Ok((Err(e), _id, _)) => {
166                        error!("Error from one channel in test task {e:?}");
167                        sleep(Duration::from_millis(4000)).await;
168                    },
169                    _ => {},
170                };
171            }
172        })
173    }
174}
175
176/// Add the network task to handle messages and publish events.
177pub async fn add_network_message_test_task<
178    TYPES: NodeType,
179    V: Versions,
180    NET: ConnectedNetwork<TYPES::SignatureKey>,
181>(
182    internal_event_stream: Sender<Arc<HotShotEvent<TYPES>>>,
183    external_event_stream: Sender<Event<TYPES>>,
184    upgrade_lock: UpgradeLock<TYPES, V>,
185    channel: Arc<NET>,
186    public_key: TYPES::SignatureKey,
187    id: u64,
188) -> JoinHandle<()> {
189    let net = Arc::clone(&channel);
190    let network_state: NetworkMessageTaskState<_, _> = NetworkMessageTaskState {
191        internal_event_stream: internal_event_stream.clone(),
192        external_event_stream: external_event_stream.clone(),
193        public_key,
194        upgrade_lock: upgrade_lock.clone(),
195        id,
196    };
197
198    let network = Arc::clone(&net);
199    let mut state = network_state.clone();
200
201    spawn(async move {
202        loop {
203            // Get the next message from the network
204            let message = match network.recv_message().await {
205                Ok(message) => message,
206                Err(e) => {
207                    error!("Failed to receive message: {e:?}");
208                    continue;
209                },
210            };
211
212            // Deserialize the message
213            let deserialized_message: Message<TYPES> =
214                match upgrade_lock.deserialize(&message).await {
215                    Ok((message, _)) => message,
216                    Err(e) => {
217                        tracing::error!("Failed to deserialize message: {e:?}");
218                        continue;
219                    },
220                };
221
222            // Handle the message
223            state.handle_message(deserialized_message).await;
224        }
225    })
226}