hotshot_testing/
test_task.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{sync::Arc, time::Duration};
8
9use async_broadcast::{Receiver, Sender};
10use async_lock::RwLock;
11use async_trait::async_trait;
12use futures::future::select_all;
13use hotshot::{
14    traits::TestableNodeImplementation,
15    types::{Event, Message},
16};
17use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState};
18use hotshot_types::{
19    message::UpgradeLock,
20    traits::{network::ConnectedNetwork, node_implementation::NodeType},
21};
22use tokio::{
23    spawn,
24    task::JoinHandle,
25    time::{sleep, timeout},
26};
27use tracing::error;
28
29use crate::test_runner::Node;
30
31/// enum describing how the tasks completed
32pub enum TestResult {
33    /// the test task passed
34    Pass,
35    /// the test task failed with an error
36    Fail(Box<dyn std::fmt::Debug + Send + Sync>),
37}
38
39pub fn spawn_timeout_task(test_sender: Sender<TestEvent>, timeout: Duration) -> JoinHandle<()> {
40    tokio::spawn(async move {
41        sleep(timeout).await;
42
43        error!("Decide timeout triggered.");
44        let _ = test_sender.broadcast(TestEvent::Shutdown).await;
45    })
46}
47
48#[async_trait]
49/// Type for mutable task state that can be used as the state for a `Task`
50pub trait TestTaskState: Send {
51    /// Type of event sent and received by the task
52    type Event: Clone + Send + Sync;
53    /// Type of error produced by the task
54    type Error: std::fmt::Display;
55
56    /// Handles an event from one of multiple receivers.
57    async fn handle_event(
58        &mut self,
59        (event, id): (Self::Event, usize),
60    ) -> std::result::Result<(), Self::Error>;
61
62    /// Check the result of the test.
63    async fn check(&self) -> TestResult;
64}
65
66/// Type alias for type-erased [`TestTaskState`] to be used for
67/// dynamic dispatch
68pub type AnyTestTaskState<TYPES> = Box<
69    dyn TestTaskState<Event = hotshot_types::event::Event<TYPES>, Error = anyhow::Error>
70        + Send
71        + Sync,
72>;
73
74#[async_trait]
75impl<TYPES: NodeType> TestTaskState for AnyTestTaskState<TYPES> {
76    type Event = Event<TYPES>;
77    type Error = anyhow::Error;
78
79    async fn handle_event(
80        &mut self,
81        event: (Self::Event, usize),
82    ) -> std::result::Result<(), anyhow::Error> {
83        (**self).handle_event(event).await
84    }
85
86    async fn check(&self) -> TestResult {
87        (**self).check().await
88    }
89}
90
91#[async_trait]
92pub trait TestTaskStateSeed<TYPES, I>: Send
93where
94    TYPES: NodeType,
95    I: TestableNodeImplementation<TYPES>,
96{
97    async fn into_state(
98        self: Box<Self>,
99        handles: Arc<RwLock<Vec<Node<TYPES, I>>>>,
100    ) -> AnyTestTaskState<TYPES>;
101}
102
103/// A basic task which loops waiting for events to come from `event_receiver`
104/// and then handles them using it's state
105/// It sends events to other `Task`s through `event_sender`
106/// This should be used as the primary building block for long running
107/// or medium running tasks (i.e. anything that can't be described as a dependency task)
108pub struct TestTask<S: TestTaskState> {
109    /// The state of the task.  It is fed events from `event_sender`
110    /// and mutates it state ocordingly.  Also it signals the task
111    /// if it is complete/should shutdown
112    state: S,
113    /// Receives events that are broadcast from any task, including itself
114    receivers: Vec<Receiver<S::Event>>,
115    /// Receiver for test events, used for communication between test tasks.
116    test_receiver: Receiver<TestEvent>,
117}
118
119#[derive(Clone, Debug)]
120pub enum TestEvent {
121    Shutdown,
122}
123
124impl<S: TestTaskState + Send + 'static> TestTask<S> {
125    /// Create a new task
126    pub fn new(
127        state: S,
128        receivers: Vec<Receiver<S::Event>>,
129        test_receiver: Receiver<TestEvent>,
130    ) -> Self {
131        TestTask {
132            state,
133            receivers,
134            test_receiver,
135        }
136    }
137
138    /// Spawn the task loop, consuming self.  Will continue until
139    /// the task reaches some shutdown condition
140    pub fn run(mut self) -> JoinHandle<TestResult> {
141        spawn(async move {
142            loop {
143                if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() {
144                    break self.state.check().await;
145                }
146
147                self.receivers.retain(|receiver| !receiver.is_closed());
148
149                let mut messages = Vec::new();
150
151                for receiver in &mut self.receivers {
152                    messages.push(receiver.recv());
153                }
154
155                match timeout(Duration::from_millis(2500), select_all(messages)).await {
156                    Ok((Ok(input), id, _)) => {
157                        let _ = S::handle_event(&mut self.state, (input, id))
158                            .await
159                            .inspect_err(|e| tracing::error!("{e}"));
160                    },
161                    Ok((Err(e), _id, _)) => {
162                        error!("Error from one channel in test task {e:?}");
163                        sleep(Duration::from_millis(4000)).await;
164                    },
165                    _ => {},
166                };
167            }
168        })
169    }
170}
171
172/// Add the network task to handle messages and publish events.
173pub async fn add_network_message_test_task<
174    TYPES: NodeType,
175    NET: ConnectedNetwork<TYPES::SignatureKey>,
176>(
177    internal_event_stream: Sender<Arc<HotShotEvent<TYPES>>>,
178    external_event_stream: Sender<Event<TYPES>>,
179    upgrade_lock: UpgradeLock<TYPES>,
180    channel: Arc<NET>,
181    public_key: TYPES::SignatureKey,
182    id: u64,
183) -> JoinHandle<()> {
184    let net = Arc::clone(&channel);
185    let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState {
186        internal_event_stream: internal_event_stream.clone(),
187        external_event_stream: external_event_stream.clone(),
188        public_key,
189        upgrade_lock: upgrade_lock.clone(),
190        id,
191    };
192
193    let network = Arc::clone(&net);
194    let mut state = network_state.clone();
195
196    spawn(async move {
197        loop {
198            // Get the next message from the network
199            let message = match network.recv_message().await {
200                Ok(message) => message,
201                Err(e) => {
202                    error!("Failed to receive message: {e:?}");
203                    continue;
204                },
205            };
206
207            // Deserialize the message
208            let deserialized_message: Message<TYPES> = match upgrade_lock.deserialize(&message) {
209                Ok((message, _)) => message,
210                Err(e) => {
211                    tracing::error!("Failed to deserialize message: {e:?}");
212                    continue;
213                },
214            };
215
216            // Handle the message
217            state.handle_message(deserialized_message).await;
218        }
219    })
220}