1use std::collections::HashMap;
10
11use alloy::primitives::U256;
12use ark_ed_on_bn254::EdwardsConfig as Config;
13use ark_ff::PrimeField;
14use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
15use jf_crhf::CRHF;
16use jf_rescue::{crhf::VariableLengthRescueCRHF, RescueError, RescueParameter};
17use jf_signature::schnorr;
18use jf_utils::to_bytes;
19use rand::SeedableRng;
20use rand_chacha::ChaCha20Rng;
21use serde::{Deserialize, Serialize};
22use tagged_base64::tagged;
23
24use crate::signature_key::BLSPubKey;
25
26pub const DEFAULT_STAKE_TABLE_CAPACITY: usize = 200;
28pub type CircuitField = ark_ed_on_bn254::Fq;
30pub type LightClientState = GenericLightClientState<CircuitField>;
32pub type LightClientStateMsg = GenericLightClientStateMsg<CircuitField>;
34pub type StakeTableState = GenericStakeTableState<CircuitField>;
36pub type StateSignatureScheme =
38 jf_signature::schnorr::SchnorrSignatureScheme<ark_ed_on_bn254::EdwardsConfig>;
39pub type StateSignature = schnorr::Signature<Config>;
41pub type StateVerKey = schnorr::VerKey<Config>;
43pub type StateSignKey = schnorr::SignKey<ark_ed_on_bn254::Fr>;
45#[derive(Debug, Default, Clone)]
47pub struct StateKeyPair(pub schnorr::KeyPair<Config>);
48
49#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize, Serialize, Deserialize)]
51pub struct StateSignatureRequestBody {
52 pub key: StateVerKey,
54 pub state: LightClientState,
56 pub next_stake: StakeTableState,
58 pub signature: StateSignature,
60}
61
62#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize, Serialize, Deserialize)]
64pub struct LegacyStateSignatureRequestBody {
65 pub key: StateVerKey,
67 pub state: LightClientState,
69 pub signature: StateSignature,
71}
72
73impl From<LegacyStateSignatureRequestBody> for StateSignatureRequestBody {
74 fn from(value: LegacyStateSignatureRequestBody) -> Self {
75 Self {
76 key: value.key,
77 state: value.state,
78 next_stake: StakeTableState::default(),
80 signature: value.signature,
81 }
82 }
83}
84
85#[derive(Clone, Debug, Serialize, Deserialize)]
87pub struct StateSignaturesBundle {
88 pub state: LightClientState,
90 pub next_stake: StakeTableState,
92 pub signatures: HashMap<StateVerKey, StateSignature>,
94 pub accumulated_weight: U256,
96}
97
98#[tagged("LIGHT_CLIENT_STATE")]
100#[derive(
101 Clone,
102 Debug,
103 CanonicalSerialize,
104 CanonicalDeserialize,
105 Default,
106 Eq,
107 PartialEq,
108 PartialOrd,
109 Ord,
110 Hash,
111 Copy,
112)]
113pub struct GenericLightClientState<F: PrimeField> {
114 pub view_number: u64,
116 pub block_height: u64,
118 pub block_comm_root: F,
120}
121
122pub type GenericLightClientStateMsg<F> = [F; 3];
123
124impl<F: PrimeField> From<GenericLightClientState<F>> for GenericLightClientStateMsg<F> {
125 fn from(state: GenericLightClientState<F>) -> Self {
126 [
127 F::from(state.view_number),
128 F::from(state.block_height),
129 state.block_comm_root,
130 ]
131 }
132}
133
134impl<F: PrimeField> From<&GenericLightClientState<F>> for GenericLightClientStateMsg<F> {
135 fn from(state: &GenericLightClientState<F>) -> Self {
136 [
137 F::from(state.view_number),
138 F::from(state.block_height),
139 state.block_comm_root,
140 ]
141 }
142}
143
144impl<F: PrimeField + RescueParameter> GenericLightClientState<F> {
145 pub fn new(
146 view_number: u64,
147 block_height: u64,
148 block_comm_root: &[u8],
149 ) -> anyhow::Result<Self> {
150 Ok(Self {
151 view_number,
152 block_height,
153 block_comm_root: hash_bytes_to_field(block_comm_root)?,
154 })
155 }
156}
157
158#[tagged("STAKE_TABLE_STATE")]
160#[derive(
161 Clone,
162 Debug,
163 CanonicalSerialize,
164 CanonicalDeserialize,
165 Default,
166 Eq,
167 PartialEq,
168 PartialOrd,
169 Ord,
170 Hash,
171 Copy,
172)]
173pub struct GenericStakeTableState<F: PrimeField> {
174 pub bls_key_comm: F,
176 pub schnorr_key_comm: F,
178 pub amount_comm: F,
180 pub threshold: F,
182}
183
184impl<F: PrimeField> From<GenericStakeTableState<F>> for [F; 4] {
185 fn from(state: GenericStakeTableState<F>) -> Self {
186 [
187 state.bls_key_comm,
188 state.schnorr_key_comm,
189 state.amount_comm,
190 state.threshold,
191 ]
192 }
193}
194
195impl std::ops::Deref for StateKeyPair {
196 type Target = schnorr::KeyPair<Config>;
197
198 fn deref(&self) -> &Self::Target {
199 &self.0
200 }
201}
202
203impl StateKeyPair {
204 #[must_use]
206 pub fn from_sign_key(sk: StateSignKey) -> Self {
207 Self(schnorr::KeyPair::<Config>::from(sk))
208 }
209
210 #[must_use]
212 pub fn generate() -> StateKeyPair {
213 schnorr::KeyPair::generate(&mut rand::thread_rng()).into()
214 }
215
216 #[must_use]
218 pub fn generate_from_seed(seed: [u8; 32]) -> StateKeyPair {
219 schnorr::KeyPair::generate(&mut ChaCha20Rng::from_seed(seed)).into()
220 }
221
222 #[must_use]
224 pub fn generate_from_seed_indexed(seed: [u8; 32], index: u64) -> StateKeyPair {
225 let mut hasher = blake3::Hasher::new();
226 hasher.update(&seed);
227 hasher.update(&index.to_le_bytes());
228 let new_seed = *hasher.finalize().as_bytes();
229 Self::generate_from_seed(new_seed)
230 }
231}
232
233impl From<schnorr::KeyPair<Config>> for StateKeyPair {
234 fn from(value: schnorr::KeyPair<Config>) -> Self {
235 StateKeyPair(value)
236 }
237}
238
239pub fn hash_bytes_to_field<F: RescueParameter>(bytes: &[u8]) -> Result<F, RescueError> {
240 let bytes_len = (<F as PrimeField>::MODULUS_BIT_SIZE.div_ceil(8) - 1) as usize;
242 let elem = bytes
243 .chunks(bytes_len)
244 .map(F::from_le_bytes_mod_order)
245 .collect::<Vec<_>>();
246 Ok(VariableLengthRescueCRHF::<_, 1>::evaluate(elem)?[0])
247}
248
249pub trait ToFieldsLightClientCompat {
252 const SIZE: usize;
253 fn to_fields(&self) -> Vec<CircuitField>;
254}
255
256impl ToFieldsLightClientCompat for StateVerKey {
257 const SIZE: usize = 2;
258 fn to_fields(&self) -> Vec<CircuitField> {
260 let p = self.to_affine();
261 vec![p.x, p.y]
262 }
263}
264
265impl ToFieldsLightClientCompat for BLSPubKey {
266 const SIZE: usize = 3;
267 fn to_fields(&self) -> Vec<CircuitField> {
269 match to_bytes!(&self.to_affine()) {
270 Ok(bytes) => {
271 vec![
272 CircuitField::from_le_bytes_mod_order(&bytes[..31]),
273 CircuitField::from_le_bytes_mod_order(&bytes[31..62]),
274 CircuitField::from_le_bytes_mod_order(&bytes[62..]),
275 ]
276 },
277 Err(_) => unreachable!(),
278 }
279 }
280}