1use std::{
2 borrow::Cow,
3 fmt,
4 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
5};
6
7use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub enum NetAddr {
14 Inet(IpAddr, u16),
15 Name(Cow<'static, str>, u16),
16}
17
18impl NetAddr {
19 pub fn named<S>(name: S, port: u16) -> Self
20 where
21 S: Into<Cow<'static, str>>,
22 {
23 Self::Name(name.into(), port)
24 }
25
26 pub fn port(&self) -> u16 {
28 match self {
29 Self::Inet(_, p) => *p,
30 Self::Name(_, p) => *p,
31 }
32 }
33
34 pub fn set_port(&mut self, p: u16) {
36 match self {
37 Self::Inet(_, o) => *o = p,
38 Self::Name(_, o) => *o = p,
39 }
40 }
41
42 pub fn with_port(mut self, p: u16) -> Self {
43 match self {
44 Self::Inet(ip, _) => self = Self::Inet(ip, p),
45 Self::Name(hn, _) => self = Self::Name(hn, p),
46 }
47 self
48 }
49
50 pub fn with_offset(mut self, o: u16) -> Self {
51 debug_assert!(self.port().checked_add(o).is_some());
52 match self {
53 Self::Inet(ip, p) => self = Self::Inet(ip, p + o),
54 Self::Name(hn, p) => self = Self::Name(hn, p + o),
55 }
56 self
57 }
58
59 pub fn is_ip(&self) -> bool {
60 matches!(self, Self::Inet(..))
61 }
62}
63
64impl fmt::Display for NetAddr {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::Inet(a, p) => write!(f, "{a}:{p}"),
68 Self::Name(h, p) => write!(f, "{h}:{p}"),
69 }
70 }
71}
72
73impl From<(&str, u16)> for NetAddr {
74 fn from((h, p): (&str, u16)) -> Self {
75 Self::Name(h.to_string().into(), p)
76 }
77}
78
79impl From<(String, u16)> for NetAddr {
80 fn from((h, p): (String, u16)) -> Self {
81 Self::Name(h.into(), p)
82 }
83}
84
85impl From<(IpAddr, u16)> for NetAddr {
86 fn from((ip, p): (IpAddr, u16)) -> Self {
87 Self::Inet(ip, p)
88 }
89}
90
91impl From<(Ipv4Addr, u16)> for NetAddr {
92 fn from((ip, p): (Ipv4Addr, u16)) -> Self {
93 Self::Inet(IpAddr::V4(ip), p)
94 }
95}
96
97impl From<(Ipv6Addr, u16)> for NetAddr {
98 fn from((ip, p): (Ipv6Addr, u16)) -> Self {
99 Self::Inet(IpAddr::V6(ip), p)
100 }
101}
102
103impl From<SocketAddr> for NetAddr {
104 fn from(a: SocketAddr) -> Self {
105 Self::Inet(a.ip(), a.port())
106 }
107}
108
109impl std::str::FromStr for NetAddr {
110 type Err = InvalidNetAddr;
111
112 fn from_str(s: &str) -> Result<Self, Self::Err> {
113 let parse = |a: &str, p: Option<&str>| {
114 let p: u16 = if let Some(p) = p {
115 p.parse().map_err(|_| InvalidNetAddr(()))?
116 } else {
117 0
118 };
119 IpAddr::from_str(a)
120 .map(|a| Self::Inet(a, p))
121 .or_else(|_| Ok(Self::Name(a.to_string().into(), p)))
122 };
123 match s.rsplit_once(':') {
124 None => parse(s, None),
125 Some((a, p)) => parse(a, Some(p)),
126 }
127 }
128}
129
130impl TryFrom<&str> for NetAddr {
131 type Error = InvalidNetAddr;
132
133 fn try_from(val: &str) -> Result<Self, Self::Error> {
134 val.parse()
135 }
136}
137
138#[derive(Debug, Clone, thiserror::Error)]
139#[error("invalid network address")]
140pub struct InvalidNetAddr(());
141
142impl Serialize for NetAddr {
143 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
144 self.to_string().serialize(s)
145 }
146}
147
148impl<'de> Deserialize<'de> for NetAddr {
149 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
150 let s = String::deserialize(d)?;
151 let a = s.parse().map_err(de::Error::custom)?;
152 Ok(a)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use std::net::IpAddr;
159
160 use super::NetAddr;
161
162 #[test]
163 fn test_parse() {
164 let a: NetAddr = "127.0.0.1:1234".parse().unwrap();
165 let NetAddr::Inet(a, p) = a else {
166 unreachable!()
167 };
168 assert_eq!(IpAddr::from([127, 0, 0, 1]), a);
169 assert_eq!(1234, p);
170
171 let a: NetAddr = "::1:1234".parse().unwrap();
172 let NetAddr::Inet(a, p) = a else {
173 unreachable!()
174 };
175 assert_eq!("::1".parse::<IpAddr>().unwrap(), a);
176 assert_eq!(1234, p);
177
178 let a: NetAddr = "localhost:1234".parse().unwrap();
179 let NetAddr::Name(h, p) = a else {
180 unreachable!()
181 };
182 assert_eq!("localhost", &h);
183 assert_eq!(1234, p);
184
185 let a: NetAddr = "sub.domain.com:1234".parse().unwrap();
186 let NetAddr::Name(h, p) = a else {
187 unreachable!()
188 };
189 assert_eq!("sub.domain.com", &h);
190 assert_eq!(1234, p);
191 }
192}