1use std::{fmt::Display, str::FromStr as _};
2
3use derive_more::{Add, From};
4use hotshot_types::{light_client::StateSignKey, signature_key::BLSPrivKey};
5use rust_decimal::{prelude::ToPrimitive as _, Decimal};
6use tagged_base64::{TaggedBase64, Tb64Error};
7use thiserror::Error;
8
9pub fn parse_bls_priv_key(s: &str) -> Result<BLSPrivKey, Tb64Error> {
10 TaggedBase64::parse(s)?.try_into()
11}
12
13pub fn parse_state_priv_key(s: &str) -> Result<StateSignKey, Tb64Error> {
14 TaggedBase64::parse(s)?.try_into()
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq, Add)]
18pub struct Commission(u16);
19
20impl Commission {
21 pub fn to_evm(self) -> u16 {
22 self.0
23 }
24}
25
26impl TryFrom<&str> for Commission {
27 type Error = ParseCommissionError;
28
29 fn try_from(s: &str) -> Result<Self, Self::Error> {
30 parse_commission(s)
31 }
32}
33
34impl TryFrom<u16> for Commission {
35 type Error = ParseCommissionError;
36
37 fn try_from(s: u16) -> Result<Self, Self::Error> {
38 if s > 10000 {
39 return Err("Commission must be between 0 (0.00%) and 100 (100.00%)"
40 .to_string()
41 .into());
42 }
43 Ok(Commission(s))
44 }
45}
46
47impl TryFrom<u64> for Commission {
48 type Error = ParseCommissionError;
49
50 fn try_from(s: u64) -> Result<Self, Self::Error> {
51 if s > 10000 {
52 return Err("Commission must be between 0 (0.00%) and 100 (100.00%)"
53 .to_string()
54 .into());
55 }
56 Ok(Commission(s as u16))
57 }
58}
59
60impl Display for Commission {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{:.2} %", Decimal::from(self.0) / Decimal::new(100, 0))
63 }
64}
65
66#[derive(Clone, Debug, From, Error)]
67#[error("failed to parse ByteSize. {msg}")]
68pub struct ParseCommissionError {
69 msg: String,
70}
71
72pub fn parse_commission(s: &str) -> Result<Commission, ParseCommissionError> {
74 let s = s.trim().trim_end_matches("%").trim();
76 let dec = Decimal::from_str(s).map_err(|e| ParseCommissionError { msg: e.to_string() })?;
77 if dec != dec.round_dp(2) {
78 return Err(
79 "Commission must be in percent with at most 2 decimal places"
80 .to_string()
81 .into(),
82 );
83 }
84 let hundred = Decimal::new(100, 0);
85 if dec < Decimal::ZERO || dec > hundred {
86 return Err(
87 format!("Commission must be between 0 (0.00%) and 100 (100.00%), got {dec}")
88 .to_string()
89 .into(),
90 );
91 }
92 Ok(Commission(
93 dec.checked_mul(hundred)
94 .expect("multiplication succeeds")
95 .to_u16()
96 .expect("conversion to u64 succeeds"),
97 ))
98}
99
100#[cfg(test)]
101mod test {
102 use super::*;
103 #[test]
104 fn test_commission_display() {
105 let cases = [
106 (0, "0.00 %"),
107 (1, "0.01 %"),
108 (100, "1.00 %"),
109 (200, "2.00 %"),
110 (1234, "12.34 %"),
111 (10000, "100.00 %"),
112 ];
113 for (input, expected) in cases {
114 let commission = Commission(input);
115 assert_eq!(commission.to_string(), expected);
116 let parsed = parse_commission(expected).unwrap();
117 assert_eq!(parsed, commission);
118 }
119 }
120
121 #[test]
122 fn test_parse_commission() {
123 let cases = [
124 ("0", 0),
125 ("0.0", 0),
126 ("0.00", 0),
127 ("0.000", 0),
128 ("0.01", 1),
129 ("1", 100),
130 ("2", 200),
131 ("1.000000", 100),
132 ("1.2", 120),
133 ("12.34", 1234),
134 ("12.34%", 1234),
135 ("12.34% ", 1234),
136 ("12.34 % ", 1234),
137 ("100", 10000),
138 ("100.0", 10000),
139 ("100.00", 10000),
140 ("100.000", 10000),
141 ];
142 for (input, expected) in cases {
143 let parsed = parse_commission(input).unwrap().to_evm();
144 assert_eq!(
145 parsed, expected,
146 "input: {input}, parsed: {parsed} != expected {expected}"
147 );
148 }
149
150 let failure_cases = [
151 "-1", "-0.001", "0.123", "0.1234", "99.999", ".001", "100.01", "100.1", "1000", "fooo",
152 "0.0.",
153 ];
154 for input in failure_cases {
155 assert!(
156 parse_commission(input).is_err(),
157 "input: {input} did not fail"
158 );
159 }
160 }
161}