request_response/lib.rs
1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5 any::Any,
6 collections::HashMap,
7 future::Future,
8 marker::PhantomData,
9 pin::Pin,
10 sync::{Arc, Weak},
11 time::{Duration, Instant},
12};
13
14use anyhow::{anyhow, Context, Result};
15use data_source::DataSource;
16use derive_more::derive::Deref;
17use hotshot_types::traits::signature_key::SignatureKey;
18use message::{Message, RequestMessage, ResponseMessage};
19use network::{Bytes, Receiver, Sender};
20use parking_lot::RwLock;
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25 spawn,
26 time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32/// The data source trait. Is what we use to derive the response data for a request
33pub mod data_source;
34/// The message type. Is the base type for all messages in the request-response protocol
35pub mod message;
36/// The network traits. Is what we use to send and receive messages over the network as
37/// the protocol
38pub mod network;
39/// The recipient source trait. Is what we use to get the recipients that a specific message should
40/// expect responses from
41pub mod recipient_source;
42/// The request trait. Is what we use to define a request and a corresponding response type
43pub mod request;
44/// Utility types and functions
45mod util;
46
47/// A type alias for the hash of a request
48pub type RequestHash = blake3::Hash;
49
50/// A type alias for the outgoing requests map
51pub type OutgoingRequestsMap<Req> =
52 Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
53
54/// A type alias for the list of tasks that are responding to requests
55pub type IncomingRequests<K> = NamedSemaphore<K>;
56
57/// A type alias for the list of tasks that are validating incoming responses
58pub type IncomingResponses = NamedSemaphore<()>;
59
60/// The type of request to make
61#[derive(PartialEq, Eq, Clone, Copy)]
62pub enum RequestType {
63 /// A request that can be satisfied by a single participant,
64 /// and as such will be batched to a few participants at a time
65 /// until one succeeds
66 Batched,
67 /// A request that needs most or all participants to respond,
68 /// and as such will be broadcasted to all participants
69 Broadcast,
70}
71
72/// The errors that can occur when making a request for data
73#[derive(thiserror::Error, Debug)]
74pub enum RequestError {
75 /// The request timed out
76 #[error("request timed out")]
77 Timeout,
78 /// The request was invalid
79 #[error("request was invalid")]
80 InvalidRequest(anyhow::Error),
81 /// Other errors
82 #[error("other error")]
83 Other(anyhow::Error),
84}
85
86/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
87/// [`Response`] types will need to implement this trait
88pub trait Serializable: Sized {
89 /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
90 /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
91 /// you may end up with collisions as the request hash is used as a unique identifier
92 ///
93 /// # Errors
94 /// - If the type cannot be serialized to a byte array
95 fn to_bytes(&self) -> Result<Vec<u8>>;
96
97 /// Deserialize the type from a byte array
98 ///
99 /// # Errors
100 /// - If the byte array is not a valid representation of the type
101 fn from_bytes(bytes: &[u8]) -> Result<Self>;
102}
103
104/// The underlying configuration for the request-response protocol
105#[derive(Clone)]
106pub struct RequestResponseConfig {
107 /// The timeout for incoming requests. Do not respond to a request after this threshold
108 /// has passed.
109 pub incoming_request_ttl: Duration,
110 /// The maximum amount of time we will spend trying to both derive a response for a request and
111 /// send the response over the wire.
112 pub incoming_request_timeout: Duration,
113 /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
114 /// an attack where a malicious participant sends us a bunch of requests that take a long time to
115 /// validate.
116 pub incoming_response_timeout: Duration,
117 /// The batch size for outgoing requests. This is the number of request messages that we will
118 /// send out at a time for a single request before waiting for the [`request_batch_interval`].
119 pub request_batch_size: usize,
120 /// The time to wait (per request) between sending out batches of request messages
121 pub request_batch_interval: Duration,
122 /// The maximum (global) number of incoming requests that can be processed at any given time.
123 pub max_incoming_requests: usize,
124 /// The maximum number of incoming requests that can be processed for a single key at any given time.
125 pub max_incoming_requests_per_key: usize,
126 /// The maximum (global) number of incoming responses that can be processed at any given time.
127 /// We need this because responses coming in need to be validated [asynchronously] that they
128 /// satisfy the request they are responding to
129 pub max_incoming_responses: usize,
130}
131
132/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
133/// need to wrap it in an `Arc`
134#[derive(Deref)]
135pub struct RequestResponse<
136 S: Sender<K>,
137 R: Receiver,
138 Req: Request,
139 RS: RecipientSource<Req, K>,
140 DS: DataSource<Req>,
141 K: SignatureKey + 'static,
142> {
143 #[deref]
144 /// The inner implementation of the request-response protocol
145 pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
146 /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
147 _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
148}
149
150/// We need to manually implement the `Clone` trait for this type because deriving
151/// `Deref` will cause an issue where it tries to clone the inner field instead
152impl<
153 S: Sender<K>,
154 R: Receiver,
155 Req: Request,
156 RS: RecipientSource<Req, K>,
157 DS: DataSource<Req>,
158 K: SignatureKey + 'static,
159 > Clone for RequestResponse<S, R, Req, RS, DS, K>
160{
161 fn clone(&self) -> Self {
162 Self {
163 inner: Arc::clone(&self.inner),
164 _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
165 }
166 }
167}
168
169impl<
170 S: Sender<K>,
171 R: Receiver,
172 Req: Request,
173 RS: RecipientSource<Req, K>,
174 DS: DataSource<Req>,
175 K: SignatureKey + 'static,
176 > RequestResponse<S, R, Req, RS, DS, K>
177{
178 /// Create a new [`RequestResponseProtocol`]
179 pub fn new(
180 // The configuration for the protocol
181 config: RequestResponseConfig,
182 // The network sender that [`RequestResponseProtocol`] will use to send messages
183 sender: S,
184 // The network receiver that [`RequestResponseProtocol`] will use to receive messages
185 receiver: R,
186 // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
187 // that a specific message should expect responses from
188 recipient_source: RS,
189 // The [response] data source that [`RequestResponseProtocol`] will use to derive the
190 // response data for a specific request
191 data_source: DS,
192 ) -> Self {
193 // Create the outgoing requests map
194 let outgoing_requests = OutgoingRequestsMap::default();
195
196 // Create the inner implementation
197 let inner = Arc::new(RequestResponseInner {
198 config,
199 sender,
200 recipient_source,
201 data_source,
202 outgoing_requests,
203 phantom_data: PhantomData,
204 });
205
206 // Start the task that receives messages and handles them. This will automatically get cancelled
207 // when the protocol is dropped
208 let inner_clone = Arc::clone(&inner);
209 let receive_task_handle =
210 AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
211
212 // Return the protocol
213 Self {
214 inner,
215 _receiving_task_handle: Arc::new(receive_task_handle),
216 }
217 }
218}
219
220/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
221type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
222
223/// A type alias for the future that validates a response
224type ResponseValidationFuture =
225 Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
226
227/// A type alias for the function that returns the above future
228type ResponseValidationFn<R> =
229 Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
230
231/// The inner implementation for the request-response protocol
232pub struct RequestResponseInner<
233 S: Sender<K>,
234 R: Receiver,
235 Req: Request,
236 RS: RecipientSource<Req, K>,
237 DS: DataSource<Req>,
238 K: SignatureKey + 'static,
239> {
240 /// The configuration of the protocol
241 config: RequestResponseConfig,
242 /// The sender to use for the protocol
243 pub sender: S,
244 /// The recipient source to use for the protocol
245 pub recipient_source: RS,
246 /// The data source to use for the protocol
247 data_source: DS,
248 /// The map of currently active, outgoing requests
249 outgoing_requests: OutgoingRequestsMap<Req>,
250 /// Phantom data to help with type inference
251 phantom_data: PhantomData<(K, R, Req, DS)>,
252}
253impl<
254 S: Sender<K>,
255 R: Receiver,
256 Req: Request,
257 RS: RecipientSource<Req, K>,
258 DS: DataSource<Req>,
259 K: SignatureKey + 'static,
260 > RequestResponseInner<S, R, Req, RS, DS, K>
261{
262 /// Request something from the protocol indefinitely until we get a response
263 /// or there was a critical error (e.g. the request could not be signed)
264 ///
265 /// # Errors
266 /// - If the request was invalid
267 /// - If there was a critical error (e.g. the channel was closed)
268 pub async fn request_indefinitely<F, Fut, O>(
269 self: &Arc<Self>,
270 public_key: &K,
271 private_key: &K::PrivateKey,
272 // The type of request to make
273 request_type: RequestType,
274 // The estimated TTL of other participants. This is used to decide when to
275 // stop making requests and sign a new one
276 estimated_request_ttl: Duration,
277 // The request to make
278 request: Req,
279 // The response validation function
280 response_validation_fn: F,
281 ) -> std::result::Result<O, RequestError>
282 where
283 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
284 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
285 O: Send + Sync + 'static + Clone,
286 {
287 loop {
288 // Sign a request message
289 let request_message = RequestMessage::new_signed(public_key, private_key, &request)
290 .map_err(|e| {
291 RequestError::InvalidRequest(anyhow::anyhow!(
292 "failed to sign request message: {e}"
293 ))
294 })?;
295
296 // Request the data, handling the errors appropriately
297 match self
298 .request(
299 request_message,
300 request_type,
301 estimated_request_ttl,
302 response_validation_fn.clone(),
303 )
304 .await
305 {
306 Ok(response) => return Ok(response),
307 Err(RequestError::Timeout) => continue,
308 Err(e) => return Err(e),
309 }
310 }
311 }
312
313 /// Request something from the protocol and wait for the response. This function
314 /// will join with an existing request for the same data (determined by `Blake3` hash),
315 /// however both will make requests until the timeout is reached
316 ///
317 /// # Errors
318 /// - If the request times out
319 /// - If the channel is closed (this is an internal error)
320 /// - If the request we sign is invalid
321 pub async fn request<F, Fut, O>(
322 self: &Arc<Self>,
323 request_message: RequestMessage<Req, K>,
324 request_type: RequestType,
325 timeout_duration: Duration,
326 response_validation_fn: F,
327 ) -> std::result::Result<O, RequestError>
328 where
329 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
330 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
331 O: Send + Sync + 'static + Clone,
332 {
333 timeout(timeout_duration, async move {
334 // Calculate the hash of the request
335 let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
336 RequestError::InvalidRequest(anyhow::anyhow!(
337 "failed to serialize request message: {e}"
338 ))
339 })?);
340
341 let request = {
342 // Get a write lock on the outgoing requests map
343 let mut outgoing_requests_write = self.outgoing_requests.write();
344
345 // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
346 // the existing one has been dropped and not yet removed
347 if let Some(outgoing_request) = outgoing_requests_write
348 .get(&request_hash)
349 .and_then(Weak::upgrade)
350 {
351 OutgoingRequest(outgoing_request)
352 } else {
353 // Create a new broadcast channel for the response
354 let (sender, receiver) = async_broadcast::broadcast(1);
355
356 // Modify the response validation function to return an `Arc<dyn Any>`
357 let response_validation_fn =
358 Box::new(move |request: &Req, response: Req::Response| {
359 let fut = response_validation_fn(request, response);
360 Box::pin(
361 async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
362 ) as ResponseValidationFuture
363 });
364
365 // Create a new outgoing request
366 let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
367 sender,
368 receiver,
369 response_validation_fn,
370 request: request_message.request.clone(),
371 outgoing_requests: Arc::clone(&self.outgoing_requests),
372 request_hash,
373 }));
374
375 // Write the new outgoing request to the map
376 outgoing_requests_write
377 .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379 // Return the new outgoing request
380 outgoing_request
381 }
382 };
383
384 // Create a request message and serialize it
385 let message = Bytes::from(
386 Message::Request(request_message.clone())
387 .to_bytes()
388 .map_err(|e| {
389 RequestError::InvalidRequest(anyhow::anyhow!(
390 "failed to serialize request message: {e}"
391 ))
392 })?,
393 );
394
395 // Create a place to put the handle for the batched sending task. We need this because
396 // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397 // gets cancelled or returns
398 let mut _batched_sending_task = None;
399
400 // Match on the type of request
401 if request_type == RequestType::Broadcast {
402 trace!("Sending request {request_message:?} to all participants");
403
404 // If the message is a broadcast request, just send it to all participants
405 self.sender
406 .send_broadcast_message(&message)
407 .await
408 .map_err(|e| {
409 RequestError::Other(anyhow::anyhow!(
410 "failed to send broadcast message: {e}"
411 ))
412 })?;
413 } else {
414 // If the message is a batched request, we need to batch it with other requests
415
416 // Get the recipients that the request should expect responses from. Shuffle them so
417 // that we don't always send to the same recipients in the same order
418 let mut recipients = self
419 .recipient_source
420 .get_expected_responders(&request_message.request)
421 .await
422 .map_err(|e| {
423 RequestError::InvalidRequest(anyhow::anyhow!(
424 "failed to get expected responders for request: {e}"
425 ))
426 })?;
427 recipients.shuffle(&mut rand::thread_rng());
428
429 // Get the current time so we can check when the timeout has elapsed
430 let start_time = Instant::now();
431
432 // Spawn a task that sends out requests to the network
433 let self_clone = Arc::clone(self);
434 let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435 // Create a bounded queue for the outgoing requests. We use this to make sure
436 // we have less than [`config.request_batch_size`] requests in flight at any time.
437 //
438 // When newer requests are added, older ones are removed from the queue. Because we use
439 // `AbortOnDropHandle`, the older ones will automatically get cancelled
440 let mut outgoing_requests =
441 BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443 // While the timeout hasn't elapsed, send out requests to the network
444 while start_time.elapsed() < timeout_duration {
445 // Send out requests to the network in their own separate tasks
446 for recipient_batch in
447 recipients.chunks(self_clone.config.request_batch_size)
448 {
449 for recipient in recipient_batch {
450 // Clone ourselves, the message, and the recipient so they can be moved
451 let self_clone = Arc::clone(&self_clone);
452 let request_message_clone = request_message.clone();
453 let recipient_clone = recipient.clone();
454 let message_clone = Arc::clone(&message);
455
456 // Spawn the task that sends the request to the participant
457 let individual_sending_task = spawn(async move {
458 trace!(
459 "Sending request {request_message_clone:?} to \
460 {recipient_clone:?}"
461 );
462
463 let _ = self_clone
464 .sender
465 .send_direct_message(&message_clone, recipient_clone)
466 .await;
467 });
468
469 // Add the sending task to the queue
470 outgoing_requests
471 .push(AbortOnDropHandle::new(individual_sending_task));
472 }
473
474 // After we send the batch out, wait the [`config.request_batch_interval`]
475 // before sending the next one
476 sleep(self_clone.config.request_batch_interval).await;
477 }
478 }
479 }));
480
481 // Store the handle so it doesn't get dropped
482 _batched_sending_task = Some(batched_sending_handle);
483 }
484
485 // Wait for a response on the channel
486 request
487 .receiver
488 .clone()
489 .recv()
490 .await
491 .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492 })
493 .await
494 .map_err(|_| RequestError::Timeout)
495 .and_then(|result| result)
496 .and_then(|result| {
497 result.downcast::<O>().map_err(|e| {
498 RequestError::Other(anyhow::anyhow!(
499 "failed to downcast response to expected type: {e:?}"
500 ))
501 })
502 })
503 .map(|result| Arc::unwrap_or_clone(result))
504 }
505
506 /// The task responsible for receiving messages from the receiver and handling them
507 async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508 // Upper bound the number of outgoing and incoming responses
509 let mut incoming_requests = NamedSemaphore::new(
510 self.config.max_incoming_requests_per_key,
511 Some(self.config.max_incoming_requests),
512 );
513 let mut incoming_responses = NamedSemaphore::new(self.config.max_incoming_responses, None);
514
515 // While the receiver is open, we receive messages and handle them
516 loop {
517 // Try to receive a message
518 match receiver.receive_message().await {
519 Ok(message) => {
520 // Deserialize the message, warning if it fails
521 let message = match Message::from_bytes(&message) {
522 Ok(message) => message,
523 Err(e) => {
524 warn!("Received invalid message: {e:#}");
525 continue;
526 },
527 };
528
529 // Handle the message based on its type
530 match message {
531 Message::Request(request_message) => {
532 self.handle_request(request_message, &mut incoming_requests);
533 },
534 Message::Response(response_message) => {
535 self.handle_response(response_message, &mut incoming_responses);
536 },
537 }
538 },
539 // An error here means the receiver will _NEVER_ receive any more messages
540 Err(e) => {
541 error!("Request/response receive task exited: {e:#}");
542 return;
543 },
544 }
545 }
546 }
547
548 /// Handle a request sent to us
549 fn handle_request(
550 self: &Arc<Self>,
551 request_message: RequestMessage<Req, K>,
552 incoming_requests: &mut IncomingRequests<K>,
553 ) {
554 trace!("Handling request {:?}", request_message);
555
556 // Spawn a task to:
557 // - Validate the request
558 // - Derive the response data (check if we have it)
559 // - Send the response to the requester
560 let self_clone = Arc::clone(self);
561
562 // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
563 // either globally or per-key
564 let permit = incoming_requests.try_acquire(request_message.public_key.clone());
565 match permit {
566 Ok(ref permit) => permit,
567 Err(NamedSemaphoreError::PerKeyLimitReached) => {
568 warn!(
569 "Failed to process request from {}: too many requests from the same key are \
570 already being processed",
571 request_message.public_key
572 );
573 return;
574 },
575 Err(NamedSemaphoreError::GlobalLimitReached) => {
576 warn!(
577 "Failed to process request from {}: too many requests are already being \
578 processed",
579 request_message.public_key
580 );
581 return;
582 },
583 };
584
585 tokio::spawn(async move {
586 let result = timeout(self_clone.config.incoming_request_timeout, async move {
587 // Validate the request message. This includes:
588 // - Checking the signature and making sure it's valid
589 // - Checking the timestamp and making sure it's not too old
590 // - Calling the request's application-specific validation function
591 request_message
592 .validate(self_clone.config.incoming_request_ttl)
593 .await
594 .with_context(|| "failed to validate request")?;
595
596 // Try to fetch the response data from the data source
597 let response = self_clone
598 .data_source
599 .derive_response_for(&request_message.request)
600 .await
601 .with_context(|| "failed to derive response for request")?;
602
603 // Create the response message and serialize it
604 let response = Bytes::from(
605 Message::Response::<Req, K>(ResponseMessage {
606 request_hash: blake3::hash(&request_message.request.to_bytes()?),
607 response,
608 })
609 .to_bytes()
610 .with_context(|| "failed to serialize response message")?,
611 );
612
613 // Send the response to the requester
614 self_clone
615 .sender
616 .send_direct_message(&response, request_message.public_key)
617 .await
618 .with_context(|| "failed to send response to requester")?;
619
620 // Drop the permit
621 _ = permit;
622 drop(permit);
623
624 Ok::<(), anyhow::Error>(())
625 })
626 .await
627 .map_err(|_| anyhow::anyhow!("timed out while sending response"))
628 .and_then(|result| result);
629
630 if let Err(e) = result {
631 debug!("Failed to send response to requester: {e:#}");
632 }
633 });
634 }
635
636 /// Handle a response sent to us
637 fn handle_response(
638 self: &Arc<Self>,
639 response: ResponseMessage<Req>,
640 incoming_responses: &mut IncomingResponses,
641 ) {
642 trace!("Handling response {response:?}");
643
644 // Get the entry in the map, ignoring it if it doesn't exist
645 let Some(outgoing_request) = self
646 .outgoing_requests
647 .read()
648 .get(&response.request_hash)
649 .cloned()
650 .and_then(|r| r.upgrade())
651 else {
652 return;
653 };
654
655 // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
656 let permit = incoming_responses.try_acquire(());
657 let Ok(permit) = permit else {
658 warn!("Failed to process response: too many responses are already being processed");
659 return;
660 };
661
662 // Spawn a task to validate the response and send it to the requester (us)
663 let response_validate_timeout = self.config.incoming_response_timeout;
664 tokio::spawn(async move {
665 if timeout(response_validate_timeout, async move {
666 // Make sure the response is valid for the given request
667 let validation_result = match (outgoing_request.response_validation_fn)(
668 &outgoing_request.request,
669 response.response,
670 )
671 .await
672 {
673 Ok(validation_result) => validation_result,
674 Err(e) => {
675 debug!("Received invalid response: {e:#}");
676 return;
677 },
678 };
679
680 // Send the response to the requester (the user of [`RequestResponse::request`])
681 let _ = outgoing_request.sender.try_broadcast(validation_result);
682
683 // Drop the permit
684 _ = permit;
685 drop(permit);
686 })
687 .await
688 .is_err()
689 {
690 warn!("Timed out while validating response");
691 }
692 });
693 }
694}
695
696/// An outgoing request. This is what we use to track a request and its corresponding response
697/// in the protocol
698#[derive(Clone, Deref)]
699pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
700
701/// The inner implementation of an outgoing request
702pub struct OutgoingRequestInner<R: Request> {
703 /// The sender to use for the protocol
704 sender: async_broadcast::Sender<ThreadSafeAny>,
705 /// The receiver to use for the protocol
706 receiver: async_broadcast::Receiver<ThreadSafeAny>,
707
708 /// The request that we are waiting for a response to
709 request: R,
710
711 /// The function used to validate the response
712 response_validation_fn: ResponseValidationFn<R>,
713
714 /// A copy of the map of currently active, outgoing requests
715 outgoing_requests: OutgoingRequestsMap<R>,
716 /// The hash of the request. We need this so we can remove ourselves from the map
717 request_hash: RequestHash,
718}
719
720impl<R: Request> Drop for OutgoingRequestInner<R> {
721 fn drop(&mut self) {
722 self.outgoing_requests.write().remove(&self.request_hash);
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use std::{
729 collections::HashMap,
730 sync::{atomic::AtomicBool, Mutex},
731 };
732
733 use async_trait::async_trait;
734 use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
735 use rand::Rng;
736 use tokio::{sync::mpsc, task::JoinSet};
737
738 use super::*;
739
740 /// This test makes sure that when all references to an outgoing request are dropped, it is
741 /// removed from the outgoing requests map
742 #[test]
743 fn test_outgoing_request_drop() {
744 // Create an outgoing requests map
745 let outgoing_requests = OutgoingRequestsMap::default();
746
747 // Create an outgoing request
748 let (sender, receiver) = async_broadcast::broadcast(1);
749 let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
750 sender,
751 receiver,
752 request: TestRequest(vec![1, 2, 3]),
753 response_validation_fn: Box::new(|_request, _response| {
754 Box::pin(async move { Ok(Arc::new(()) as ThreadSafeAny) })
755 as ResponseValidationFuture
756 }),
757 outgoing_requests: Arc::clone(&outgoing_requests),
758 request_hash: blake3::hash(&[1, 2, 3]),
759 }));
760
761 // Insert the outgoing request into the map
762 outgoing_requests.write().insert(
763 outgoing_request.request_hash,
764 Arc::downgrade(&outgoing_request.0),
765 );
766
767 // Clone the outgoing request
768 let outgoing_request_clone = outgoing_request.clone();
769
770 // Drop the outgoing request
771 drop(outgoing_request);
772
773 // Make sure nothing has been removed
774 assert_eq!(outgoing_requests.read().len(), 1);
775
776 // Drop the clone
777 drop(outgoing_request_clone);
778
779 // Make sure it has been removed
780 assert_eq!(outgoing_requests.read().len(), 0);
781 }
782
783 /// A test sender that has a list of all the participants in the network
784 #[derive(Clone)]
785 pub struct TestSender {
786 network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
787 }
788
789 /// An implementation of the [`Sender`] trait for the [`TestSender`] type
790 #[async_trait]
791 impl Sender<BLSPubKey> for TestSender {
792 async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
793 self.network
794 .get(&recipient)
795 .ok_or(anyhow::anyhow!("recipient not found"))?
796 .send(Arc::clone(message))
797 .await
798 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
799
800 Ok(())
801 }
802
803 async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
804 for sender in self.network.values() {
805 sender
806 .send(Arc::clone(message))
807 .await
808 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
809 }
810 Ok(())
811 }
812 }
813
814 // Implement the [`RecipientSource`] trait for the [`TestSender`] type
815 #[async_trait]
816 impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
817 async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
818 // Get all the participants in the network
819 Ok(self.network.keys().copied().collect())
820 }
821 }
822
823 // Create a test request that is just some bytes
824 #[derive(Clone, Debug)]
825 struct TestRequest(Vec<u8>);
826
827 // Implement the [`Serializable`] trait for the [`TestRequest`] type
828 impl Serializable for TestRequest {
829 fn to_bytes(&self) -> Result<Vec<u8>> {
830 Ok(self.0.clone())
831 }
832
833 fn from_bytes(bytes: &[u8]) -> Result<Self> {
834 Ok(TestRequest(bytes.to_vec()))
835 }
836 }
837
838 // Implement the [`Request`] trait for the [`TestRequest`] type
839 #[async_trait]
840 impl Request for TestRequest {
841 type Response = Vec<u8>;
842 async fn validate(&self) -> Result<()> {
843 Ok(())
844 }
845 }
846
847 // Create a test data source that pretends to have the data or not
848 #[derive(Clone)]
849 struct TestDataSource {
850 /// Whether we have the data or not
851 has_data: bool,
852 /// The time at which the data will be available if we have it
853 data_available_time: Instant,
854
855 /// Whether or not the data will be taken once served
856 take_data: bool,
857 /// Whether or not the data has been taken
858 taken: Arc<AtomicBool>,
859 }
860
861 #[async_trait]
862 impl DataSource<TestRequest> for TestDataSource {
863 async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
864 // Return a response if we hit the hit rate
865 if self.has_data && Instant::now() >= self.data_available_time {
866 if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
867 return Err(anyhow::anyhow!("data already taken"));
868 }
869 Ok(blake3::hash(&request.0).as_bytes().to_vec())
870 } else {
871 Err(anyhow::anyhow!("did not have the data"))
872 }
873 }
874 }
875
876 /// Create and return a default protocol configuration
877 fn default_protocol_config() -> RequestResponseConfig {
878 RequestResponseConfig {
879 incoming_request_ttl: Duration::from_secs(40),
880 incoming_request_timeout: Duration::from_secs(40),
881 request_batch_size: 10,
882 request_batch_interval: Duration::from_millis(100),
883 max_incoming_requests: 10,
884 max_incoming_requests_per_key: 1,
885 incoming_response_timeout: Duration::from_secs(1),
886 max_incoming_responses: 5,
887 }
888 }
889
890 /// Create fully connected test networks with `num_participants` participants
891 fn create_participants(
892 num: usize,
893 ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
894 // The entire network
895 let mut network = HashMap::new();
896
897 // All receivers in the network
898 let mut receivers = Vec::new();
899
900 // All keypairs in the network
901 let mut keypairs = Vec::new();
902
903 // For each participant,
904 for i in 0..num {
905 // Create a unique `BLSPubKey`
906 let (public_key, private_key) =
907 BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
908
909 // Add the keypair to the list
910 keypairs.push((public_key, private_key));
911
912 // Create a channel for sending and receiving messages
913 let (sender, receiver) = mpsc::channel::<Bytes>(100);
914
915 // Add the participant to the network
916 network.insert(public_key, sender);
917
918 // Add the receiver to the list of receivers
919 receivers.push(receiver);
920 }
921
922 // Create a test sender from the network
923 let sender = TestSender {
924 network: Arc::new(network),
925 };
926
927 // Return all senders and receivers
928 receivers
929 .into_iter()
930 .zip(keypairs)
931 .map(|(r, k)| (sender.clone(), r, k))
932 .collect()
933 }
934
935 /// The configuration for an integration test
936 #[derive(Clone)]
937 struct IntegrationTestConfig {
938 /// The request response protocol configuration
939 request_response_config: RequestResponseConfig,
940 /// The number of participants in the network
941 num_participants: usize,
942 /// The number of participants that have the data
943 num_participants_with_data: usize,
944 /// The timeout for the requests
945 request_timeout: Duration,
946 /// The delay before the nodes have the data available
947 data_available_delay: Duration,
948 }
949
950 /// The result of an integration test
951 struct IntegrationTestResult {
952 /// The number of nodes that received a response
953 num_succeeded: usize,
954 }
955
956 /// Run an integration test with the given parameters
957 async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
958 // Create a fully connected network with `num_participants` participants
959 let participants = create_participants(config.num_participants);
960
961 // Create a join set to wait for all the tasks to finish
962 let mut join_set = JoinSet::new();
963
964 // We need to keep these here so they don't get dropped
965 let handles = Arc::new(Mutex::new(Vec::new()));
966
967 // For each one, create a new [`RequestResponse`] protocol
968 for (i, (sender, receiver, (public_key, private_key))) in
969 participants.into_iter().enumerate()
970 {
971 let config_clone = config.request_response_config.clone();
972 let handles_clone = Arc::clone(&handles);
973 join_set.spawn(async move {
974 let protocol = RequestResponse::new(
975 config_clone,
976 sender.clone(),
977 receiver,
978 sender,
979 TestDataSource {
980 has_data: i < config.num_participants_with_data,
981 data_available_time: Instant::now() + config.data_available_delay,
982 take_data: false,
983 taken: Arc::new(AtomicBool::new(false)),
984 },
985 );
986
987 // Add the handle to the handles list so it doesn't get dropped and
988 // cancelled
989 #[allow(clippy::used_underscore_binding)]
990 handles_clone
991 .lock()
992 .unwrap()
993 .push(Arc::clone(&protocol._receiving_task_handle));
994
995 // Create a random request
996 let request = TestRequest(vec![rand::thread_rng().gen(); 100]);
997
998 // Get the hash of the request
999 let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
1000
1001 // Create a new request message
1002 let request = RequestMessage::new_signed(&public_key, &private_key, &request)
1003 .expect("failed to create request message");
1004
1005 // Request the data from the protocol
1006 let response = protocol
1007 .request(
1008 request,
1009 RequestType::Batched,
1010 config.request_timeout,
1011 |_request, response| async move { Ok(response) },
1012 )
1013 .await?;
1014
1015 // Make sure the response is the hash of the request
1016 assert_eq!(response, request_hash);
1017
1018 Ok::<(), anyhow::Error>(())
1019 });
1020 }
1021
1022 // Wait for all the tasks to finish
1023 let mut num_succeeded = config.num_participants;
1024 while let Some(result) = join_set.join_next().await {
1025 if result.is_err() || result.unwrap().is_err() {
1026 num_succeeded -= 1;
1027 }
1028 }
1029
1030 IntegrationTestResult { num_succeeded }
1031 }
1032
1033 /// Test the integration of the protocol with 50% of the participants having the data
1034 #[tokio::test(flavor = "multi_thread")]
1035 async fn test_integration_50_0s() {
1036 // Build a config
1037 let config = IntegrationTestConfig {
1038 request_response_config: default_protocol_config(),
1039 num_participants: 100,
1040 num_participants_with_data: 50,
1041 request_timeout: Duration::from_secs(40),
1042 data_available_delay: Duration::from_secs(0),
1043 };
1044
1045 // Run the test, making sure all the requests succeed
1046 let result = run_integration_test(config).await;
1047 assert_eq!(result.num_succeeded, 100);
1048 }
1049
1050 /// Test the integration of the protocol when nobody has the data. Make sure we don't
1051 /// get any responses
1052 #[tokio::test(flavor = "multi_thread")]
1053 async fn test_integration_0() {
1054 // Build a config
1055 let config = IntegrationTestConfig {
1056 request_response_config: default_protocol_config(),
1057 num_participants: 100,
1058 num_participants_with_data: 0,
1059 request_timeout: Duration::from_secs(40),
1060 data_available_delay: Duration::from_secs(0),
1061 };
1062
1063 // Run the test
1064 let result = run_integration_test(config).await;
1065
1066 // Make sure all the requests succeeded
1067 assert_eq!(result.num_succeeded, 0);
1068 }
1069
1070 /// Test the integration of the protocol when one node has the data after
1071 /// a delay of 1s
1072 #[tokio::test(flavor = "multi_thread")]
1073 async fn test_integration_1_1s() {
1074 // Build a config
1075 let config = IntegrationTestConfig {
1076 request_response_config: default_protocol_config(),
1077 num_participants: 100,
1078 num_participants_with_data: 1,
1079 request_timeout: Duration::from_secs(40),
1080 data_available_delay: Duration::from_secs(2),
1081 };
1082
1083 // Run the test
1084 let result = run_integration_test(config).await;
1085
1086 // Make sure all the requests succeeded
1087 assert_eq!(result.num_succeeded, 100);
1088 }
1089
1090 /// Test that we can join an existing request for the same data and get the same (single) response
1091 #[tokio::test(flavor = "multi_thread")]
1092 async fn test_join_existing_request() {
1093 // Build a config
1094 let config = default_protocol_config();
1095
1096 // Create two participants
1097 let mut participants = Vec::new();
1098
1099 for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1100 // For each, create a new [`RequestResponse`] protocol
1101 let protocol = RequestResponse::new(
1102 config.clone(),
1103 sender.clone(),
1104 receiver,
1105 sender,
1106 TestDataSource {
1107 take_data: true,
1108 has_data: true,
1109 data_available_time: Instant::now() + Duration::from_secs(2),
1110 taken: Arc::new(AtomicBool::new(false)),
1111 },
1112 );
1113
1114 // Add the participants to the list
1115 participants.push((protocol, public_key, private_key));
1116 }
1117
1118 // Take the first participant
1119 let one = Arc::new(participants.remove(0));
1120
1121 // Create the request that they should all be able to join on
1122 let request = TestRequest(vec![rand::thread_rng().gen(); 100]);
1123
1124 // Create a join set to wait for all the tasks to finish
1125 let mut join_set = JoinSet::new();
1126
1127 // Make 10 requests with the same hash
1128 for _ in 0..10 {
1129 // Clone the first participant
1130 let one_clone = Arc::clone(&one);
1131
1132 // Clone the request
1133 let request_clone = request.clone();
1134
1135 // Spawn a task to request the data
1136 join_set.spawn(async move {
1137 // Create a new, signed request message
1138 let request_message =
1139 RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1140
1141 // Start requesting it
1142 one_clone
1143 .0
1144 .request(
1145 request_message,
1146 RequestType::Batched,
1147 Duration::from_secs(20),
1148 |_request, response| async move { Ok(response) },
1149 )
1150 .await?;
1151
1152 Ok::<(), anyhow::Error>(())
1153 });
1154 }
1155
1156 // Wait for all the tasks to finish, making sure they all succeed
1157 while let Some(result) = join_set.join_next().await {
1158 result
1159 .expect("failed to join task")
1160 .expect("failed to request data");
1161 }
1162 }
1163}