espresso_types/v0/impls/
reward.rs

1use std::{borrow::Borrow, collections::HashSet, iter::once, str::FromStr};
2
3use alloy::primitives::{
4    utils::{parse_units, ParseUnits},
5    Address, B256, U256,
6};
7use anyhow::{bail, ensure, Context};
8use ark_serialize::{
9    CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate,
10};
11use hotshot::types::BLSPubKey;
12use hotshot_contract_adapter::sol_types::AccruedRewardsProofSol;
13use hotshot_types::{
14    data::{EpochNumber, ViewNumber},
15    traits::{election::Membership, node_implementation::ConsensusTime},
16    utils::epoch_from_block_number,
17};
18use jf_merkle_tree::{
19    prelude::MerkleNode, ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme,
20    LookupResult, MerkleTreeScheme, PersistentUniversalMerkleTreeScheme, ToTraversalPath,
21    UniversalMerkleTreeScheme,
22};
23use num_traits::CheckedSub;
24use sequencer_utils::{
25    impl_serde_from_string_or_integer, impl_to_fixed_bytes, ser::FromStringOrInteger,
26};
27use vbs::version::StaticVersionType;
28
29use super::{
30    v0_3::{RewardAmount, Validator, COMMISSION_BASIS_POINTS},
31    v0_4::{
32        RewardAccountProofV2, RewardAccountQueryDataV2, RewardAccountV2, RewardMerkleCommitmentV2,
33        RewardMerkleProofV2, RewardMerkleTreeV2,
34    },
35    Leaf2, NodeState, ValidatedState,
36};
37use crate::{
38    eth_signature_key::EthKeyPair,
39    v0_3::{
40        RewardAccountProofV1, RewardAccountV1, RewardMerkleCommitmentV1, RewardMerkleProofV1,
41        RewardMerkleTreeV1,
42    },
43    v0_4::{Delta, REWARD_MERKLE_TREE_V2_ARITY, REWARD_MERKLE_TREE_V2_HEIGHT},
44    DrbAndHeaderUpgradeVersion, EpochVersion, FeeAccount,
45};
46
47impl_serde_from_string_or_integer!(RewardAmount);
48impl_to_fixed_bytes!(RewardAmount, U256);
49
50impl From<u64> for RewardAmount {
51    fn from(amt: u64) -> Self {
52        Self(U256::from(amt))
53    }
54}
55
56impl CheckedSub for RewardAmount {
57    fn checked_sub(&self, v: &Self) -> Option<Self> {
58        self.0.checked_sub(v.0).map(RewardAmount)
59    }
60}
61
62impl FromStr for RewardAmount {
63    type Err = <U256 as FromStr>::Err;
64
65    fn from_str(s: &str) -> Result<Self, Self::Err> {
66        Ok(Self(s.parse()?))
67    }
68}
69
70impl FromStringOrInteger for RewardAmount {
71    type Binary = U256;
72    type Integer = u64;
73
74    fn from_binary(b: Self::Binary) -> anyhow::Result<Self> {
75        Ok(Self(b))
76    }
77
78    fn from_integer(i: Self::Integer) -> anyhow::Result<Self> {
79        Ok(i.into())
80    }
81
82    fn from_string(s: String) -> anyhow::Result<Self> {
83        // For backwards compatibility, we have an ad hoc parser for WEI amounts represented as hex
84        // strings.
85        if let Some(s) = s.strip_prefix("0x") {
86            return Ok(Self(s.parse()?));
87        }
88
89        // Strip an optional non-numeric suffix, which will be interpreted as a unit.
90        let (base, unit) = s
91            .split_once(char::is_whitespace)
92            .unwrap_or((s.as_str(), "wei"));
93        match parse_units(base, unit)? {
94            ParseUnits::U256(n) => Ok(Self(n)),
95            ParseUnits::I256(_) => bail!("amount cannot be negative"),
96        }
97    }
98
99    fn to_binary(&self) -> anyhow::Result<Self::Binary> {
100        Ok(self.0)
101    }
102
103    fn to_string(&self) -> anyhow::Result<String> {
104        Ok(format!("{self}"))
105    }
106}
107
108impl RewardAmount {
109    pub fn as_u64(&self) -> Option<u64> {
110        if self.0 <= U256::from(u64::MAX) {
111            Some(self.0.to::<u64>())
112        } else {
113            None
114        }
115    }
116}
117
118impl From<[u8; 20]> for RewardAccountV1 {
119    fn from(bytes: [u8; 20]) -> Self {
120        Self(Address::from(bytes))
121    }
122}
123
124impl AsRef<[u8]> for RewardAccountV1 {
125    fn as_ref(&self) -> &[u8] {
126        self.0.as_slice()
127    }
128}
129
130impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV1 {
131    fn to_traversal_path(&self, height: usize) -> Vec<usize> {
132        self.0
133            .as_slice()
134            .iter()
135            .take(height)
136            .map(|i| *i as usize)
137            .collect()
138    }
139}
140
141impl RewardAccountV2 {
142    /// Return inner `Address`
143    pub fn address(&self) -> Address {
144        self.0
145    }
146    /// Return byte slice representation of inner `Address` type
147    pub fn as_bytes(&self) -> &[u8] {
148        self.0.as_slice()
149    }
150    /// Return array containing underlying bytes of inner `Address` type
151    pub fn to_fixed_bytes(self) -> [u8; 20] {
152        self.0.into_array()
153    }
154    pub fn test_key_pair() -> EthKeyPair {
155        EthKeyPair::from_mnemonic(
156            "test test test test test test test test test test test junk",
157            0u32,
158        )
159        .unwrap()
160    }
161}
162
163impl RewardAccountV1 {
164    /// Return inner `Address`
165    pub fn address(&self) -> Address {
166        self.0
167    }
168    /// Return byte slice representation of inner `Address` type
169    pub fn as_bytes(&self) -> &[u8] {
170        self.0.as_slice()
171    }
172    /// Return array containing underlying bytes of inner `Address` type
173    pub fn to_fixed_bytes(self) -> [u8; 20] {
174        self.0.into_array()
175    }
176    pub fn test_key_pair() -> EthKeyPair {
177        EthKeyPair::from_mnemonic(
178            "test test test test test test test test test test test junk",
179            0u32,
180        )
181        .unwrap()
182    }
183}
184
185impl FromStr for RewardAccountV2 {
186    type Err = anyhow::Error;
187
188    fn from_str(s: &str) -> Result<Self, Self::Err> {
189        Ok(Self(s.parse()?))
190    }
191}
192
193impl FromStr for RewardAccountV1 {
194    type Err = anyhow::Error;
195
196    fn from_str(s: &str) -> Result<Self, Self::Err> {
197        Ok(Self(s.parse()?))
198    }
199}
200
201impl Valid for RewardAmount {
202    fn check(&self) -> Result<(), SerializationError> {
203        Ok(())
204    }
205}
206
207impl Valid for RewardAccountV2 {
208    fn check(&self) -> Result<(), SerializationError> {
209        Ok(())
210    }
211}
212
213impl Valid for RewardAccountV1 {
214    fn check(&self) -> Result<(), SerializationError> {
215        Ok(())
216    }
217}
218
219impl CanonicalSerialize for RewardAmount {
220    fn serialize_with_mode<W: std::io::prelude::Write>(
221        &self,
222        mut writer: W,
223        _compress: Compress,
224    ) -> Result<(), SerializationError> {
225        Ok(writer.write_all(&self.to_fixed_bytes())?)
226    }
227
228    fn serialized_size(&self, _compress: Compress) -> usize {
229        core::mem::size_of::<U256>()
230    }
231}
232impl CanonicalDeserialize for RewardAmount {
233    fn deserialize_with_mode<R: Read>(
234        mut reader: R,
235        _compress: Compress,
236        _validate: Validate,
237    ) -> Result<Self, SerializationError> {
238        let mut bytes = [0u8; core::mem::size_of::<U256>()];
239        reader.read_exact(&mut bytes)?;
240        let value = U256::from_le_slice(&bytes);
241        Ok(Self(value))
242    }
243}
244
245impl CanonicalSerialize for RewardAccountV2 {
246    fn serialize_with_mode<W: std::io::prelude::Write>(
247        &self,
248        mut writer: W,
249        _compress: Compress,
250    ) -> Result<(), SerializationError> {
251        Ok(writer.write_all(self.0.as_slice())?)
252    }
253
254    fn serialized_size(&self, _compress: Compress) -> usize {
255        core::mem::size_of::<Address>()
256    }
257}
258impl CanonicalDeserialize for RewardAccountV2 {
259    fn deserialize_with_mode<R: Read>(
260        mut reader: R,
261        _compress: Compress,
262        _validate: Validate,
263    ) -> Result<Self, SerializationError> {
264        let mut bytes = [0u8; core::mem::size_of::<Address>()];
265        reader.read_exact(&mut bytes)?;
266        let value = Address::from_slice(&bytes);
267        Ok(Self(value))
268    }
269}
270
271impl CanonicalSerialize for RewardAccountV1 {
272    fn serialize_with_mode<W: std::io::prelude::Write>(
273        &self,
274        mut writer: W,
275        _compress: Compress,
276    ) -> Result<(), SerializationError> {
277        Ok(writer.write_all(self.0.as_slice())?)
278    }
279
280    fn serialized_size(&self, _compress: Compress) -> usize {
281        core::mem::size_of::<Address>()
282    }
283}
284impl CanonicalDeserialize for RewardAccountV1 {
285    fn deserialize_with_mode<R: Read>(
286        mut reader: R,
287        _compress: Compress,
288        _validate: Validate,
289    ) -> Result<Self, SerializationError> {
290        let mut bytes = [0u8; core::mem::size_of::<Address>()];
291        reader.read_exact(&mut bytes)?;
292        let value = Address::from_slice(&bytes);
293        Ok(Self(value))
294    }
295}
296
297impl From<[u8; 20]> for RewardAccountV2 {
298    fn from(bytes: [u8; 20]) -> Self {
299        Self(Address::from(bytes))
300    }
301}
302
303impl AsRef<[u8]> for RewardAccountV2 {
304    fn as_ref(&self) -> &[u8] {
305        self.0.as_slice()
306    }
307}
308
309impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV2 {
310    fn to_traversal_path(&self, height: usize) -> Vec<usize> {
311        let mut result = vec![0; height];
312
313        // Convert 20-byte address to U256
314        let mut value = U256::from_be_slice(self.0.as_slice());
315
316        // Extract digits using modulo and division (LSB first)
317        for item in result.iter_mut().take(height) {
318            let digit = (value % U256::from(ARITY)).to::<usize>();
319            *item = digit;
320            value /= U256::from(ARITY);
321        }
322
323        result
324    }
325}
326
327impl RewardAccountProofV2 {
328    pub fn presence(
329        pos: FeeAccount,
330        proof: <RewardMerkleTreeV2 as MerkleTreeScheme>::MembershipProof,
331    ) -> Self {
332        Self {
333            account: pos.into(),
334            proof: RewardMerkleProofV2::Presence(proof),
335        }
336    }
337
338    pub fn absence(
339        pos: RewardAccountV2,
340        proof: <RewardMerkleTreeV2 as UniversalMerkleTreeScheme>::NonMembershipProof,
341    ) -> Self {
342        Self {
343            account: pos.into(),
344            proof: RewardMerkleProofV2::Absence(proof),
345        }
346    }
347
348    pub fn prove(tree: &RewardMerkleTreeV2, account: Address) -> Option<(Self, U256)> {
349        match tree.universal_lookup(RewardAccountV2(account)) {
350            LookupResult::Ok(balance, proof) => Some((
351                Self {
352                    account,
353                    proof: RewardMerkleProofV2::Presence(proof),
354                },
355                balance.0,
356            )),
357            LookupResult::NotFound(proof) => Some((
358                Self {
359                    account,
360                    proof: RewardMerkleProofV2::Absence(proof),
361                },
362                U256::ZERO,
363            )),
364            LookupResult::NotInMemory => None,
365        }
366    }
367
368    pub fn verify(&self, comm: &RewardMerkleCommitmentV2) -> anyhow::Result<U256> {
369        match &self.proof {
370            RewardMerkleProofV2::Presence(proof) => {
371                ensure!(
372                    RewardMerkleTreeV2::verify(comm, RewardAccountV2(self.account), proof)?.is_ok(),
373                    "invalid proof"
374                );
375                Ok(proof
376                    .elem()
377                    .context("presence proof is missing account balance")?
378                    .0)
379            },
380            RewardMerkleProofV2::Absence(proof) => {
381                let tree = RewardMerkleTreeV2::from_commitment(comm);
382                ensure!(
383                    RewardMerkleTreeV2::non_membership_verify(
384                        tree.commitment(),
385                        RewardAccountV2(self.account),
386                        proof
387                    )?,
388                    "invalid proof"
389                );
390                Ok(U256::ZERO)
391            },
392        }
393    }
394
395    pub fn remember(&self, tree: &mut RewardMerkleTreeV2) -> anyhow::Result<()> {
396        match &self.proof {
397            RewardMerkleProofV2::Presence(proof) => {
398                tree.remember(
399                    RewardAccountV2(self.account),
400                    proof
401                        .elem()
402                        .context("presence proof is missing account balance")?,
403                    proof,
404                )?;
405                Ok(())
406            },
407            RewardMerkleProofV2::Absence(proof) => {
408                tree.non_membership_remember(RewardAccountV2(self.account), proof)?;
409                Ok(())
410            },
411        }
412    }
413}
414
415impl TryInto<AccruedRewardsProofSol> for RewardAccountProofV2 {
416    type Error = anyhow::Error;
417
418    /// Generate a Solidity-compatible proof for this account
419    ///
420    /// The proof is returned without leaf value. The caller is expected to
421    /// obtain the leaf value from the jellyfish proof (Self).
422    ///
423    /// TODO: review error handling / panics
424    fn try_into(self) -> anyhow::Result<AccruedRewardsProofSol> {
425        // NOTE: rustfmt fails to format this file if the nesting is too deep.
426        let proof = if let RewardMerkleProofV2::Presence(proof) = &self.proof {
427            proof
428        } else {
429            bail!("only presence proofs supported")
430        };
431
432        let path = ToTraversalPath::<REWARD_MERKLE_TREE_V2_ARITY>::to_traversal_path(
433            &RewardAccountV2(self.account),
434            REWARD_MERKLE_TREE_V2_HEIGHT,
435        );
436
437        if path.len() != REWARD_MERKLE_TREE_V2_HEIGHT {
438            bail!("Invalid proof: unexpected path length: {}", path.len());
439        };
440
441        let siblings: Vec<B256> = proof
442            .proof
443            .iter()
444            .enumerate()
445            .skip(1) // Skip the leaf node (first element)
446            .filter_map(|(level_idx, node)| match node {
447                MerkleNode::Branch { children, .. } => {
448                    // Use the path to determine which sibling we need
449                    let path_direction = path
450                        .get(level_idx - 1)
451                        .copied()
452                        .expect("exists");
453                    let sibling_idx = if path_direction == 0 { 1 } else { 0 };
454                    if sibling_idx >= children.len() {
455                        panic!(
456                            "Invalid proof: index={sibling_idx} length={}",
457                            children.len()
458                        );
459                    };
460
461                    match children[sibling_idx].as_ref() {
462                        MerkleNode::Empty => Some(B256::ZERO),
463                        MerkleNode::Leaf { value, .. } => {
464                            let bytes = value.as_ref();
465                            Some(B256::from_slice(bytes))
466                        }
467                        MerkleNode::Branch { value, .. } => {
468                            let bytes = value.as_ref();
469                            Some(B256::from_slice(bytes))
470                        }
471                        MerkleNode::ForgettenSubtree { value } => {
472                            let bytes = value.as_ref();
473                            Some(B256::from_slice(bytes))
474                        }
475                    }
476                }
477                _ => None,
478            })
479            .collect();
480
481        Ok(AccruedRewardsProofSol { siblings })
482    }
483}
484
485impl RewardAccountProofV1 {
486    pub fn presence(
487        pos: FeeAccount,
488        proof: <RewardMerkleTreeV1 as MerkleTreeScheme>::MembershipProof,
489    ) -> Self {
490        Self {
491            account: pos.into(),
492            proof: RewardMerkleProofV1::Presence(proof),
493        }
494    }
495
496    pub fn absence(
497        pos: RewardAccountV1,
498        proof: <RewardMerkleTreeV1 as UniversalMerkleTreeScheme>::NonMembershipProof,
499    ) -> Self {
500        Self {
501            account: pos.into(),
502            proof: RewardMerkleProofV1::Absence(proof),
503        }
504    }
505
506    pub fn prove(tree: &RewardMerkleTreeV1, account: Address) -> Option<(Self, U256)> {
507        match tree.universal_lookup(RewardAccountV1(account)) {
508            LookupResult::Ok(balance, proof) => Some((
509                Self {
510                    account,
511                    proof: RewardMerkleProofV1::Presence(proof),
512                },
513                balance.0,
514            )),
515            LookupResult::NotFound(proof) => Some((
516                Self {
517                    account,
518                    proof: RewardMerkleProofV1::Absence(proof),
519                },
520                U256::ZERO,
521            )),
522            LookupResult::NotInMemory => None,
523        }
524    }
525
526    pub fn verify(&self, comm: &RewardMerkleCommitmentV1) -> anyhow::Result<U256> {
527        match &self.proof {
528            RewardMerkleProofV1::Presence(proof) => {
529                ensure!(
530                    RewardMerkleTreeV1::verify(comm, RewardAccountV1(self.account), proof)?.is_ok(),
531                    "invalid proof"
532                );
533                Ok(proof
534                    .elem()
535                    .context("presence proof is missing account balance")?
536                    .0)
537            },
538            RewardMerkleProofV1::Absence(proof) => {
539                let tree = RewardMerkleTreeV1::from_commitment(comm);
540                ensure!(
541                    RewardMerkleTreeV1::non_membership_verify(
542                        tree.commitment(),
543                        RewardAccountV1(self.account),
544                        proof
545                    )?,
546                    "invalid proof"
547                );
548                Ok(U256::ZERO)
549            },
550        }
551    }
552
553    pub fn remember(&self, tree: &mut RewardMerkleTreeV1) -> anyhow::Result<()> {
554        match &self.proof {
555            RewardMerkleProofV1::Presence(proof) => {
556                tree.remember(
557                    RewardAccountV1(self.account),
558                    proof
559                        .elem()
560                        .context("presence proof is missing account balance")?,
561                    proof,
562                )?;
563                Ok(())
564            },
565            RewardMerkleProofV1::Absence(proof) => {
566                tree.non_membership_remember(RewardAccountV1(self.account), proof)?;
567                Ok(())
568            },
569        }
570    }
571}
572
573impl From<(RewardAccountProofV2, U256)> for RewardAccountQueryDataV2 {
574    fn from((proof, balance): (RewardAccountProofV2, U256)) -> Self {
575        Self { balance, proof }
576    }
577}
578
579#[derive(Clone, Debug)]
580pub struct ComputedRewards {
581    leader_address: Address,
582    // leader commission reward
583    leader_commission: RewardAmount,
584    // delegator rewards
585    delegators: Vec<(Address, RewardAmount)>,
586}
587
588impl ComputedRewards {
589    pub fn new(
590        delegators: Vec<(Address, RewardAmount)>,
591        leader_address: Address,
592        leader_commission: RewardAmount,
593    ) -> Self {
594        Self {
595            delegators,
596            leader_address,
597            leader_commission,
598        }
599    }
600
601    pub fn leader_commission(&self) -> &RewardAmount {
602        &self.leader_commission
603    }
604
605    pub fn delegators(&self) -> &Vec<(Address, RewardAmount)> {
606        &self.delegators
607    }
608
609    // chains delegation rewards and leader commission reward
610    pub fn all_rewards(self) -> Vec<(Address, RewardAmount)> {
611        self.delegators
612            .into_iter()
613            .chain(once((self.leader_address, self.leader_commission)))
614            .collect()
615    }
616}
617
618pub struct RewardDistributor {
619    validator: Validator<BLSPubKey>,
620    block_reward: RewardAmount,
621    total_distributed: RewardAmount,
622}
623
624impl RewardDistributor {
625    pub fn new(
626        validator: Validator<BLSPubKey>,
627        block_reward: RewardAmount,
628        total_distributed: RewardAmount,
629    ) -> Self {
630        Self {
631            validator,
632            block_reward,
633            total_distributed,
634        }
635    }
636
637    pub fn validator(&self) -> Validator<BLSPubKey> {
638        self.validator.clone()
639    }
640
641    pub fn block_reward(&self) -> RewardAmount {
642        self.block_reward
643    }
644
645    pub fn total_distributed(&self) -> RewardAmount {
646        self.total_distributed
647    }
648
649    pub fn update_rewards_delta(&self, delta: &mut Delta) -> anyhow::Result<()> {
650        // Update delta rewards
651        delta
652            .rewards_delta
653            .insert(RewardAccountV2(self.validator().account));
654        delta.rewards_delta.extend(
655            self.validator()
656                .delegators
657                .keys()
658                .map(|d| RewardAccountV2(*d)),
659        );
660
661        Ok(())
662    }
663
664    fn update_reward_balance<P>(
665        tree: &mut P,
666        account: &P::Index,
667        amount: P::Element,
668    ) -> anyhow::Result<()>
669    where
670        P: PersistentUniversalMerkleTreeScheme,
671        P: MerkleTreeScheme<Element = RewardAmount>,
672        P::Index: Borrow<<P as MerkleTreeScheme>::Index> + std::fmt::Display,
673    {
674        let mut err = None;
675        *tree = tree.persistent_update_with(account.clone(), |balance| {
676            let balance = balance.copied();
677            match balance.unwrap_or_default().0.checked_add(amount.0) {
678                Some(updated) => Some(updated.into()),
679                None => {
680                    err = Some(format!("overflowed reward balance for account {account}"));
681                    balance
682                },
683            }
684        })?;
685
686        if let Some(error) = err {
687            tracing::warn!(error);
688            bail!(error)
689        }
690
691        Ok(())
692    }
693
694    pub fn apply_rewards(
695        &mut self,
696        version: vbs::version::Version,
697        state: &mut ValidatedState,
698    ) -> anyhow::Result<()> {
699        let computed_rewards = self.compute_rewards()?;
700
701        if version <= EpochVersion::version() {
702            for (address, reward) in computed_rewards.all_rewards() {
703                Self::update_reward_balance(
704                    &mut state.reward_merkle_tree_v1,
705                    &RewardAccountV1(address),
706                    reward,
707                )?;
708                tracing::debug!(%address, %reward, "applied v1 rewards");
709            }
710        } else {
711            for (address, reward) in computed_rewards.all_rewards() {
712                Self::update_reward_balance(
713                    &mut state.reward_merkle_tree_v2,
714                    &RewardAccountV2(address),
715                    reward,
716                )?;
717                tracing::debug!(%address, %reward, "applied v2 rewards");
718            }
719        }
720
721        self.total_distributed += self.block_reward();
722
723        Ok(())
724    }
725
726    /// Computes the reward in a block for the validator and its delegators
727    /// based on the commission rate, individual delegator stake, and total block reward.
728    ///
729    /// The block reward is distributed among the delegators first based on their stake,
730    /// with the remaining amount from the block reward given to the validator as the commission.
731    /// Any minor discrepancies due to rounding off errors are adjusted in the leader reward
732    /// to ensure the total reward is exactly equal to block reward.
733    pub fn compute_rewards(&self) -> anyhow::Result<ComputedRewards> {
734        ensure!(
735            self.validator.commission <= COMMISSION_BASIS_POINTS,
736            "commission must not exceed {COMMISSION_BASIS_POINTS}"
737        );
738
739        let mut rewards = Vec::new();
740
741        let total_reward = self.block_reward.0;
742        let delegators_ratio_basis_points = U256::from(COMMISSION_BASIS_POINTS)
743            .checked_sub(U256::from(self.validator.commission))
744            .context("overflow")?;
745        let delegators_reward = delegators_ratio_basis_points
746            .checked_mul(total_reward)
747            .context("overflow")?;
748
749        // Distribute delegator rewards
750        let total_stake = self.validator.stake;
751        let mut delegators_total_reward_distributed = U256::from(0);
752        for (delegator_address, delegator_stake) in &self.validator.delegators {
753            let delegator_reward = RewardAmount::from(
754                (delegator_stake
755                    .checked_mul(delegators_reward)
756                    .context("overflow")?
757                    .checked_div(total_stake)
758                    .context("overflow")?)
759                .checked_div(U256::from(COMMISSION_BASIS_POINTS))
760                .context("overflow")?,
761            );
762
763            delegators_total_reward_distributed += delegator_reward.0;
764
765            rewards.push((*delegator_address, delegator_reward));
766        }
767
768        let leader_commission = total_reward
769            .checked_sub(delegators_total_reward_distributed)
770            .context("overflow")?;
771
772        Ok(ComputedRewards::new(
773            rewards,
774            self.validator.account,
775            leader_commission.into(),
776        ))
777    }
778}
779
780/// Distributes the block reward for a given block height
781///
782/// Rewards are only distributed if the block belongs to an epoch beyond the second epoch.
783///
784/// The function also calculates the appropriate reward (fixed or dynamic) based
785/// on the protocol version.
786pub async fn distribute_block_reward(
787    instance_state: &NodeState,
788    validated_state: &mut ValidatedState,
789    parent_leaf: &Leaf2,
790    view_number: ViewNumber,
791    version: vbs::version::Version,
792) -> anyhow::Result<Option<RewardDistributor>> {
793    let height = parent_leaf.height() + 1;
794
795    let epoch_height = instance_state
796        .epoch_height
797        .context("epoch height not found")?;
798    let epoch = EpochNumber::new(epoch_from_block_number(height, epoch_height));
799    let coordinator = instance_state.coordinator.clone();
800    let first_epoch = {
801        coordinator
802            .membership()
803            .read()
804            .await
805            .first_epoch()
806            .context("The first epoch was not set.")?
807    };
808
809    // Rewards are distributed only if the current epoch is not the first or second epoch
810    // this is because we don't have stake table from the contract for the first two epochs
811    if epoch <= first_epoch + 1 {
812        return Ok(None);
813    }
814
815    // Determine who the block leader is for this view and ensure missing block
816    // rewards are fetched from peers if needed.
817
818    let leader = get_leader_and_fetch_missing_rewards(
819        instance_state,
820        validated_state,
821        parent_leaf,
822        view_number,
823    )
824    .await?;
825
826    let parent_header = parent_leaf.block_header();
827    // Initialize the total rewards distributed so far in this block.
828
829    let mut previously_distributed = parent_header.total_reward_distributed().unwrap_or_default();
830
831    // Decide whether to use a fixed or dynamic block reward.
832    let block_reward = if version >= DrbAndHeaderUpgradeVersion::version() {
833        let block_reward = instance_state
834            .block_reward(Some(EpochNumber::new(*epoch)))
835            .await
836            .with_context(|| format!("block reward is None for epoch {epoch}"))?;
837
838        // If the current block is the start block of the new v4 version,
839        // we use *fixed block reward* for calculating the total rewards distributed so far.
840        if parent_header.version() == EpochVersion::version() {
841            ensure!(
842                instance_state.epoch_start_block != 0,
843                "epoch_start_block is zero"
844            );
845
846            let fixed_block_reward = instance_state
847                .block_reward(None)
848                .await
849                .with_context(|| format!("block reward is None for epoch {epoch}"))?;
850
851            // Compute the first block where rewards start being distributed.
852            // Rewards begin only after the first two epochs
853            // Example:
854            //   epoch_height = 10, first_epoch = 1
855            // first_reward_block = 31
856            let first_reward_block = (*first_epoch + 2) * epoch_height + 1;
857
858            // If v4 upgrade started at block 101, and first_reward_block is 31:
859            // total_distributed = (101 - 31) * fixed_block_reward
860            let blocks = height
861                .checked_sub(first_reward_block)
862                .context("height - epoch_start_block underflowed")?;
863
864            previously_distributed = U256::from(blocks)
865                .checked_mul(fixed_block_reward.0)
866                .context("overflow during total_distributed calculation")?
867                .into();
868        }
869
870        block_reward
871    } else {
872        instance_state
873            .block_reward(None)
874            .await
875            .with_context(|| format!("fixed block reward is None for epoch {epoch}"))?
876    };
877
878    if block_reward.0.is_zero() {
879        tracing::info!("block reward is zero. height={height}. epoch={epoch}");
880        return Ok(None);
881    }
882
883    let mut reward_distributor =
884        RewardDistributor::new(leader, block_reward, previously_distributed);
885
886    reward_distributor.apply_rewards(version, validated_state)?;
887
888    Ok(Some(reward_distributor))
889}
890
891pub async fn get_leader_and_fetch_missing_rewards(
892    instance_state: &NodeState,
893    validated_state: &mut ValidatedState,
894    parent_leaf: &Leaf2,
895    view: ViewNumber,
896) -> anyhow::Result<Validator<BLSPubKey>> {
897    let parent_height = parent_leaf.height();
898    let parent_view = parent_leaf.view_number();
899    let new_height = parent_height + 1;
900
901    let epoch_height = instance_state
902        .epoch_height
903        .context("epoch height not found")?;
904    if epoch_height == 0 {
905        bail!("epoch height is 0. can not catchup reward accounts");
906    }
907    let epoch = EpochNumber::new(epoch_from_block_number(new_height, epoch_height));
908
909    let coordinator = instance_state.coordinator.clone();
910
911    let epoch_membership = coordinator.membership_for_epoch(Some(epoch)).await?;
912    let membership = epoch_membership.coordinator.membership().read().await;
913
914    let leader: BLSPubKey = membership
915        .leader(view, Some(epoch))
916        .context(format!("leader for epoch {epoch:?} not found"))?;
917
918    let validator = membership
919        .get_validator_config(&epoch, leader)
920        .context("validator not found")?;
921    drop(membership);
922
923    let mut reward_accounts = HashSet::new();
924    reward_accounts.insert(validator.account.into());
925    let delegators = validator
926        .delegators
927        .keys()
928        .cloned()
929        .map(|a| a.into())
930        .collect::<Vec<RewardAccountV2>>();
931
932    reward_accounts.extend(delegators.clone());
933
934    let parent_header = parent_leaf.block_header();
935
936    if parent_header.version() <= EpochVersion::version() {
937        let accts: HashSet<_> = reward_accounts
938            .into_iter()
939            .map(RewardAccountV1::from)
940            .collect();
941        let missing_reward_accts = validated_state.forgotten_reward_accounts_v1(accts);
942
943        if !missing_reward_accts.is_empty() {
944            tracing::warn!(
945                parent_height,
946                ?parent_view,
947                ?missing_reward_accts,
948                "fetching missing v1 reward accounts from peers"
949            );
950
951            let missing_account_proofs = instance_state
952                .state_catchup
953                .fetch_reward_accounts_v1(
954                    instance_state,
955                    parent_height,
956                    parent_view,
957                    validated_state.reward_merkle_tree_v1.commitment(),
958                    missing_reward_accts,
959                )
960                .await?;
961
962            for proof in missing_account_proofs.iter() {
963                proof
964                    .remember(&mut validated_state.reward_merkle_tree_v1)
965                    .expect("proof previously verified");
966            }
967        }
968    } else {
969        let missing_reward_accts = validated_state.forgotten_reward_accounts_v2(reward_accounts);
970
971        if !missing_reward_accts.is_empty() {
972            tracing::warn!(
973                parent_height,
974                ?parent_view,
975                ?missing_reward_accts,
976                "fetching missing reward accounts from peers"
977            );
978
979            let missing_account_proofs = instance_state
980                .state_catchup
981                .fetch_reward_accounts_v2(
982                    instance_state,
983                    parent_height,
984                    parent_view,
985                    validated_state.reward_merkle_tree_v2.commitment(),
986                    missing_reward_accts,
987                )
988                .await?;
989
990            for proof in missing_account_proofs.iter() {
991                proof
992                    .remember(&mut validated_state.reward_merkle_tree_v2)
993                    .expect("proof previously verified");
994            }
995        }
996    }
997
998    Ok(validator)
999}
1000
1001#[cfg(test)]
1002pub mod tests {
1003
1004    use super::*;
1005
1006    // TODO: current tests are just sanity checks, we need more.
1007
1008    #[test]
1009    fn test_reward_calculation_sanity_checks() {
1010        // This test verifies that the total rewards distributed match the block reward.
1011        // Due to rounding effects in distribution, the validator may receive a slightly higher amount
1012        // because the remainder after delegator distribution is sent to the validator.
1013
1014        let validator = Validator::mock();
1015        let mut distributor = RewardDistributor::new(
1016            validator,
1017            RewardAmount(U256::from(1902000000000000000_u128)),
1018            U256::ZERO.into(),
1019        );
1020        let rewards = distributor.compute_rewards().unwrap();
1021        let total = |rewards: ComputedRewards| {
1022            rewards
1023                .all_rewards()
1024                .iter()
1025                .fold(U256::ZERO, |acc, (_, r)| acc + r.0)
1026        };
1027        assert_eq!(total(rewards.clone()), distributor.block_reward.0);
1028
1029        distributor.validator.commission = 0;
1030        let rewards = distributor.compute_rewards().unwrap();
1031        assert_eq!(total(rewards.clone()), distributor.block_reward.0);
1032
1033        distributor.validator.commission = 10000;
1034        let rewards = distributor.compute_rewards().unwrap();
1035        assert_eq!(total(rewards.clone()), distributor.block_reward.0);
1036        let leader_commission = rewards.leader_commission();
1037        assert_eq!(*leader_commission, distributor.block_reward);
1038
1039        distributor.validator.commission = 10001;
1040        assert!(distributor
1041            .compute_rewards()
1042            .err()
1043            .unwrap()
1044            .to_string()
1045            .contains("must not exceed"));
1046    }
1047
1048    #[test]
1049    fn test_compute_rewards_validator_commission() {
1050        let validator = Validator::mock();
1051        let mut distributor = RewardDistributor::new(
1052            validator.clone(),
1053            RewardAmount(U256::from(1902000000000000000_u128)),
1054            U256::ZERO.into(),
1055        );
1056        distributor.validator.commission = 0;
1057
1058        let rewards = distributor.compute_rewards().unwrap();
1059
1060        let leader_commission = rewards.leader_commission();
1061        let percentage =
1062            leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1063        assert_eq!(percentage, U256::ZERO);
1064
1065        // 3%
1066        distributor.validator.commission = 300;
1067
1068        let rewards = distributor.compute_rewards().unwrap();
1069        let leader_commission = rewards.leader_commission();
1070        let percentage =
1071            leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1072        println!("percentage: {percentage:?}");
1073        assert_eq!(percentage, U256::from(300));
1074
1075        //100%
1076        distributor.validator.commission = 10000;
1077
1078        let rewards = distributor.compute_rewards().unwrap();
1079        let leader_commission = rewards.leader_commission();
1080        assert_eq!(*leader_commission, distributor.block_reward);
1081    }
1082}