hotshot_task/
task.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::sync::Arc;
8
9use async_broadcast::{Receiver, RecvError, Sender};
10use async_trait::async_trait;
11use futures::future::try_join_all;
12use hotshot_utils::anytrace::*;
13use tokio::task::{spawn, JoinHandle};
14
15/// Trait for events that long-running tasks handle
16pub trait TaskEvent: PartialEq {
17    /// The shutdown signal for this event type
18    ///
19    /// Note that this is necessarily uniform across all tasks.
20    /// Exiting the task loop is handled by the task spawner, rather than the task individually.
21    fn shutdown_event() -> Self;
22}
23
24#[async_trait]
25/// Type for mutable task state that can be used as the state for a `Task`
26pub trait TaskState: Send {
27    /// Type of event sent and received by the task
28    type Event: TaskEvent + Clone + Send + Sync;
29
30    /// Joins all subtasks.
31    fn cancel_subtasks(&mut self);
32
33    /// Handles an event, providing direct access to the specific channel we received the event on.
34    async fn handle_event(
35        &mut self,
36        event: Arc<Self::Event>,
37        _sender: &Sender<Arc<Self::Event>>,
38        _receiver: &Receiver<Arc<Self::Event>>,
39    ) -> Result<()>;
40}
41
42/// A basic task which loops waiting for events to come from `event_receiver`
43/// and then handles them using its state
44/// It sends events to other `Task`s through `sender`
45/// This should be used as the primary building block for long running
46/// or medium running tasks (i.e. anything that can't be described as a dependency task)
47pub struct Task<S: TaskState> {
48    /// The state of the task.  It is fed events from `receiver`
49    /// and mutated via `handle_event`.
50    state: S,
51    /// Sends events all tasks including itself
52    sender: Sender<Arc<S::Event>>,
53    /// Receives events that are broadcast from any task, including itself
54    receiver: Receiver<Arc<S::Event>>,
55}
56
57impl<S: TaskState + Send + 'static> Task<S> {
58    /// Create a new task
59    pub fn new(state: S, sender: Sender<Arc<S::Event>>, receiver: Receiver<Arc<S::Event>>) -> Self {
60        Task {
61            state,
62            sender,
63            receiver,
64        }
65    }
66
67    /// The state of the task, as a boxed dynamic trait object.
68    fn boxed_state(self) -> Box<dyn TaskState<Event = S::Event>> {
69        Box::new(self.state) as Box<dyn TaskState<Event = S::Event>>
70    }
71
72    /// Spawn the task loop, consuming self.  Will continue until
73    /// the task reaches some shutdown condition
74    pub fn run(mut self) -> JoinHandle<Box<dyn TaskState<Event = S::Event>>> {
75        spawn(async move {
76            loop {
77                match self.receiver.recv_direct().await {
78                    Ok(input) => {
79                        if *input == S::Event::shutdown_event() {
80                            self.state.cancel_subtasks();
81
82                            break self.boxed_state();
83                        }
84
85                        log!(
86                            S::handle_event(&mut self.state, input, &self.sender, &self.receiver)
87                                .await
88                        );
89                    },
90                    Err(RecvError::Closed) => {
91                        break self.boxed_state();
92                    },
93                    Err(e) => {
94                        tracing::error!("Failed to receive from event stream Error: {}", e);
95                    },
96                }
97            }
98        })
99    }
100}
101
102#[derive(Default)]
103/// A collection of tasks which can handle shutdown
104pub struct ConsensusTaskRegistry<EVENT> {
105    /// Tasks this registry controls
106    task_handles: Vec<JoinHandle<Box<dyn TaskState<Event = EVENT>>>>,
107}
108
109impl<EVENT: Send + Sync + Clone + TaskEvent> ConsensusTaskRegistry<EVENT> {
110    #[must_use]
111    /// Create a new task registry
112    pub fn new() -> Self {
113        ConsensusTaskRegistry {
114            task_handles: vec![],
115        }
116    }
117    /// Add a task to the registry
118    pub fn register(&mut self, handle: JoinHandle<Box<dyn TaskState<Event = EVENT>>>) {
119        self.task_handles.push(handle);
120    }
121    /// Try to cancel/abort the task this registry has
122    ///
123    /// # Panics
124    ///
125    /// Should not panic, unless awaiting on the JoinHandle in tokio fails.
126    pub async fn shutdown(&mut self) {
127        let handles = &mut self.task_handles;
128
129        while let Some(handle) = handles.pop() {
130            let _ = handle
131                .await
132                .map(|mut task_state| task_state.cancel_subtasks());
133        }
134    }
135    /// Take a task, run it, and register it
136    pub fn run_task<S>(&mut self, task: Task<S>)
137    where
138        S: TaskState<Event = EVENT> + Send + 'static,
139    {
140        self.register(task.run());
141    }
142
143    /// Wait for the results of all the tasks registered
144    /// # Panics
145    /// Panics if one of the tasks panicked
146    pub async fn join_all(self) -> Vec<Box<dyn TaskState<Event = EVENT>>> {
147        try_join_all(self.task_handles).await.unwrap()
148    }
149}
150
151#[derive(Default)]
152/// A collection of tasks which can handle shutdown
153pub struct NetworkTaskRegistry {
154    /// Tasks this registry controls
155    pub handles: Vec<JoinHandle<()>>,
156}
157
158impl NetworkTaskRegistry {
159    #[must_use]
160    /// Create a new task registry
161    pub fn new() -> Self {
162        NetworkTaskRegistry { handles: vec![] }
163    }
164
165    #[allow(clippy::unused_async)]
166    /// Shuts down all tasks managed by this instance.
167    ///
168    /// This function waits for all tasks to complete before returning.
169    ///
170    /// # Panics
171    ///
172    /// When using the tokio executor, this function will panic if any of the
173    /// tasks being joined return an error.
174    pub async fn shutdown(&mut self) {
175        let handles = std::mem::take(&mut self.handles);
176        try_join_all(handles)
177            .await
178            .expect("Failed to join all tasks during shutdown");
179    }
180
181    /// Add a task to the registry
182    pub fn register(&mut self, handle: JoinHandle<()>) {
183        self.handles.push(handle);
184    }
185}