1use std::collections::HashMap;
4
5use alloy::primitives::{FixedBytes, U256};
6use anyhow::bail;
7use espresso_types::SeqTypes;
8use hotshot_query_service::availability::Error;
9use hotshot_task_impls::helpers::derive_signed_state_digest;
10use hotshot_types::{
11 light_client::StateVerKey,
12 simple_certificate::LightClientStateUpdateCertificateV2,
13 stake_table::HSStakeTable,
14 traits::signature_key::{LCV2StateSignatureKey, LCV3StateSignatureKey, StakeTableEntryType},
15};
16use tide_disco::StatusCode;
17
18#[derive(Debug, thiserror::Error)]
20pub enum StateCertFetchError {
21 #[error("Failed to fetch state certificate: {0}")]
22 FetchError(#[source] anyhow::Error),
23
24 #[error("State certificate validation failed: {0}")]
25 ValidationError(#[source] anyhow::Error),
26
27 #[error("State certificate error: {0}")]
28 Other(#[source] anyhow::Error),
29}
30
31impl From<StateCertFetchError> for hotshot_query_service::availability::Error {
32 fn from(err: StateCertFetchError) -> Self {
33 match err {
34 StateCertFetchError::FetchError(e) => Error::Custom {
35 message: format!("Failed to fetch state cert from peers: {e}"),
36 status: StatusCode::NOT_FOUND,
37 },
38 StateCertFetchError::ValidationError(e) => Error::Custom {
39 message: format!("State certificate validation failed: {e}"),
40 status: StatusCode::INTERNAL_SERVER_ERROR,
41 },
42 StateCertFetchError::Other(e) => Error::Custom {
43 message: format!("Failed to process state cert: {e}"),
44 status: StatusCode::INTERNAL_SERVER_ERROR,
45 },
46 }
47 }
48}
49
50pub fn validate_state_cert(
52 cert: &LightClientStateUpdateCertificateV2<SeqTypes>,
53 stake_table: &HSStakeTable<SeqTypes>,
54) -> anyhow::Result<()> {
55 let signed_state_digest = derive_signed_state_digest(
56 &cert.light_client_state,
57 &cert.next_stake_table_state,
58 &cert.auth_root,
59 );
60
61 let use_lcv2_only = cert.auth_root == FixedBytes::<32>::default();
64
65 let signature_map: HashMap<&StateVerKey, _> = cert
66 .signatures
67 .iter()
68 .map(|(key, lcv3_sig, lcv2_sig)| (key, (lcv3_sig, lcv2_sig)))
69 .collect();
70
71 let mut accumulated_weight = U256::ZERO;
73
74 for peer in stake_table.iter() {
75 if let Some((lcv3_sig, lcv2_sig)) = signature_map.get(&peer.state_ver_key) {
76 let lcv2_valid = <StateVerKey as LCV2StateSignatureKey>::verify_state_sig(
77 &peer.state_ver_key,
78 lcv2_sig,
79 &cert.light_client_state,
80 &cert.next_stake_table_state,
81 );
82
83 let is_valid = if use_lcv2_only {
84 lcv2_valid
85 } else {
86 let lcv3_valid = <StateVerKey as LCV3StateSignatureKey>::verify_state_sig(
87 &peer.state_ver_key,
88 lcv3_sig,
89 signed_state_digest,
90 );
91
92 lcv3_valid && lcv2_valid
93 };
94
95 if is_valid {
96 accumulated_weight += peer.stake_table_entry.stake();
97 } else {
98 bail!(format!(
99 "Invalid signature from key: {}",
100 peer.state_ver_key
101 ))
102 }
103 }
104 }
105
106 let total_stake = stake_table.total_stakes();
108 let threshold = hotshot_types::stake_table::one_honest_threshold(total_stake);
109 if accumulated_weight < threshold {
110 bail!(
111 "State certificate validation failed: accumulated weight {accumulated_weight} is \
112 below threshold {threshold}",
113 );
114 }
115
116 Ok(())
117}