request_response/
message.rs

1use std::{
2    io::{Cursor, Read, Write},
3    time::{Duration, SystemTime, UNIX_EPOCH},
4};
5
6use anyhow::{Context, Result};
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use hotshot_types::traits::signature_key::SignatureKey;
9
10use super::{request::Request, RequestHash, Serializable};
11
12/// The outer message type for the request-response protocol. Can either be a request or a response
13#[derive(Clone, Debug)]
14#[cfg_attr(test, derive(PartialEq, Eq))]
15pub enum Message<R: Request, K: SignatureKey> {
16    /// A request
17    Request(RequestMessage<R, K>),
18    /// A response
19    Response(ResponseMessage<R>),
20}
21
22/// A request message, which includes the requester's public key, the request's signature, a timestamp, and the request itself
23#[derive(Clone, Debug)]
24#[cfg_attr(test, derive(PartialEq, Eq))]
25pub struct RequestMessage<R: Request, K: SignatureKey> {
26    /// The requester's public key
27    pub public_key: K,
28    /// The requester's signature over the [the actual request content + timestamp]
29    pub signature: K::PureAssembledSignatureType,
30    /// The timestamp of when the request was sent (in seconds since the Unix epoch). We use this to
31    /// ensure that the request is not old, which is useful for preventing replay attacks.
32    pub timestamp_unix_seconds: u64,
33    /// The actual request data. This is from the application
34    pub request: R,
35}
36
37/// A response message, which includes the hash of the request we're responding to and the response itself.
38#[derive(Clone, Debug)]
39#[cfg_attr(test, derive(PartialEq, Eq))]
40pub struct ResponseMessage<R: Request> {
41    /// The hash of the application-specific request we're responding to. The hash is a free way
42    /// to identify the request and weed out any potential incompatibilities
43    pub request_hash: RequestHash,
44    /// The actual response content
45    pub response: R::Response,
46}
47
48impl<R: Request, K: SignatureKey> RequestMessage<R, K> {
49    /// Create a new signed request message from a request
50    ///
51    /// # Errors
52    /// - If the request's content cannot be serialized
53    /// - If the request cannot be signed
54    ///
55    /// # Panics
56    /// - If time is not monotonic
57    pub fn new_signed(public_key: &K, private_key: &K::PrivateKey, request: &R) -> Result<Self>
58    where
59        <K as SignatureKey>::SignError: 'static,
60    {
61        // Get the current timestamp
62        let timestamp_unix_seconds = SystemTime::now()
63            .duration_since(UNIX_EPOCH)
64            .expect("time went backwards")
65            .as_secs();
66
67        // Concatenate the content and timestamp
68        let content_to_sign = [
69            request
70                .to_bytes()
71                .with_context(|| "failed to serialize request content")?
72                .as_slice(),
73            timestamp_unix_seconds.to_le_bytes().as_slice(),
74            b"espresso-request-response",
75        ]
76        .concat();
77
78        // Sign the actual request content (+ a namespace) with the private key
79        let signature =
80            K::sign(private_key, &content_to_sign).with_context(|| "failed to sign message")?;
81
82        // Return the newly signed request message
83        Ok(RequestMessage {
84            public_key: public_key.clone(),
85            signature,
86            timestamp_unix_seconds,
87            request: request.clone(),
88        })
89    }
90
91    /// Validate the [`RequestMessage`], checking the signature and the timestamp and
92    /// calling the request's application-specific validation function
93    ///
94    /// # Errors
95    /// - If the request's signature is invalid
96    /// - If the request is too old
97    ///
98    /// # Panics
99    /// - If time is not monotonic
100    pub async fn validate(&self, incoming_request_ttl: Duration) -> Result<()> {
101        // Make sure the request is not too old
102        if self
103            .timestamp_unix_seconds
104            .saturating_add(incoming_request_ttl.as_secs())
105            < SystemTime::now()
106                .duration_since(UNIX_EPOCH)
107                .expect("time went backwards")
108                .as_secs()
109        {
110            return Err(anyhow::anyhow!("request is too old"));
111        }
112        // Check the signature over the request content and timestamp
113        if !self.public_key.validate(
114            &self.signature,
115            &[
116                self.request.to_bytes()?,
117                self.timestamp_unix_seconds.to_le_bytes().to_vec(),
118                b"espresso-request-response".to_vec(),
119            ]
120            .concat(),
121        ) {
122            return Err(anyhow::anyhow!("invalid request signature"));
123        }
124
125        // Call the request's application-specific validation function
126        self.request.validate().await
127    }
128}
129
130/// A blanket implementation of the [`Serializable`] trait for any [`Message`]
131impl<R: Request, K: SignatureKey> Serializable for Message<R, K> {
132    /// Converts any [`Message`] to bytes if the content is also [`Serializable`]
133    fn to_bytes(&self) -> Result<Vec<u8>> {
134        // Create a buffer for the bytes
135        let mut bytes = Vec::new();
136
137        // Convert the message to bytes based on the type. By default it is just type-prefixed
138        match self {
139            Message::Request(request_message) => {
140                // Write the type (request)
141                bytes.push(0);
142
143                // Write the request content
144                bytes.extend_from_slice(request_message.to_bytes()?.as_slice());
145            },
146            Message::Response(response_message) => {
147                // Write the type (response)
148                bytes.push(1);
149
150                // Write the response content
151                bytes.extend_from_slice(response_message.to_bytes()?.as_slice());
152            },
153        };
154
155        Ok(bytes)
156    }
157
158    /// Convert bytes to a [`Message`]
159    fn from_bytes(bytes: &[u8]) -> Result<Self> {
160        // Create a cursor so we can easily read the bytes in order
161        let mut bytes = Cursor::new(bytes);
162
163        // Get the message type
164        let type_byte = bytes.read_u8()?;
165
166        // Deserialize the message based on the type
167        match type_byte {
168            0 => {
169                // Read the `RequestMessage`
170                Ok(Message::Request(RequestMessage::from_bytes(&read_to_end(
171                    &mut bytes,
172                )?)?))
173            },
174            1 => {
175                // Read the `ResponseMessage`
176                Ok(Message::Response(ResponseMessage::from_bytes(
177                    &read_to_end(&mut bytes)?,
178                )?))
179            },
180            _ => Err(anyhow::anyhow!("invalid message type")),
181        }
182    }
183}
184
185impl<R: Request, K: SignatureKey> Serializable for RequestMessage<R, K> {
186    fn to_bytes(&self) -> Result<Vec<u8>> {
187        // Create a buffer for the bytes
188        let mut bytes = Vec::new();
189
190        // Write the public key (length-prefixed)
191        write_length_prefixed(&mut bytes, &self.public_key.to_bytes())?;
192
193        // Write the signature (length-prefixed)
194        write_length_prefixed(&mut bytes, &bincode::serialize(&self.signature)?)?;
195
196        // Write the timestamp
197        bytes.write_all(&self.timestamp_unix_seconds.to_le_bytes())?;
198
199        // Write the actual request
200        bytes.write_all(self.request.to_bytes()?.as_slice())?;
201
202        Ok(bytes)
203    }
204
205    fn from_bytes(bytes: &[u8]) -> Result<Self> {
206        // Create a cursor so we can easily read the bytes in order
207        let mut bytes = Cursor::new(bytes);
208
209        // Read the public key (length-prefixed)
210        let public_key = K::from_bytes(&read_length_prefixed(&mut bytes)?)?;
211
212        // Read the signature (length-prefixed)
213        let signature = bincode::deserialize(&read_length_prefixed(&mut bytes)?)?;
214
215        // Read the timestamp as a [`u64`]
216        let timestamp = bytes.read_u64::<LittleEndian>()?;
217
218        // Deserialize the request
219        let request = R::from_bytes(&read_to_end(&mut bytes)?)?;
220
221        Ok(Self {
222            public_key,
223            signature,
224            timestamp_unix_seconds: timestamp,
225            request,
226        })
227    }
228}
229
230impl<R: Request> Serializable for ResponseMessage<R> {
231    fn to_bytes(&self) -> Result<Vec<u8>> {
232        // Create a buffer for the bytes
233        let mut bytes = Vec::new();
234
235        // Write the request hash as bytes
236        bytes.write_all(self.request_hash.as_bytes())?;
237
238        // Write the response content
239        bytes.write_all(self.response.to_bytes()?.as_slice())?;
240
241        Ok(bytes)
242    }
243
244    fn from_bytes(bytes: &[u8]) -> Result<Self> {
245        // Create a buffer for the bytes
246        let mut bytes = Cursor::new(bytes);
247
248        // Read the request hash as a [`blake3::Hash`]
249        let mut request_hash_bytes = [0; 32];
250        bytes.read_exact(&mut request_hash_bytes)?;
251        let request_hash = RequestHash::from(request_hash_bytes);
252
253        // Read the response content to the end
254        let response = R::Response::from_bytes(&read_to_end(&mut bytes)?)?;
255
256        Ok(Self {
257            request_hash,
258            response,
259        })
260    }
261}
262
263/// A helper function to write a length-prefixed value to a writer
264fn write_length_prefixed<W: Write>(writer: &mut W, value: &[u8]) -> Result<()> {
265    // Write the length of the value as a u32
266    writer.write_u32::<LittleEndian>(
267        u32::try_from(value.len()).with_context(|| "value was too large")?,
268    )?;
269
270    // Write the (already serialized) value
271    writer.write_all(value)?;
272    Ok(())
273}
274
275/// A helper function to read a length-prefixed value from a reader
276fn read_length_prefixed<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
277    // Read the length of the value as a u32
278    let length = reader.read_u32::<LittleEndian>()?;
279
280    // Read the value
281    let mut value = vec![0; length as usize];
282    reader.read_exact(&mut value)?;
283    Ok(value)
284}
285
286/// A helper function to read to the end of the reader
287fn read_to_end<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
288    let mut value = Vec::new();
289    reader.read_to_end(&mut value)?;
290    Ok(value)
291}
292
293#[cfg(test)]
294mod tests {
295    use async_trait::async_trait;
296    use hotshot_types::signature_key::BLSPubKey;
297    use rand::Rng;
298
299    use super::*;
300
301    // A testing implementation of the [`Serializable`] trait for [`Vec<u8>`]
302    impl Serializable for Vec<u8> {
303        fn to_bytes(&self) -> Result<Vec<u8>> {
304            Ok(self.clone())
305        }
306        fn from_bytes(bytes: &[u8]) -> Result<Self> {
307            Ok(bytes.to_vec())
308        }
309    }
310
311    /// A testing implementation of the [`Request`] trait for [`Vec<u8>`]
312    #[async_trait]
313    impl Request for Vec<u8> {
314        type Response = Vec<u8>;
315
316        async fn validate(&self) -> Result<()> {
317            Ok(())
318        }
319    }
320
321    /// Tests that properly signed requests are validated correctly and that invalid requests
322    /// (bad timestamp/signature) are rejected
323    #[tokio::test]
324    async fn test_request_validation() {
325        // Create some RNG
326        let mut rng = rand::thread_rng();
327
328        for _ in 0..100 {
329            // Create a random keypair
330            let (public_key, private_key) =
331                BLSPubKey::generated_from_seed_indexed([1; 32], rng.gen::<u64>());
332
333            // Create a valid request with some random content
334            let mut request = RequestMessage::new_signed(
335                &public_key,
336                &private_key,
337                &vec![rng.gen::<u8>(); rng.gen_range(1..10000)],
338            )
339            .expect("Failed to create signed request");
340
341            let (should_be_valid, request_ttl) = match rng.gen_range(0..4) {
342                0 => (true, Duration::from_secs(1)),
343
344                1 => {
345                    // Alter the requests's actual content
346                    request.request[0] = !request.request[0];
347
348                    // It should not be valid anymore
349                    (false, Duration::from_secs(1))
350                },
351
352                2 => {
353                    // Alter the timestamp
354                    request.timestamp_unix_seconds += 1000;
355
356                    // It should not be valid anymore
357                    (false, Duration::from_secs(1))
358                },
359
360                3 => {
361                    // Change the request ttl to be 0. This should make the request
362                    // invalid immediately
363                    (true, Duration::from_secs(0))
364                },
365
366                _ => unreachable!(),
367            };
368
369            // Validate the request
370            assert_eq!(request.validate(request_ttl).await.is_ok(), should_be_valid);
371        }
372    }
373
374    /// Tests that messages are serialized and deserialized correctly
375    #[test]
376    fn test_message_parity() {
377        for _ in 0..100 {
378            // Create some RNG
379            let mut rng = rand::thread_rng();
380
381            // Generate a random message type
382            let is_request = rng.gen::<u8>() % 2 == 0;
383
384            // The request content will be a random vector of bytes
385            let request = vec![rng.gen::<u8>(); rng.gen_range(0..10000)];
386
387            // Create a message
388            let message = if is_request {
389                // Create a random keypair
390                let (public_key, private_key) =
391                    BLSPubKey::generated_from_seed_indexed([1; 32], rng.gen::<u64>());
392
393                // Create a new signed request
394                let request = RequestMessage::new_signed(&public_key, &private_key, &request)
395                    .expect("Failed to create signed request");
396
397                Message::Request(request)
398            } else {
399                // Create a response message
400                Message::Response(ResponseMessage {
401                    request_hash: blake3::hash(&request),
402                    response: vec![rng.gen::<u8>(); rng.gen_range(0..10000)],
403                })
404            };
405
406            // Serialize the message
407            let serialized = message.to_bytes().expect("Failed to serialize message");
408
409            // Deserialize the message
410            let deserialized =
411                Message::from_bytes(&serialized).expect("Failed to deserialize message");
412
413            // Assert that the deserialized message is the same as the original message
414            assert_eq!(message, deserialized);
415        }
416    }
417
418    /// Tests that length-prefixed values are read and written correctly
419    #[test]
420    fn test_length_prefix_parity() {
421        // Create some RNG
422        let mut rng = rand::thread_rng();
423
424        for _ in 0..100 {
425            // Create a buffer to test over
426            let mut bytes = Vec::new();
427
428            // Generate the value to test over
429            let value = vec![rng.gen::<u8>(); rng.gen_range(0..10000)];
430
431            // Write the length-prefixed value
432            write_length_prefixed(&mut bytes, &value).unwrap();
433
434            // Create a reader from the bytes
435            let mut reader = Cursor::new(bytes);
436
437            // Read the length-prefixed value
438            let value = read_length_prefixed(&mut reader).unwrap();
439            assert_eq!(value, value);
440        }
441    }
442}