hotshot_libp2p_networking/network/
cbor.rs

1use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
2
3use async_trait::async_trait;
4use cbor4ii::core::error::DecodeError;
5use futures::prelude::*;
6use libp2p::{
7    request_response::{self, Codec},
8    StreamProtocol,
9};
10use serde::{de::DeserializeOwned, Serialize};
11
12/// `Behaviour` type alias for the `Cbor` codec
13pub type Behaviour<Req, Resp> = request_response::Behaviour<Cbor<Req, Resp>>;
14
15/// Forked `cbor` codec with altered request/response sizes
16pub struct Cbor<Req, Resp> {
17    /// Phantom data
18    phantom: PhantomData<(Req, Resp)>,
19    /// Maximum request size in bytes
20    request_size_maximum: u64,
21    /// Maximum response size in bytes
22    response_size_maximum: u64,
23}
24
25impl<Req, Resp> Default for Cbor<Req, Resp> {
26    fn default() -> Self {
27        Cbor {
28            phantom: PhantomData,
29            request_size_maximum: 20 * 1024 * 1024,
30            response_size_maximum: 20 * 1024 * 1024,
31        }
32    }
33}
34
35impl<Req, Resp> Cbor<Req, Resp> {
36    /// Create a new `Cbor` codec with the given request and response sizes
37    #[must_use]
38    pub fn new(request_size_maximum: u64, response_size_maximum: u64) -> Self {
39        Cbor {
40            phantom: PhantomData,
41            request_size_maximum,
42            response_size_maximum,
43        }
44    }
45}
46
47impl<Req, Resp> Clone for Cbor<Req, Resp> {
48    fn clone(&self) -> Self {
49        Self::default()
50    }
51}
52
53#[async_trait]
54impl<Req, Resp> Codec for Cbor<Req, Resp>
55where
56    Req: Send + Serialize + DeserializeOwned,
57    Resp: Send + Serialize + DeserializeOwned,
58{
59    type Protocol = StreamProtocol;
60    type Request = Req;
61    type Response = Resp;
62
63    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
64    where
65        T: AsyncRead + Unpin + Send,
66    {
67        let mut vec = Vec::new();
68
69        io.take(self.request_size_maximum)
70            .read_to_end(&mut vec)
71            .await?;
72
73        cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
74    }
75
76    async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
77    where
78        T: AsyncRead + Unpin + Send,
79    {
80        let mut vec = Vec::new();
81
82        io.take(self.response_size_maximum)
83            .read_to_end(&mut vec)
84            .await?;
85
86        cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
87    }
88
89    async fn write_request<T>(
90        &mut self,
91        _: &Self::Protocol,
92        io: &mut T,
93        req: Self::Request,
94    ) -> io::Result<()>
95    where
96        T: AsyncWrite + Unpin + Send,
97    {
98        let data: Vec<u8> =
99            cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
100
101        io.write_all(data.as_ref()).await?;
102
103        Ok(())
104    }
105
106    async fn write_response<T>(
107        &mut self,
108        _: &Self::Protocol,
109        io: &mut T,
110        resp: Self::Response,
111    ) -> io::Result<()>
112    where
113        T: AsyncWrite + Unpin + Send,
114    {
115        let data: Vec<u8> =
116            cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
117
118        io.write_all(data.as_ref()).await?;
119
120        Ok(())
121    }
122}
123
124/// Convert a `cbor4ii::serde::DecodeError` into an `io::Error`
125fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
126    match err {
127        cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => io::Error::other(e.to_string()),
128        cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
129            io::Error::new(io::ErrorKind::Unsupported, e.to_string())
130        },
131        cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
132            io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string())
133        },
134        cbor4ii::serde::DecodeError::Core(e) => {
135            io::Error::new(io::ErrorKind::InvalidData, e.to_string())
136        },
137        cbor4ii::serde::DecodeError::Custom(e) => io::Error::other(e.to_string()),
138    }
139}
140
141/// Convert a `cbor4ii::serde::EncodeError` into an `io::Error`
142fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
143    io::Error::other(err)
144}