hotshot_task/task.rs
1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::sync::Arc;
8
9use async_broadcast::{Receiver, RecvError, Sender};
10use async_trait::async_trait;
11use futures::future::try_join_all;
12use hotshot_utils::anytrace::*;
13use tokio::task::{spawn, JoinHandle};
14
15/// Trait for events that long-running tasks handle
16pub trait TaskEvent: PartialEq {
17 /// The shutdown signal for this event type
18 ///
19 /// Note that this is necessarily uniform across all tasks.
20 /// Exiting the task loop is handled by the task spawner, rather than the task individually.
21 fn shutdown_event() -> Self;
22}
23
24#[async_trait]
25/// Type for mutable task state that can be used as the state for a `Task`
26pub trait TaskState: Send {
27 /// Type of event sent and received by the task
28 type Event: TaskEvent + Clone + Send + Sync;
29
30 /// Joins all subtasks.
31 fn cancel_subtasks(&mut self);
32
33 /// Handles an event, providing direct access to the specific channel we received the event on.
34 async fn handle_event(
35 &mut self,
36 event: Arc<Self::Event>,
37 _sender: &Sender<Arc<Self::Event>>,
38 _receiver: &Receiver<Arc<Self::Event>>,
39 ) -> Result<()>;
40}
41
42/// A basic task which loops waiting for events to come from `event_receiver`
43/// and then handles them using its state
44/// It sends events to other `Task`s through `sender`
45/// This should be used as the primary building block for long running
46/// or medium running tasks (i.e. anything that can't be described as a dependency task)
47pub struct Task<S: TaskState> {
48 /// The state of the task. It is fed events from `receiver`
49 /// and mutated via `handle_event`.
50 state: S,
51 /// Sends events all tasks including itself
52 sender: Sender<Arc<S::Event>>,
53 /// Receives events that are broadcast from any task, including itself
54 receiver: Receiver<Arc<S::Event>>,
55}
56
57impl<S: TaskState + Send + 'static> Task<S> {
58 /// Create a new task
59 pub fn new(state: S, sender: Sender<Arc<S::Event>>, receiver: Receiver<Arc<S::Event>>) -> Self {
60 Task {
61 state,
62 sender,
63 receiver,
64 }
65 }
66
67 /// The state of the task, as a boxed dynamic trait object.
68 fn boxed_state(self) -> Box<dyn TaskState<Event = S::Event>> {
69 Box::new(self.state) as Box<dyn TaskState<Event = S::Event>>
70 }
71
72 /// Spawn the task loop, consuming self. Will continue until
73 /// the task reaches some shutdown condition
74 pub fn run(mut self) -> JoinHandle<Box<dyn TaskState<Event = S::Event>>> {
75 spawn(async move {
76 loop {
77 match self.receiver.recv_direct().await {
78 Ok(input) => {
79 if *input == S::Event::shutdown_event() {
80 self.state.cancel_subtasks();
81
82 break self.boxed_state();
83 }
84
85 log!(
86 S::handle_event(&mut self.state, input, &self.sender, &self.receiver)
87 .await
88 );
89 },
90 Err(RecvError::Closed) => {
91 break self.boxed_state();
92 },
93 Err(e) => {
94 tracing::error!("Failed to receive from event stream Error: {}", e);
95 },
96 }
97 }
98 })
99 }
100}
101
102#[derive(Default)]
103/// A collection of tasks which can handle shutdown
104pub struct ConsensusTaskRegistry<EVENT> {
105 /// Tasks this registry controls
106 task_handles: Vec<JoinHandle<Box<dyn TaskState<Event = EVENT>>>>,
107}
108
109impl<EVENT: Send + Sync + Clone + TaskEvent> ConsensusTaskRegistry<EVENT> {
110 #[must_use]
111 /// Create a new task registry
112 pub fn new() -> Self {
113 ConsensusTaskRegistry {
114 task_handles: vec![],
115 }
116 }
117 /// Add a task to the registry
118 pub fn register(&mut self, handle: JoinHandle<Box<dyn TaskState<Event = EVENT>>>) {
119 self.task_handles.push(handle);
120 }
121 /// Try to cancel/abort the task this registry has
122 ///
123 /// # Panics
124 ///
125 /// Should not panic, unless awaiting on the JoinHandle in tokio fails.
126 pub async fn shutdown(&mut self) {
127 let handles = &mut self.task_handles;
128
129 while let Some(handle) = handles.pop() {
130 let _ = handle
131 .await
132 .map(|mut task_state| task_state.cancel_subtasks());
133 }
134 }
135 /// Take a task, run it, and register it
136 pub fn run_task<S>(&mut self, task: Task<S>)
137 where
138 S: TaskState<Event = EVENT> + Send + 'static,
139 {
140 self.register(task.run());
141 }
142
143 /// Wait for the results of all the tasks registered
144 /// # Panics
145 /// Panics if one of the tasks panicked
146 pub async fn join_all(self) -> Vec<Box<dyn TaskState<Event = EVENT>>> {
147 try_join_all(self.task_handles).await.unwrap()
148 }
149}
150
151#[derive(Default)]
152/// A collection of tasks which can handle shutdown
153pub struct NetworkTaskRegistry {
154 /// Tasks this registry controls
155 pub handles: Vec<JoinHandle<()>>,
156}
157
158impl NetworkTaskRegistry {
159 #[must_use]
160 /// Create a new task registry
161 pub fn new() -> Self {
162 NetworkTaskRegistry { handles: vec![] }
163 }
164
165 #[allow(clippy::unused_async)]
166 /// Shuts down all tasks managed by this instance.
167 ///
168 /// This function waits for all tasks to complete before returning.
169 ///
170 /// # Panics
171 ///
172 /// When using the tokio executor, this function will panic if any of the
173 /// tasks being joined return an error.
174 pub async fn shutdown(&mut self) {
175 let handles = std::mem::take(&mut self.handles);
176 try_join_all(handles)
177 .await
178 .expect("Failed to join all tasks during shutdown");
179 }
180
181 /// Add a task to the registry
182 pub fn register(&mut self, handle: JoinHandle<()>) {
183 self.handles.push(handle);
184 }
185}