hotshot_types/
addr.rs

1use std::{
2    borrow::Cow,
3    fmt,
4    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
5};
6
7use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
8
9/// A network address.
10///
11/// Either an IP address and port number or else a hostname and port number.
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub enum NetAddr {
14    Inet(IpAddr, u16),
15    Name(Cow<'static, str>, u16),
16}
17
18impl NetAddr {
19    pub fn named<S>(name: S, port: u16) -> Self
20    where
21        S: Into<Cow<'static, str>>,
22    {
23        Self::Name(name.into(), port)
24    }
25
26    /// Get the port number of an address.
27    pub fn port(&self) -> u16 {
28        match self {
29            Self::Inet(_, p) => *p,
30            Self::Name(_, p) => *p,
31        }
32    }
33
34    /// Set the address port.
35    pub fn set_port(&mut self, p: u16) {
36        match self {
37            Self::Inet(_, o) => *o = p,
38            Self::Name(_, o) => *o = p,
39        }
40    }
41
42    pub fn with_port(mut self, p: u16) -> Self {
43        match self {
44            Self::Inet(ip, _) => self = Self::Inet(ip, p),
45            Self::Name(hn, _) => self = Self::Name(hn, p),
46        }
47        self
48    }
49
50    pub fn with_offset(mut self, o: u16) -> Self {
51        debug_assert!(self.port().checked_add(o).is_some());
52        match self {
53            Self::Inet(ip, p) => self = Self::Inet(ip, p + o),
54            Self::Name(hn, p) => self = Self::Name(hn, p + o),
55        }
56        self
57    }
58
59    pub fn is_ip(&self) -> bool {
60        matches!(self, Self::Inet(..))
61    }
62}
63
64impl fmt::Display for NetAddr {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Self::Inet(a, p) => write!(f, "{a}:{p}"),
68            Self::Name(h, p) => write!(f, "{h}:{p}"),
69        }
70    }
71}
72
73impl From<(&str, u16)> for NetAddr {
74    fn from((h, p): (&str, u16)) -> Self {
75        Self::Name(h.to_string().into(), p)
76    }
77}
78
79impl From<(String, u16)> for NetAddr {
80    fn from((h, p): (String, u16)) -> Self {
81        Self::Name(h.into(), p)
82    }
83}
84
85impl From<(IpAddr, u16)> for NetAddr {
86    fn from((ip, p): (IpAddr, u16)) -> Self {
87        Self::Inet(ip, p)
88    }
89}
90
91impl From<(Ipv4Addr, u16)> for NetAddr {
92    fn from((ip, p): (Ipv4Addr, u16)) -> Self {
93        Self::Inet(IpAddr::V4(ip), p)
94    }
95}
96
97impl From<(Ipv6Addr, u16)> for NetAddr {
98    fn from((ip, p): (Ipv6Addr, u16)) -> Self {
99        Self::Inet(IpAddr::V6(ip), p)
100    }
101}
102
103impl From<SocketAddr> for NetAddr {
104    fn from(a: SocketAddr) -> Self {
105        Self::Inet(a.ip(), a.port())
106    }
107}
108
109impl std::str::FromStr for NetAddr {
110    type Err = InvalidNetAddr;
111
112    fn from_str(s: &str) -> Result<Self, Self::Err> {
113        let parse = |a: &str, p: Option<&str>| {
114            let p: u16 = if let Some(p) = p {
115                p.parse().map_err(|_| InvalidNetAddr(()))?
116            } else {
117                0
118            };
119            IpAddr::from_str(a)
120                .map(|a| Self::Inet(a, p))
121                .or_else(|_| Ok(Self::Name(a.to_string().into(), p)))
122        };
123        match s.rsplit_once(':') {
124            None => parse(s, None),
125            Some((a, p)) => parse(a, Some(p)),
126        }
127    }
128}
129
130impl TryFrom<&str> for NetAddr {
131    type Error = InvalidNetAddr;
132
133    fn try_from(val: &str) -> Result<Self, Self::Error> {
134        val.parse()
135    }
136}
137
138#[derive(Debug, Clone, thiserror::Error)]
139#[error("invalid network address")]
140pub struct InvalidNetAddr(());
141
142impl Serialize for NetAddr {
143    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
144        self.to_string().serialize(s)
145    }
146}
147
148impl<'de> Deserialize<'de> for NetAddr {
149    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
150        let s = String::deserialize(d)?;
151        let a = s.parse().map_err(de::Error::custom)?;
152        Ok(a)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use std::net::IpAddr;
159
160    use super::NetAddr;
161
162    #[test]
163    fn test_parse() {
164        let a: NetAddr = "127.0.0.1:1234".parse().unwrap();
165        let NetAddr::Inet(a, p) = a else {
166            unreachable!()
167        };
168        assert_eq!(IpAddr::from([127, 0, 0, 1]), a);
169        assert_eq!(1234, p);
170
171        let a: NetAddr = "::1:1234".parse().unwrap();
172        let NetAddr::Inet(a, p) = a else {
173            unreachable!()
174        };
175        assert_eq!("::1".parse::<IpAddr>().unwrap(), a);
176        assert_eq!(1234, p);
177
178        let a: NetAddr = "localhost:1234".parse().unwrap();
179        let NetAddr::Name(h, p) = a else {
180            unreachable!()
181        };
182        assert_eq!("localhost", &h);
183        assert_eq!(1234, p);
184
185        let a: NetAddr = "sub.domain.com:1234".parse().unwrap();
186        let NetAddr::Name(h, p) = a else {
187            unreachable!()
188        };
189        assert_eq!("sub.domain.com", &h);
190        assert_eq!(1234, p);
191    }
192}