staking_cli/
parse.rs

1use std::{fmt::Display, str::FromStr as _};
2
3use derive_more::{Add, From};
4use hotshot_types::{light_client::StateSignKey, signature_key::BLSPrivKey};
5use rust_decimal::{prelude::ToPrimitive as _, Decimal};
6use tagged_base64::{TaggedBase64, Tb64Error};
7use thiserror::Error;
8
9pub fn parse_bls_priv_key(s: &str) -> Result<BLSPrivKey, Tb64Error> {
10    TaggedBase64::parse(s)?.try_into()
11}
12
13pub fn parse_state_priv_key(s: &str) -> Result<StateSignKey, Tb64Error> {
14    TaggedBase64::parse(s)?.try_into()
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq, Add)]
18pub struct Commission(u16);
19
20impl Commission {
21    pub fn to_evm(self) -> u16 {
22        self.0
23    }
24}
25
26impl TryFrom<&str> for Commission {
27    type Error = ParseCommissionError;
28
29    fn try_from(s: &str) -> Result<Self, Self::Error> {
30        parse_commission(s)
31    }
32}
33
34impl TryFrom<u16> for Commission {
35    type Error = ParseCommissionError;
36
37    fn try_from(s: u16) -> Result<Self, Self::Error> {
38        if s > 10000 {
39            return Err("Commission must be between 0 (0.00%) and 100 (100.00%)"
40                .to_string()
41                .into());
42        }
43        Ok(Commission(s))
44    }
45}
46
47impl TryFrom<u64> for Commission {
48    type Error = ParseCommissionError;
49
50    fn try_from(s: u64) -> Result<Self, Self::Error> {
51        if s > 10000 {
52            return Err("Commission must be between 0 (0.00%) and 100 (100.00%)"
53                .to_string()
54                .into());
55        }
56        Ok(Commission(s as u16))
57    }
58}
59
60impl Display for Commission {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "{:.2} %", Decimal::from(self.0) / Decimal::new(100, 0))
63    }
64}
65
66#[derive(Clone, Debug, From, Error)]
67#[error("failed to parse ByteSize. {msg}")]
68pub struct ParseCommissionError {
69    msg: String,
70}
71
72/// Parse a percentage string into a `Percentage` type.
73pub fn parse_commission(s: &str) -> Result<Commission, ParseCommissionError> {
74    // strip trailing whitespace and percentage sign
75    let s = s.trim().trim_end_matches("%").trim();
76    let dec = Decimal::from_str(s).map_err(|e| ParseCommissionError { msg: e.to_string() })?;
77    if dec != dec.round_dp(2) {
78        return Err(
79            "Commission must be in percent with at most 2 decimal places"
80                .to_string()
81                .into(),
82        );
83    }
84    let hundred = Decimal::new(100, 0);
85    if dec < Decimal::ZERO || dec > hundred {
86        return Err(
87            format!("Commission must be between 0 (0.00%) and 100 (100.00%), got {dec}")
88                .to_string()
89                .into(),
90        );
91    }
92    Ok(Commission(
93        dec.checked_mul(hundred)
94            .expect("multiplication succeeds")
95            .to_u16()
96            .expect("conversion to u64 succeeds"),
97    ))
98}
99
100#[cfg(test)]
101mod test {
102    use super::*;
103    #[test]
104    fn test_commission_display() {
105        let cases = [
106            (0, "0.00 %"),
107            (1, "0.01 %"),
108            (100, "1.00 %"),
109            (200, "2.00 %"),
110            (1234, "12.34 %"),
111            (10000, "100.00 %"),
112        ];
113        for (input, expected) in cases {
114            let commission = Commission(input);
115            assert_eq!(commission.to_string(), expected);
116            let parsed = parse_commission(expected).unwrap();
117            assert_eq!(parsed, commission);
118        }
119    }
120
121    #[test]
122    fn test_parse_commission() {
123        let cases = [
124            ("0", 0),
125            ("0.0", 0),
126            ("0.00", 0),
127            ("0.000", 0),
128            ("0.01", 1),
129            ("1", 100),
130            ("2", 200),
131            ("1.000000", 100),
132            ("1.2", 120),
133            ("12.34", 1234),
134            ("12.34%", 1234),
135            ("12.34% ", 1234),
136            ("12.34 % ", 1234),
137            ("100", 10000),
138            ("100.0", 10000),
139            ("100.00", 10000),
140            ("100.000", 10000),
141        ];
142        for (input, expected) in cases {
143            let parsed = parse_commission(input).unwrap().to_evm();
144            assert_eq!(
145                parsed, expected,
146                "input: {input}, parsed: {parsed} != expected {expected}"
147            );
148        }
149
150        let failure_cases = [
151            "-1", "-0.001", "0.123", "0.1234", "99.999", ".001", "100.01", "100.1", "1000", "fooo",
152            "0.0.",
153        ];
154        for input in failure_cases {
155            assert!(
156                parse_commission(input).is_err(),
157                "input: {input} did not fail"
158            );
159        }
160    }
161}