1use std::{
2 io::{Cursor, Read, Write},
3 time::{Duration, SystemTime, UNIX_EPOCH},
4};
5
6use anyhow::{Context, Result};
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use hotshot_types::traits::signature_key::SignatureKey;
9
10use super::{request::Request, RequestHash, Serializable};
11
12#[derive(Clone, Debug)]
14#[cfg_attr(test, derive(PartialEq, Eq))]
15pub enum Message<R: Request, K: SignatureKey> {
16 Request(RequestMessage<R, K>),
18 Response(ResponseMessage<R>),
20}
21
22#[derive(Clone, Debug)]
24#[cfg_attr(test, derive(PartialEq, Eq))]
25pub struct RequestMessage<R: Request, K: SignatureKey> {
26 pub public_key: K,
28 pub signature: K::PureAssembledSignatureType,
30 pub timestamp_unix_seconds: u64,
33 pub request: R,
35}
36
37#[derive(Clone, Debug)]
39#[cfg_attr(test, derive(PartialEq, Eq))]
40pub struct ResponseMessage<R: Request> {
41 pub request_hash: RequestHash,
44 pub response: R::Response,
46}
47
48impl<R: Request, K: SignatureKey> RequestMessage<R, K> {
49 pub fn new_signed(public_key: &K, private_key: &K::PrivateKey, request: &R) -> Result<Self>
58 where
59 <K as SignatureKey>::SignError: 'static,
60 {
61 let timestamp_unix_seconds = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .expect("time went backwards")
65 .as_secs();
66
67 let content_to_sign = [
69 request
70 .to_bytes()
71 .with_context(|| "failed to serialize request content")?
72 .as_slice(),
73 timestamp_unix_seconds.to_le_bytes().as_slice(),
74 b"espresso-request-response",
75 ]
76 .concat();
77
78 let signature =
80 K::sign(private_key, &content_to_sign).with_context(|| "failed to sign message")?;
81
82 Ok(RequestMessage {
84 public_key: public_key.clone(),
85 signature,
86 timestamp_unix_seconds,
87 request: request.clone(),
88 })
89 }
90
91 pub async fn validate(&self, incoming_request_ttl: Duration) -> Result<()> {
101 if self
103 .timestamp_unix_seconds
104 .saturating_add(incoming_request_ttl.as_secs())
105 < SystemTime::now()
106 .duration_since(UNIX_EPOCH)
107 .expect("time went backwards")
108 .as_secs()
109 {
110 return Err(anyhow::anyhow!("request is too old"));
111 }
112 if !self.public_key.validate(
114 &self.signature,
115 &[
116 self.request.to_bytes()?,
117 self.timestamp_unix_seconds.to_le_bytes().to_vec(),
118 b"espresso-request-response".to_vec(),
119 ]
120 .concat(),
121 ) {
122 return Err(anyhow::anyhow!("invalid request signature"));
123 }
124
125 self.request.validate().await
127 }
128}
129
130impl<R: Request, K: SignatureKey> Serializable for Message<R, K> {
132 fn to_bytes(&self) -> Result<Vec<u8>> {
134 let mut bytes = Vec::new();
136
137 match self {
139 Message::Request(request_message) => {
140 bytes.push(0);
142
143 bytes.extend_from_slice(request_message.to_bytes()?.as_slice());
145 },
146 Message::Response(response_message) => {
147 bytes.push(1);
149
150 bytes.extend_from_slice(response_message.to_bytes()?.as_slice());
152 },
153 };
154
155 Ok(bytes)
156 }
157
158 fn from_bytes(bytes: &[u8]) -> Result<Self> {
160 let mut bytes = Cursor::new(bytes);
162
163 let type_byte = bytes.read_u8()?;
165
166 match type_byte {
168 0 => {
169 Ok(Message::Request(RequestMessage::from_bytes(&read_to_end(
171 &mut bytes,
172 )?)?))
173 },
174 1 => {
175 Ok(Message::Response(ResponseMessage::from_bytes(
177 &read_to_end(&mut bytes)?,
178 )?))
179 },
180 _ => Err(anyhow::anyhow!("invalid message type")),
181 }
182 }
183}
184
185impl<R: Request, K: SignatureKey> Serializable for RequestMessage<R, K> {
186 fn to_bytes(&self) -> Result<Vec<u8>> {
187 let mut bytes = Vec::new();
189
190 write_length_prefixed(&mut bytes, &self.public_key.to_bytes())?;
192
193 write_length_prefixed(&mut bytes, &bincode::serialize(&self.signature)?)?;
195
196 bytes.write_all(&self.timestamp_unix_seconds.to_le_bytes())?;
198
199 bytes.write_all(self.request.to_bytes()?.as_slice())?;
201
202 Ok(bytes)
203 }
204
205 fn from_bytes(bytes: &[u8]) -> Result<Self> {
206 let mut bytes = Cursor::new(bytes);
208
209 let public_key = K::from_bytes(&read_length_prefixed(&mut bytes)?)?;
211
212 let signature = bincode::deserialize(&read_length_prefixed(&mut bytes)?)?;
214
215 let timestamp = bytes.read_u64::<LittleEndian>()?;
217
218 let request = R::from_bytes(&read_to_end(&mut bytes)?)?;
220
221 Ok(Self {
222 public_key,
223 signature,
224 timestamp_unix_seconds: timestamp,
225 request,
226 })
227 }
228}
229
230impl<R: Request> Serializable for ResponseMessage<R> {
231 fn to_bytes(&self) -> Result<Vec<u8>> {
232 let mut bytes = Vec::new();
234
235 bytes.write_all(self.request_hash.as_bytes())?;
237
238 bytes.write_all(self.response.to_bytes()?.as_slice())?;
240
241 Ok(bytes)
242 }
243
244 fn from_bytes(bytes: &[u8]) -> Result<Self> {
245 let mut bytes = Cursor::new(bytes);
247
248 let mut request_hash_bytes = [0; 32];
250 bytes.read_exact(&mut request_hash_bytes)?;
251 let request_hash = RequestHash::from(request_hash_bytes);
252
253 let response = R::Response::from_bytes(&read_to_end(&mut bytes)?)?;
255
256 Ok(Self {
257 request_hash,
258 response,
259 })
260 }
261}
262
263fn write_length_prefixed<W: Write>(writer: &mut W, value: &[u8]) -> Result<()> {
265 writer.write_u32::<LittleEndian>(
267 u32::try_from(value.len()).with_context(|| "value was too large")?,
268 )?;
269
270 writer.write_all(value)?;
272 Ok(())
273}
274
275fn read_length_prefixed<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
277 let length = reader.read_u32::<LittleEndian>()?;
279
280 let mut value = vec![0; length as usize];
282 reader.read_exact(&mut value)?;
283 Ok(value)
284}
285
286fn read_to_end<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
288 let mut value = Vec::new();
289 reader.read_to_end(&mut value)?;
290 Ok(value)
291}
292
293#[cfg(test)]
294mod tests {
295 use async_trait::async_trait;
296 use hotshot_types::signature_key::BLSPubKey;
297 use rand::Rng;
298
299 use super::*;
300
301 impl Serializable for Vec<u8> {
303 fn to_bytes(&self) -> Result<Vec<u8>> {
304 Ok(self.clone())
305 }
306 fn from_bytes(bytes: &[u8]) -> Result<Self> {
307 Ok(bytes.to_vec())
308 }
309 }
310
311 #[async_trait]
313 impl Request for Vec<u8> {
314 type Response = Vec<u8>;
315
316 async fn validate(&self) -> Result<()> {
317 Ok(())
318 }
319 }
320
321 #[tokio::test]
324 async fn test_request_validation() {
325 let mut rng = rand::thread_rng();
327
328 for _ in 0..100 {
329 let (public_key, private_key) =
331 BLSPubKey::generated_from_seed_indexed([1; 32], rng.gen::<u64>());
332
333 let mut request = RequestMessage::new_signed(
335 &public_key,
336 &private_key,
337 &vec![rng.gen::<u8>(); rng.gen_range(1..10000)],
338 )
339 .expect("Failed to create signed request");
340
341 let (should_be_valid, request_ttl) = match rng.gen_range(0..4) {
342 0 => (true, Duration::from_secs(1)),
343
344 1 => {
345 request.request[0] = !request.request[0];
347
348 (false, Duration::from_secs(1))
350 },
351
352 2 => {
353 request.timestamp_unix_seconds += 1000;
355
356 (false, Duration::from_secs(1))
358 },
359
360 3 => {
361 (true, Duration::from_secs(0))
364 },
365
366 _ => unreachable!(),
367 };
368
369 assert_eq!(request.validate(request_ttl).await.is_ok(), should_be_valid);
371 }
372 }
373
374 #[test]
376 fn test_message_parity() {
377 for _ in 0..100 {
378 let mut rng = rand::thread_rng();
380
381 let is_request = rng.gen::<u8>() % 2 == 0;
383
384 let request = vec![rng.gen::<u8>(); rng.gen_range(0..10000)];
386
387 let message = if is_request {
389 let (public_key, private_key) =
391 BLSPubKey::generated_from_seed_indexed([1; 32], rng.gen::<u64>());
392
393 let request = RequestMessage::new_signed(&public_key, &private_key, &request)
395 .expect("Failed to create signed request");
396
397 Message::Request(request)
398 } else {
399 Message::Response(ResponseMessage {
401 request_hash: blake3::hash(&request),
402 response: vec![rng.gen::<u8>(); rng.gen_range(0..10000)],
403 })
404 };
405
406 let serialized = message.to_bytes().expect("Failed to serialize message");
408
409 let deserialized =
411 Message::from_bytes(&serialized).expect("Failed to deserialize message");
412
413 assert_eq!(message, deserialized);
415 }
416 }
417
418 #[test]
420 fn test_length_prefix_parity() {
421 let mut rng = rand::thread_rng();
423
424 for _ in 0..100 {
425 let mut bytes = Vec::new();
427
428 let value = vec![rng.gen::<u8>(); rng.gen_range(0..10000)];
430
431 write_length_prefixed(&mut bytes, &value).unwrap();
433
434 let mut reader = Cursor::new(bytes);
436
437 let value = read_length_prefixed(&mut reader).unwrap();
439 assert_eq!(value, value);
440 }
441 }
442}