request_response/
lib.rs

1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5    any::Any,
6    collections::HashMap,
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    sync::{Arc, Weak},
11    time::{Duration, Instant},
12};
13
14use anyhow::{anyhow, Context, Result};
15use data_source::DataSource;
16use derive_more::derive::Deref;
17use hotshot_types::traits::signature_key::SignatureKey;
18use message::{Message, RequestMessage, ResponseMessage};
19use network::{Bytes, Receiver, Sender};
20use parking_lot::RwLock;
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25    spawn,
26    time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32/// The data source trait. Is what we use to derive the response data for a request
33pub mod data_source;
34/// The message type. Is the base type for all messages in the request-response protocol
35pub mod message;
36/// The network traits. Is what we use to send and receive messages over the network as
37/// the protocol
38pub mod network;
39/// The recipient source trait. Is what we use to get the recipients that a specific message should
40/// expect responses from
41pub mod recipient_source;
42/// The request trait. Is what we use to define a request and a corresponding response type
43pub mod request;
44/// Utility types and functions
45mod util;
46
47/// A type alias for the hash of a request
48pub type RequestHash = blake3::Hash;
49
50/// A type alias for the outgoing requests map
51pub type OutgoingRequestsMap<Req> =
52    Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
53
54/// A type alias for the list of tasks that are responding to requests
55pub type IncomingRequests<K> = NamedSemaphore<K>;
56
57/// A type alias for the list of tasks that are validating incoming responses
58pub type IncomingResponses = NamedSemaphore<()>;
59
60/// The type of request to make
61#[derive(PartialEq, Eq, Clone, Copy)]
62pub enum RequestType {
63    /// A request that can be satisfied by a single participant,
64    /// and as such will be batched to a few participants at a time
65    /// until one succeeds
66    Batched,
67    /// A request that needs most or all participants to respond,
68    /// and as such will be broadcasted to all participants
69    Broadcast,
70}
71
72/// The errors that can occur when making a request for data
73#[derive(thiserror::Error, Debug)]
74pub enum RequestError {
75    /// The request timed out
76    #[error("request timed out")]
77    Timeout,
78    /// The request was invalid
79    #[error("request was invalid")]
80    InvalidRequest(anyhow::Error),
81    /// Other errors
82    #[error("other error")]
83    Other(anyhow::Error),
84}
85
86/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
87/// [`Response`] types will need to implement this trait
88pub trait Serializable: Sized {
89    /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
90    /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
91    /// you may end up with collisions as the request hash is used as a unique identifier
92    ///
93    /// # Errors
94    /// - If the type cannot be serialized to a byte array
95    fn to_bytes(&self) -> Result<Vec<u8>>;
96
97    /// Deserialize the type from a byte array
98    ///
99    /// # Errors
100    /// - If the byte array is not a valid representation of the type
101    fn from_bytes(bytes: &[u8]) -> Result<Self>;
102}
103
104/// The underlying configuration for the request-response protocol
105#[derive(Clone)]
106pub struct RequestResponseConfig {
107    /// The timeout for incoming requests. Do not respond to a request after this threshold
108    /// has passed.
109    pub incoming_request_ttl: Duration,
110    /// The maximum amount of time we will spend trying to both derive a response for a request and
111    /// send the response over the wire.
112    pub incoming_request_timeout: Duration,
113    /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
114    /// an attack where a malicious participant sends us a bunch of requests that take a long time to
115    /// validate.
116    pub incoming_response_timeout: Duration,
117    /// The batch size for outgoing requests. This is the number of request messages that we will
118    /// send out at a time for a single request before waiting for the [`request_batch_interval`].
119    pub request_batch_size: usize,
120    /// The time to wait (per request) between sending out batches of request messages
121    pub request_batch_interval: Duration,
122    /// The maximum (global) number of incoming requests that can be processed at any given time.
123    pub max_incoming_requests: usize,
124    /// The maximum number of incoming requests that can be processed for a single key at any given time.
125    pub max_incoming_requests_per_key: usize,
126    /// The maximum (global) number of incoming responses that can be processed at any given time.
127    /// We need this because responses coming in need to be validated [asynchronously] that they
128    /// satisfy the request they are responding to
129    pub max_incoming_responses: usize,
130}
131
132/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
133/// need to wrap it in an `Arc`
134#[derive(Deref)]
135pub struct RequestResponse<
136    S: Sender<K>,
137    R: Receiver,
138    Req: Request,
139    RS: RecipientSource<Req, K>,
140    DS: DataSource<Req>,
141    K: SignatureKey + 'static,
142> {
143    #[deref]
144    /// The inner implementation of the request-response protocol
145    pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
146    /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
147    _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
148}
149
150/// We need to manually implement the `Clone` trait for this type because deriving
151/// `Deref` will cause an issue where it tries to clone the inner field instead
152impl<
153        S: Sender<K>,
154        R: Receiver,
155        Req: Request,
156        RS: RecipientSource<Req, K>,
157        DS: DataSource<Req>,
158        K: SignatureKey + 'static,
159    > Clone for RequestResponse<S, R, Req, RS, DS, K>
160{
161    fn clone(&self) -> Self {
162        Self {
163            inner: Arc::clone(&self.inner),
164            _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
165        }
166    }
167}
168
169impl<
170        S: Sender<K>,
171        R: Receiver,
172        Req: Request,
173        RS: RecipientSource<Req, K>,
174        DS: DataSource<Req>,
175        K: SignatureKey + 'static,
176    > RequestResponse<S, R, Req, RS, DS, K>
177{
178    /// Create a new [`RequestResponseProtocol`]
179    pub fn new(
180        // The configuration for the protocol
181        config: RequestResponseConfig,
182        // The network sender that [`RequestResponseProtocol`] will use to send messages
183        sender: S,
184        // The network receiver that [`RequestResponseProtocol`] will use to receive messages
185        receiver: R,
186        // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
187        // that a specific message should expect responses from
188        recipient_source: RS,
189        // The [response] data source that [`RequestResponseProtocol`] will use to derive the
190        // response data for a specific request
191        data_source: DS,
192    ) -> Self {
193        // Create the outgoing requests map
194        let outgoing_requests = OutgoingRequestsMap::default();
195
196        // Create the inner implementation
197        let inner = Arc::new(RequestResponseInner {
198            config,
199            sender,
200            recipient_source,
201            data_source,
202            outgoing_requests,
203            phantom_data: PhantomData,
204        });
205
206        // Start the task that receives messages and handles them. This will automatically get cancelled
207        // when the protocol is dropped
208        let inner_clone = Arc::clone(&inner);
209        let receive_task_handle =
210            AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
211
212        // Return the protocol
213        Self {
214            inner,
215            _receiving_task_handle: Arc::new(receive_task_handle),
216        }
217    }
218}
219
220/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
221type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
222
223/// A type alias for the future that validates a response
224type ResponseValidationFuture =
225    Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
226
227/// A type alias for the function that returns the above future
228type ResponseValidationFn<R> =
229    Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
230
231/// The inner implementation for the request-response protocol
232pub struct RequestResponseInner<
233    S: Sender<K>,
234    R: Receiver,
235    Req: Request,
236    RS: RecipientSource<Req, K>,
237    DS: DataSource<Req>,
238    K: SignatureKey + 'static,
239> {
240    /// The configuration of the protocol
241    config: RequestResponseConfig,
242    /// The sender to use for the protocol
243    pub sender: S,
244    /// The recipient source to use for the protocol
245    pub recipient_source: RS,
246    /// The data source to use for the protocol
247    data_source: DS,
248    /// The map of currently active, outgoing requests
249    outgoing_requests: OutgoingRequestsMap<Req>,
250    /// Phantom data to help with type inference
251    phantom_data: PhantomData<(K, R, Req, DS)>,
252}
253impl<
254        S: Sender<K>,
255        R: Receiver,
256        Req: Request,
257        RS: RecipientSource<Req, K>,
258        DS: DataSource<Req>,
259        K: SignatureKey + 'static,
260    > RequestResponseInner<S, R, Req, RS, DS, K>
261{
262    /// Request something from the protocol indefinitely until we get a response
263    /// or there was a critical error (e.g. the request could not be signed)
264    ///
265    /// # Errors
266    /// - If the request was invalid
267    /// - If there was a critical error (e.g. the channel was closed)
268    pub async fn request_indefinitely<F, Fut, O>(
269        self: &Arc<Self>,
270        public_key: &K,
271        private_key: &K::PrivateKey,
272        // The type of request to make
273        request_type: RequestType,
274        // The estimated TTL of other participants. This is used to decide when to
275        // stop making requests and sign a new one
276        estimated_request_ttl: Duration,
277        // The request to make
278        request: Req,
279        // The response validation function
280        response_validation_fn: F,
281    ) -> std::result::Result<O, RequestError>
282    where
283        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
284        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
285        O: Send + Sync + 'static + Clone,
286    {
287        loop {
288            // Sign a request message
289            let request_message = RequestMessage::new_signed(public_key, private_key, &request)
290                .map_err(|e| {
291                    RequestError::InvalidRequest(anyhow::anyhow!(
292                        "failed to sign request message: {e}"
293                    ))
294                })?;
295
296            // Request the data, handling the errors appropriately
297            match self
298                .request(
299                    request_message,
300                    request_type,
301                    estimated_request_ttl,
302                    response_validation_fn.clone(),
303                )
304                .await
305            {
306                Ok(response) => return Ok(response),
307                Err(RequestError::Timeout) => continue,
308                Err(e) => return Err(e),
309            }
310        }
311    }
312
313    /// Request something from the protocol and wait for the response. This function
314    /// will join with an existing request for the same data (determined by `Blake3` hash),
315    /// however both will make requests until the timeout is reached
316    ///
317    /// # Errors
318    /// - If the request times out
319    /// - If the channel is closed (this is an internal error)
320    /// - If the request we sign is invalid
321    pub async fn request<F, Fut, O>(
322        self: &Arc<Self>,
323        request_message: RequestMessage<Req, K>,
324        request_type: RequestType,
325        timeout_duration: Duration,
326        response_validation_fn: F,
327    ) -> std::result::Result<O, RequestError>
328    where
329        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
330        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
331        O: Send + Sync + 'static + Clone,
332    {
333        timeout(timeout_duration, async move {
334            // Calculate the hash of the request
335            let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
336                RequestError::InvalidRequest(anyhow::anyhow!(
337                    "failed to serialize request message: {e}"
338                ))
339            })?);
340
341            let request = {
342                // Get a write lock on the outgoing requests map
343                let mut outgoing_requests_write = self.outgoing_requests.write();
344
345                // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
346                // the existing one has been dropped and not yet removed
347                if let Some(outgoing_request) = outgoing_requests_write
348                    .get(&request_hash)
349                    .and_then(Weak::upgrade)
350                {
351                    OutgoingRequest(outgoing_request)
352                } else {
353                    // Create a new broadcast channel for the response
354                    let (sender, receiver) = async_broadcast::broadcast(1);
355
356                    // Modify the response validation function to return an `Arc<dyn Any>`
357                    let response_validation_fn =
358                        Box::new(move |request: &Req, response: Req::Response| {
359                            let fut = response_validation_fn(request, response);
360                            Box::pin(
361                                async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
362                            ) as ResponseValidationFuture
363                        });
364
365                    // Create a new outgoing request
366                    let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
367                        sender,
368                        receiver,
369                        response_validation_fn,
370                        request: request_message.request.clone(),
371                        outgoing_requests: Arc::clone(&self.outgoing_requests),
372                        request_hash,
373                    }));
374
375                    // Write the new outgoing request to the map
376                    outgoing_requests_write
377                        .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379                    // Return the new outgoing request
380                    outgoing_request
381                }
382            };
383
384            // Create a request message and serialize it
385            let message = Bytes::from(
386                Message::Request(request_message.clone())
387                    .to_bytes()
388                    .map_err(|e| {
389                        RequestError::InvalidRequest(anyhow::anyhow!(
390                            "failed to serialize request message: {e}"
391                        ))
392                    })?,
393            );
394
395            // Create a place to put the handle for the batched sending task. We need this because
396            // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397            // gets cancelled or returns
398            let mut _batched_sending_task = None;
399
400            // Match on the type of request
401            if request_type == RequestType::Broadcast {
402                trace!("Sending request {request_message:?} to all participants");
403
404                // If the message is a broadcast request, just send it to all participants
405                self.sender
406                    .send_broadcast_message(&message)
407                    .await
408                    .map_err(|e| {
409                        RequestError::Other(anyhow::anyhow!(
410                            "failed to send broadcast message: {e}"
411                        ))
412                    })?;
413            } else {
414                // If the message is a batched request, we need to batch it with other requests
415
416                // Get the recipients that the request should expect responses from. Shuffle them so
417                // that we don't always send to the same recipients in the same order
418                let mut recipients = self
419                    .recipient_source
420                    .get_expected_responders(&request_message.request)
421                    .await
422                    .map_err(|e| {
423                        RequestError::InvalidRequest(anyhow::anyhow!(
424                            "failed to get expected responders for request: {e}"
425                        ))
426                    })?;
427                recipients.shuffle(&mut rand::thread_rng());
428
429                // Get the current time so we can check when the timeout has elapsed
430                let start_time = Instant::now();
431
432                // Spawn a task that sends out requests to the network
433                let self_clone = Arc::clone(self);
434                let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435                    // Create a bounded queue for the outgoing requests. We use this to make sure
436                    // we have less than [`config.request_batch_size`] requests in flight at any time.
437                    //
438                    // When newer requests are added, older ones are removed from the queue. Because we use
439                    // `AbortOnDropHandle`, the older ones will automatically get cancelled
440                    let mut outgoing_requests =
441                        BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443                    // While the timeout hasn't elapsed, send out requests to the network
444                    while start_time.elapsed() < timeout_duration {
445                        // Send out requests to the network in their own separate tasks
446                        for recipient_batch in
447                            recipients.chunks(self_clone.config.request_batch_size)
448                        {
449                            for recipient in recipient_batch {
450                                // Clone ourselves, the message, and the recipient so they can be moved
451                                let self_clone = Arc::clone(&self_clone);
452                                let request_message_clone = request_message.clone();
453                                let recipient_clone = recipient.clone();
454                                let message_clone = Arc::clone(&message);
455
456                                // Spawn the task that sends the request to the participant
457                                let individual_sending_task = spawn(async move {
458                                    trace!(
459                                        "Sending request {request_message_clone:?} to \
460                                         {recipient_clone:?}"
461                                    );
462
463                                    let _ = self_clone
464                                        .sender
465                                        .send_direct_message(&message_clone, recipient_clone)
466                                        .await;
467                                });
468
469                                // Add the sending task to the queue
470                                outgoing_requests
471                                    .push(AbortOnDropHandle::new(individual_sending_task));
472                            }
473
474                            // After we send the batch out, wait the [`config.request_batch_interval`]
475                            // before sending the next one
476                            sleep(self_clone.config.request_batch_interval).await;
477                        }
478                    }
479                }));
480
481                // Store the handle so it doesn't get dropped
482                _batched_sending_task = Some(batched_sending_handle);
483            }
484
485            // Wait for a response on the channel
486            request
487                .receiver
488                .clone()
489                .recv()
490                .await
491                .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492        })
493        .await
494        .map_err(|_| RequestError::Timeout)
495        .and_then(|result| result)
496        .and_then(|result| {
497            result.downcast::<O>().map_err(|e| {
498                RequestError::Other(anyhow::anyhow!(
499                    "failed to downcast response to expected type: {e:?}"
500                ))
501            })
502        })
503        .map(|result| Arc::unwrap_or_clone(result))
504    }
505
506    /// The task responsible for receiving messages from the receiver and handling them
507    async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508        // Upper bound the number of outgoing and incoming responses
509        let mut incoming_requests = NamedSemaphore::new(
510            self.config.max_incoming_requests_per_key,
511            Some(self.config.max_incoming_requests),
512        );
513        let mut incoming_responses = NamedSemaphore::new(self.config.max_incoming_responses, None);
514
515        // While the receiver is open, we receive messages and handle them
516        loop {
517            // Try to receive a message
518            match receiver.receive_message().await {
519                Ok(message) => {
520                    // Deserialize the message, warning if it fails
521                    let message = match Message::from_bytes(&message) {
522                        Ok(message) => message,
523                        Err(e) => {
524                            warn!("Received invalid message: {e:#}");
525                            continue;
526                        },
527                    };
528
529                    // Handle the message based on its type
530                    match message {
531                        Message::Request(request_message) => {
532                            self.handle_request(request_message, &mut incoming_requests);
533                        },
534                        Message::Response(response_message) => {
535                            self.handle_response(response_message, &mut incoming_responses);
536                        },
537                    }
538                },
539                // An error here means the receiver will _NEVER_ receive any more messages
540                Err(e) => {
541                    error!("Request/response receive task exited: {e:#}");
542                    return;
543                },
544            }
545        }
546    }
547
548    /// Handle a request sent to us
549    fn handle_request(
550        self: &Arc<Self>,
551        request_message: RequestMessage<Req, K>,
552        incoming_requests: &mut IncomingRequests<K>,
553    ) {
554        trace!("Handling request {:?}", request_message);
555
556        // Spawn a task to:
557        // - Validate the request
558        // - Derive the response data (check if we have it)
559        // - Send the response to the requester
560        let self_clone = Arc::clone(self);
561
562        // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
563        // either globally or per-key
564        let permit = incoming_requests.try_acquire(request_message.public_key.clone());
565        match permit {
566            Ok(ref permit) => permit,
567            Err(NamedSemaphoreError::PerKeyLimitReached) => {
568                warn!(
569                    "Failed to process request from {}: too many requests from the same key are \
570                     already being processed",
571                    request_message.public_key
572                );
573                return;
574            },
575            Err(NamedSemaphoreError::GlobalLimitReached) => {
576                warn!(
577                    "Failed to process request from {}: too many requests are already being \
578                     processed",
579                    request_message.public_key
580                );
581                return;
582            },
583        };
584
585        tokio::spawn(async move {
586            let result = timeout(self_clone.config.incoming_request_timeout, async move {
587                // Validate the request message. This includes:
588                // - Checking the signature and making sure it's valid
589                // - Checking the timestamp and making sure it's not too old
590                // - Calling the request's application-specific validation function
591                request_message
592                    .validate(self_clone.config.incoming_request_ttl)
593                    .await
594                    .with_context(|| "failed to validate request")?;
595
596                // Try to fetch the response data from the data source
597                let response = self_clone
598                    .data_source
599                    .derive_response_for(&request_message.request)
600                    .await
601                    .with_context(|| "failed to derive response for request")?;
602
603                // Create the response message and serialize it
604                let response = Bytes::from(
605                    Message::Response::<Req, K>(ResponseMessage {
606                        request_hash: blake3::hash(&request_message.request.to_bytes()?),
607                        response,
608                    })
609                    .to_bytes()
610                    .with_context(|| "failed to serialize response message")?,
611                );
612
613                // Send the response to the requester
614                self_clone
615                    .sender
616                    .send_direct_message(&response, request_message.public_key)
617                    .await
618                    .with_context(|| "failed to send response to requester")?;
619
620                // Drop the permit
621                _ = permit;
622                drop(permit);
623
624                Ok::<(), anyhow::Error>(())
625            })
626            .await
627            .map_err(|_| anyhow::anyhow!("timed out while sending response"))
628            .and_then(|result| result);
629
630            if let Err(e) = result {
631                debug!("Failed to send response to requester: {e:#}");
632            }
633        });
634    }
635
636    /// Handle a response sent to us
637    fn handle_response(
638        self: &Arc<Self>,
639        response: ResponseMessage<Req>,
640        incoming_responses: &mut IncomingResponses,
641    ) {
642        trace!("Handling response {response:?}");
643
644        // Get the entry in the map, ignoring it if it doesn't exist
645        let Some(outgoing_request) = self
646            .outgoing_requests
647            .read()
648            .get(&response.request_hash)
649            .cloned()
650            .and_then(|r| r.upgrade())
651        else {
652            return;
653        };
654
655        // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
656        let permit = incoming_responses.try_acquire(());
657        let Ok(permit) = permit else {
658            warn!("Failed to process response: too many responses are already being processed");
659            return;
660        };
661
662        // Spawn a task to validate the response and send it to the requester (us)
663        let response_validate_timeout = self.config.incoming_response_timeout;
664        tokio::spawn(async move {
665            if timeout(response_validate_timeout, async move {
666                // Make sure the response is valid for the given request
667                let validation_result = match (outgoing_request.response_validation_fn)(
668                    &outgoing_request.request,
669                    response.response,
670                )
671                .await
672                {
673                    Ok(validation_result) => validation_result,
674                    Err(e) => {
675                        debug!("Received invalid response: {e:#}");
676                        return;
677                    },
678                };
679
680                // Send the response to the requester (the user of [`RequestResponse::request`])
681                let _ = outgoing_request.sender.try_broadcast(validation_result);
682
683                // Drop the permit
684                _ = permit;
685                drop(permit);
686            })
687            .await
688            .is_err()
689            {
690                warn!("Timed out while validating response");
691            }
692        });
693    }
694}
695
696/// An outgoing request. This is what we use to track a request and its corresponding response
697/// in the protocol
698#[derive(Clone, Deref)]
699pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
700
701/// The inner implementation of an outgoing request
702pub struct OutgoingRequestInner<R: Request> {
703    /// The sender to use for the protocol
704    sender: async_broadcast::Sender<ThreadSafeAny>,
705    /// The receiver to use for the protocol
706    receiver: async_broadcast::Receiver<ThreadSafeAny>,
707
708    /// The request that we are waiting for a response to
709    request: R,
710
711    /// The function used to validate the response
712    response_validation_fn: ResponseValidationFn<R>,
713
714    /// A copy of the map of currently active, outgoing requests
715    outgoing_requests: OutgoingRequestsMap<R>,
716    /// The hash of the request. We need this so we can remove ourselves from the map
717    request_hash: RequestHash,
718}
719
720impl<R: Request> Drop for OutgoingRequestInner<R> {
721    fn drop(&mut self) {
722        self.outgoing_requests.write().remove(&self.request_hash);
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use std::{
729        collections::HashMap,
730        sync::{atomic::AtomicBool, Mutex},
731    };
732
733    use async_trait::async_trait;
734    use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
735    use rand::Rng;
736    use tokio::{sync::mpsc, task::JoinSet};
737
738    use super::*;
739
740    /// This test makes sure that when all references to an outgoing request are dropped, it is
741    /// removed from the outgoing requests map
742    #[test]
743    fn test_outgoing_request_drop() {
744        // Create an outgoing requests map
745        let outgoing_requests = OutgoingRequestsMap::default();
746
747        // Create an outgoing request
748        let (sender, receiver) = async_broadcast::broadcast(1);
749        let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
750            sender,
751            receiver,
752            request: TestRequest(vec![1, 2, 3]),
753            response_validation_fn: Box::new(|_request, _response| {
754                Box::pin(async move { Ok(Arc::new(()) as ThreadSafeAny) })
755                    as ResponseValidationFuture
756            }),
757            outgoing_requests: Arc::clone(&outgoing_requests),
758            request_hash: blake3::hash(&[1, 2, 3]),
759        }));
760
761        // Insert the outgoing request into the map
762        outgoing_requests.write().insert(
763            outgoing_request.request_hash,
764            Arc::downgrade(&outgoing_request.0),
765        );
766
767        // Clone the outgoing request
768        let outgoing_request_clone = outgoing_request.clone();
769
770        // Drop the outgoing request
771        drop(outgoing_request);
772
773        // Make sure nothing has been removed
774        assert_eq!(outgoing_requests.read().len(), 1);
775
776        // Drop the clone
777        drop(outgoing_request_clone);
778
779        // Make sure it has been removed
780        assert_eq!(outgoing_requests.read().len(), 0);
781    }
782
783    /// A test sender that has a list of all the participants in the network
784    #[derive(Clone)]
785    pub struct TestSender {
786        network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
787    }
788
789    /// An implementation of the [`Sender`] trait for the [`TestSender`] type
790    #[async_trait]
791    impl Sender<BLSPubKey> for TestSender {
792        async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
793            self.network
794                .get(&recipient)
795                .ok_or(anyhow::anyhow!("recipient not found"))?
796                .send(Arc::clone(message))
797                .await
798                .map_err(|_| anyhow::anyhow!("failed to send message"))?;
799
800            Ok(())
801        }
802
803        async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
804            for sender in self.network.values() {
805                sender
806                    .send(Arc::clone(message))
807                    .await
808                    .map_err(|_| anyhow::anyhow!("failed to send message"))?;
809            }
810            Ok(())
811        }
812    }
813
814    // Implement the [`RecipientSource`] trait for the [`TestSender`] type
815    #[async_trait]
816    impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
817        async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
818            // Get all the participants in the network
819            Ok(self.network.keys().copied().collect())
820        }
821    }
822
823    // Create a test request that is just some bytes
824    #[derive(Clone, Debug)]
825    struct TestRequest(Vec<u8>);
826
827    // Implement the [`Serializable`] trait for the [`TestRequest`] type
828    impl Serializable for TestRequest {
829        fn to_bytes(&self) -> Result<Vec<u8>> {
830            Ok(self.0.clone())
831        }
832
833        fn from_bytes(bytes: &[u8]) -> Result<Self> {
834            Ok(TestRequest(bytes.to_vec()))
835        }
836    }
837
838    // Implement the [`Request`] trait for the [`TestRequest`] type
839    #[async_trait]
840    impl Request for TestRequest {
841        type Response = Vec<u8>;
842        async fn validate(&self) -> Result<()> {
843            Ok(())
844        }
845    }
846
847    // Create a test data source that pretends to have the data or not
848    #[derive(Clone)]
849    struct TestDataSource {
850        /// Whether we have the data or not
851        has_data: bool,
852        /// The time at which the data will be available if we have it
853        data_available_time: Instant,
854
855        /// Whether or not the data will be taken once served
856        take_data: bool,
857        /// Whether or not the data has been taken
858        taken: Arc<AtomicBool>,
859    }
860
861    #[async_trait]
862    impl DataSource<TestRequest> for TestDataSource {
863        async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
864            // Return a response if we hit the hit rate
865            if self.has_data && Instant::now() >= self.data_available_time {
866                if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
867                    return Err(anyhow::anyhow!("data already taken"));
868                }
869                Ok(blake3::hash(&request.0).as_bytes().to_vec())
870            } else {
871                Err(anyhow::anyhow!("did not have the data"))
872            }
873        }
874    }
875
876    /// Create and return a default protocol configuration
877    fn default_protocol_config() -> RequestResponseConfig {
878        RequestResponseConfig {
879            incoming_request_ttl: Duration::from_secs(40),
880            incoming_request_timeout: Duration::from_secs(40),
881            request_batch_size: 10,
882            request_batch_interval: Duration::from_millis(100),
883            max_incoming_requests: 10,
884            max_incoming_requests_per_key: 1,
885            incoming_response_timeout: Duration::from_secs(1),
886            max_incoming_responses: 5,
887        }
888    }
889
890    /// Create fully connected test networks with `num_participants` participants
891    fn create_participants(
892        num: usize,
893    ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
894        // The entire network
895        let mut network = HashMap::new();
896
897        // All receivers in the network
898        let mut receivers = Vec::new();
899
900        // All keypairs in the network
901        let mut keypairs = Vec::new();
902
903        // For each participant,
904        for i in 0..num {
905            // Create a unique `BLSPubKey`
906            let (public_key, private_key) =
907                BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
908
909            // Add the keypair to the list
910            keypairs.push((public_key, private_key));
911
912            // Create a channel for sending and receiving messages
913            let (sender, receiver) = mpsc::channel::<Bytes>(100);
914
915            // Add the participant to the network
916            network.insert(public_key, sender);
917
918            // Add the receiver to the list of receivers
919            receivers.push(receiver);
920        }
921
922        // Create a test sender from the network
923        let sender = TestSender {
924            network: Arc::new(network),
925        };
926
927        // Return all senders and receivers
928        receivers
929            .into_iter()
930            .zip(keypairs)
931            .map(|(r, k)| (sender.clone(), r, k))
932            .collect()
933    }
934
935    /// The configuration for an integration test
936    #[derive(Clone)]
937    struct IntegrationTestConfig {
938        /// The request response protocol configuration
939        request_response_config: RequestResponseConfig,
940        /// The number of participants in the network
941        num_participants: usize,
942        /// The number of participants that have the data
943        num_participants_with_data: usize,
944        /// The timeout for the requests
945        request_timeout: Duration,
946        /// The delay before the nodes have the data available
947        data_available_delay: Duration,
948    }
949
950    /// The result of an integration test
951    struct IntegrationTestResult {
952        /// The number of nodes that received a response
953        num_succeeded: usize,
954    }
955
956    /// Run an integration test with the given parameters
957    async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
958        // Create a fully connected network with `num_participants` participants
959        let participants = create_participants(config.num_participants);
960
961        // Create a join set to wait for all the tasks to finish
962        let mut join_set = JoinSet::new();
963
964        // We need to keep these here so they don't get dropped
965        let handles = Arc::new(Mutex::new(Vec::new()));
966
967        // For each one, create a new [`RequestResponse`] protocol
968        for (i, (sender, receiver, (public_key, private_key))) in
969            participants.into_iter().enumerate()
970        {
971            let config_clone = config.request_response_config.clone();
972            let handles_clone = Arc::clone(&handles);
973            join_set.spawn(async move {
974                let protocol = RequestResponse::new(
975                    config_clone,
976                    sender.clone(),
977                    receiver,
978                    sender,
979                    TestDataSource {
980                        has_data: i < config.num_participants_with_data,
981                        data_available_time: Instant::now() + config.data_available_delay,
982                        take_data: false,
983                        taken: Arc::new(AtomicBool::new(false)),
984                    },
985                );
986
987                // Add the handle to the handles list so it doesn't get dropped and
988                // cancelled
989                #[allow(clippy::used_underscore_binding)]
990                handles_clone
991                    .lock()
992                    .unwrap()
993                    .push(Arc::clone(&protocol._receiving_task_handle));
994
995                // Create a random request
996                let request = TestRequest(vec![rand::thread_rng().gen(); 100]);
997
998                // Get the hash of the request
999                let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
1000
1001                // Create a new request message
1002                let request = RequestMessage::new_signed(&public_key, &private_key, &request)
1003                    .expect("failed to create request message");
1004
1005                // Request the data from the protocol
1006                let response = protocol
1007                    .request(
1008                        request,
1009                        RequestType::Batched,
1010                        config.request_timeout,
1011                        |_request, response| async move { Ok(response) },
1012                    )
1013                    .await?;
1014
1015                // Make sure the response is the hash of the request
1016                assert_eq!(response, request_hash);
1017
1018                Ok::<(), anyhow::Error>(())
1019            });
1020        }
1021
1022        // Wait for all the tasks to finish
1023        let mut num_succeeded = config.num_participants;
1024        while let Some(result) = join_set.join_next().await {
1025            if result.is_err() || result.unwrap().is_err() {
1026                num_succeeded -= 1;
1027            }
1028        }
1029
1030        IntegrationTestResult { num_succeeded }
1031    }
1032
1033    /// Test the integration of the protocol with 50% of the participants having the data
1034    #[tokio::test(flavor = "multi_thread")]
1035    async fn test_integration_50_0s() {
1036        // Build a config
1037        let config = IntegrationTestConfig {
1038            request_response_config: default_protocol_config(),
1039            num_participants: 100,
1040            num_participants_with_data: 50,
1041            request_timeout: Duration::from_secs(40),
1042            data_available_delay: Duration::from_secs(0),
1043        };
1044
1045        // Run the test, making sure all the requests succeed
1046        let result = run_integration_test(config).await;
1047        assert_eq!(result.num_succeeded, 100);
1048    }
1049
1050    /// Test the integration of the protocol when nobody has the data. Make sure we don't
1051    /// get any responses
1052    #[tokio::test(flavor = "multi_thread")]
1053    async fn test_integration_0() {
1054        // Build a config
1055        let config = IntegrationTestConfig {
1056            request_response_config: default_protocol_config(),
1057            num_participants: 100,
1058            num_participants_with_data: 0,
1059            request_timeout: Duration::from_secs(40),
1060            data_available_delay: Duration::from_secs(0),
1061        };
1062
1063        // Run the test
1064        let result = run_integration_test(config).await;
1065
1066        // Make sure all the requests succeeded
1067        assert_eq!(result.num_succeeded, 0);
1068    }
1069
1070    /// Test the integration of the protocol when one node has the data after
1071    /// a delay of 1s
1072    #[tokio::test(flavor = "multi_thread")]
1073    async fn test_integration_1_1s() {
1074        // Build a config
1075        let config = IntegrationTestConfig {
1076            request_response_config: default_protocol_config(),
1077            num_participants: 100,
1078            num_participants_with_data: 1,
1079            request_timeout: Duration::from_secs(40),
1080            data_available_delay: Duration::from_secs(2),
1081        };
1082
1083        // Run the test
1084        let result = run_integration_test(config).await;
1085
1086        // Make sure all the requests succeeded
1087        assert_eq!(result.num_succeeded, 100);
1088    }
1089
1090    /// Test that we can join an existing request for the same data and get the same (single) response
1091    #[tokio::test(flavor = "multi_thread")]
1092    async fn test_join_existing_request() {
1093        // Build a config
1094        let config = default_protocol_config();
1095
1096        // Create two participants
1097        let mut participants = Vec::new();
1098
1099        for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1100            // For each, create a new [`RequestResponse`] protocol
1101            let protocol = RequestResponse::new(
1102                config.clone(),
1103                sender.clone(),
1104                receiver,
1105                sender,
1106                TestDataSource {
1107                    take_data: true,
1108                    has_data: true,
1109                    data_available_time: Instant::now() + Duration::from_secs(2),
1110                    taken: Arc::new(AtomicBool::new(false)),
1111                },
1112            );
1113
1114            // Add the participants to the list
1115            participants.push((protocol, public_key, private_key));
1116        }
1117
1118        // Take the first participant
1119        let one = Arc::new(participants.remove(0));
1120
1121        // Create the request that they should all be able to join on
1122        let request = TestRequest(vec![rand::thread_rng().gen(); 100]);
1123
1124        // Create a join set to wait for all the tasks to finish
1125        let mut join_set = JoinSet::new();
1126
1127        // Make 10 requests with the same hash
1128        for _ in 0..10 {
1129            // Clone the first participant
1130            let one_clone = Arc::clone(&one);
1131
1132            // Clone the request
1133            let request_clone = request.clone();
1134
1135            // Spawn a task to request the data
1136            join_set.spawn(async move {
1137                // Create a new, signed request message
1138                let request_message =
1139                    RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1140
1141                // Start requesting it
1142                one_clone
1143                    .0
1144                    .request(
1145                        request_message,
1146                        RequestType::Batched,
1147                        Duration::from_secs(20),
1148                        |_request, response| async move { Ok(response) },
1149                    )
1150                    .await?;
1151
1152                Ok::<(), anyhow::Error>(())
1153            });
1154        }
1155
1156        // Wait for all the tasks to finish, making sure they all succeed
1157        while let Some(result) = join_set.join_next().await {
1158            result
1159                .expect("failed to join task")
1160                .expect("failed to request data");
1161        }
1162    }
1163}