diff --git a/fuzz/src/bin/gen_target.sh b/fuzz/src/bin/gen_target.sh index b4f0c7a12b9..00546baf1ed 100755 --- a/fuzz/src/bin/gen_target.sh +++ b/fuzz/src/bin/gen_target.sh @@ -29,6 +29,7 @@ GEN_TEST fromstr_to_netaddress GEN_TEST feature_flags GEN_TEST lsps_message GEN_TEST fs_store +GEN_TEST resource_manager GEN_TEST msg_accept_channel msg_targets:: GEN_TEST msg_announcement_signatures msg_targets:: diff --git a/fuzz/src/bin/resource_manager_target.rs b/fuzz/src/bin/resource_manager_target.rs new file mode 100644 index 00000000000..792022ec666 --- /dev/null +++ b/fuzz/src/bin/resource_manager_target.rs @@ -0,0 +1,133 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +// This file is auto-generated by gen_target.sh based on target_template.txt +// To modify it, modify target_template.txt and run gen_target.sh instead. + +#![cfg_attr(feature = "libfuzzer_fuzz", no_main)] +#![cfg_attr(rustfmt, rustfmt_skip)] + +#[cfg(not(fuzzing))] +compile_error!("Fuzz targets need cfg=fuzzing"); + +#[cfg(not(hashes_fuzz))] +compile_error!("Fuzz targets need cfg=hashes_fuzz"); + +#[cfg(not(secp256k1_fuzz))] +compile_error!("Fuzz targets need cfg=secp256k1_fuzz"); + +extern crate lightning_fuzz; +use lightning_fuzz::resource_manager::*; +use lightning_fuzz::utils::test_logger; + +#[cfg(feature = "afl")] +#[macro_use] extern crate afl; +#[cfg(feature = "afl")] +fn main() { + fuzz!(|data| { + resource_manager_test(&data, test_logger::DevNull {}); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + resource_manager_test(&data, test_logger::DevNull {}); + }); + } +} + +#[cfg(feature = "libfuzzer_fuzz")] +#[macro_use] extern crate libfuzzer_sys; +#[cfg(feature = "libfuzzer_fuzz")] +fuzz_target!(|data: &[u8]| { + resource_manager_test(data, test_logger::DevNull {}); +}); + +#[cfg(feature = "stdin_fuzz")] +fn main() { + use std::io::Read; + + // On macOS, panic=abort causes the process to send SIGABRT which can leave it + // stuck in an uninterruptible state due to the ReportCrash daemon. Using + // process::exit in a panic hook avoids this by terminating cleanly. + #[cfg(target_os = "macos")] + std::panic::set_hook(Box::new(|panic_info| { + use std::io::Write; + let _ = std::io::stdout().flush(); + eprintln!("{}\n{}", panic_info, std::backtrace::Backtrace::force_capture()); + let _ = std::io::stderr().flush(); + std::process::exit(1); + })); + + let mut data = Vec::with_capacity(8192); + std::io::stdin().read_to_end(&mut data).unwrap(); + resource_manager_test(&data, lightning_fuzz::utils::test_logger::Stdout {}); +} + +#[test] +fn run_test_cases() { + use std::fs; + use std::io::Read; + use lightning_fuzz::utils::test_logger::StringBuffer; + + use std::sync::{atomic, Arc}; + { + let data: Vec = vec![0]; + resource_manager_test(&data, test_logger::DevNull {}); + } + let mut threads = Vec::new(); + let threads_running = Arc::new(atomic::AtomicUsize::new(0)); + if let Ok(tests) = fs::read_dir("test_cases/resource_manager") { + for test in tests { + let mut data: Vec = Vec::new(); + let path = test.unwrap().path(); + fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap(); + threads_running.fetch_add(1, atomic::Ordering::AcqRel); + + let thread_count_ref = Arc::clone(&threads_running); + let main_thread_ref = std::thread::current(); + threads.push((path.file_name().unwrap().to_str().unwrap().to_string(), + std::thread::spawn(move || { + let string_logger = StringBuffer::new(); + + let panic_logger = string_logger.clone(); + let res = if ::std::panic::catch_unwind(move || { + resource_manager_test(&data, panic_logger); + }).is_err() { + Some(string_logger.into_string()) + } else { None }; + thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel); + main_thread_ref.unpark(); + res + }) + )); + while threads_running.load(atomic::Ordering::Acquire) > 32 { + std::thread::park(); + } + } + } + let mut failed_outputs = Vec::new(); + for (test, thread) in threads.drain(..) { + if let Some(output) = thread.join().unwrap() { + println!("\nOutput of {}:\n{}\n", test, output); + failed_outputs.push(test); + } + } + if !failed_outputs.is_empty() { + println!("Test cases which failed: "); + for case in failed_outputs { + println!("{}", case); + } + panic!(); + } +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index 582fa346c54..de734cd206f 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -42,6 +42,7 @@ pub mod peer_crypt; pub mod process_network_graph; pub mod process_onion_failure; pub mod refund_deser; +pub mod resource_manager; pub mod router; pub mod static_invoice_deser; pub mod zbase32; diff --git a/fuzz/src/resource_manager.rs b/fuzz/src/resource_manager.rs new file mode 100644 index 00000000000..55d6d31a36e --- /dev/null +++ b/fuzz/src/resource_manager.rs @@ -0,0 +1,280 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +use crate::utils::test_logger; + +use lightning::ln::resource_manager::{ + DefaultResourceManager, ForwardingOutcome, PendingHTLCReplay, ResourceManagerConfig, +}; +use lightning::sign::EntropySource; +use lightning::util::ser::{ReadableArgs, Writeable}; + +use lightning::util::hash_tables::new_hash_map; + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; + +struct CountingEntropy { + counter: AtomicU64, +} + +impl CountingEntropy { + fn new() -> Self { + CountingEntropy { counter: AtomicU64::new(0) } + } +} + +impl EntropySource for CountingEntropy { + fn get_secure_random_bytes(&self) -> [u8; 32] { + let ctr = self.counter.fetch_add(1, Ordering::Relaxed); + let mut bytes = [0u8; 32]; + bytes[..8].copy_from_slice(&ctr.to_le_bytes()); + bytes + } +} + +struct TrackedHtlc { + incoming_channel_id: u64, + htlc_id: u64, + outgoing_channel_id: u64, + incoming_amount_msat: u64, + outgoing_amount_msat: u64, + incoming_cltv_expiry: u32, + height_added: u32, + incoming_accountable: bool, + added_at: u64, +} + +const AMOUNT_TABLE: [u64; 4] = [1_000, 100_000, 500_000_000, 2_000_000_000]; +const FEE_TABLE: [u64; 4] = [100, 1_000, 10_000, 100_000]; +const TIME_DELTAS: [u64; 7] = [0, 1, 60, 3600, 86400, 604800, 1209600]; +const CLTV_DELTAS: [u32; 4] = [20, 144, 500, 2016]; +const MAX_HTLCS_TABLE: [u16; 4] = [12, 50, 114, 483]; +const MAX_IN_FLIGHT_TABLE: [u64; 4] = [1_000_000, 100_000_000, 5_000_000_000, 500_000_000_000]; +// (general_allocation_pct, congestion_allocation_pct) +const CONFIG_TABLE: [(u8, u8); 7] = [ + (30, 30), + (60, 10), + (10, 10), + (20, 5), + (50, 30), + (5, 5), + (70, 20), +]; + +#[inline] +pub fn do_test(data: &[u8], _out: Out) { + let entropy_source = CountingEntropy::new(); + + let mut read_pos = 0; + macro_rules! get_slice { + ($len:expr) => {{ + let len = $len as usize; + if data.len() < read_pos + len { + return; + } + read_pos += len; + &data[read_pos - len..read_pos] + }}; + } + let config_byte = get_slice!(1)[0]; + let (general_pct, congestion_pct) = if config_byte < 128 { + (40_u8, 20_u8) + } else { + CONFIG_TABLE[((config_byte - 128) % 7) as usize] + }; + let config = ResourceManagerConfig { + general_allocation_pct: general_pct, + congestion_allocation_pct: congestion_pct, + ..ResourceManagerConfig::default() + }; + let mut rm = DefaultResourceManager::new(config); + + let mut current_time: u64 = 1_700_000_000; + let current_height: u32 = 1000; + let mut channel_limits: HashMap = HashMap::new(); + let mut pending_htlcs: Vec = Vec::new(); + let mut next_htlc_id: u64 = 0; + + loop { + let action_byte = get_slice!(1)[0]; + match action_byte % 7 { + // Add channel + 0 => { + let params = get_slice!(2); + let channel_id = (params[0] % 8) as u64 + 1; + let max_htlcs = MAX_HTLCS_TABLE[(params[1] % 4) as usize]; + let max_in_flight = MAX_IN_FLIGHT_TABLE[(params[1] / 4 % 4) as usize]; + if rm.add_channel(channel_id, max_in_flight, max_htlcs, current_time).is_ok() { + channel_limits.insert(channel_id, (max_in_flight, max_htlcs)); + } + }, + // Remove channel + 1 => { + let idx = get_slice!(1)[0]; + let channel_id = (idx % 8) as u64 + 1; + let _ = rm.remove_channel(channel_id); + channel_limits.remove(&channel_id); + pending_htlcs.retain(|h| { + h.incoming_channel_id != channel_id && h.outgoing_channel_id != channel_id + }); + }, + // Add HTLC + 2 => { + let params = get_slice!(4); + let incoming_id = (params[0] % 8) as u64 + 1; + let outgoing_id = (params[1] % 8) as u64 + 1; + if incoming_id == outgoing_id { + continue; + } + let outgoing_amount = AMOUNT_TABLE[(params[2] % 4) as usize]; + let fee = FEE_TABLE[(params[3] % 4) as usize]; + let incoming_amount = outgoing_amount + fee; + let accountable = params[3] >= 4; + let cltv_delta = CLTV_DELTAS[(params[2] / 4 % 4) as usize]; + let htlc_id = next_htlc_id; + next_htlc_id += 1; + let cltv_expiry = current_height + cltv_delta; + + let result = rm.add_htlc( + incoming_id, + incoming_amount, + cltv_expiry, + outgoing_id, + outgoing_amount, + accountable, + htlc_id, + current_height, + current_time, + &entropy_source, + ); + match result { + Ok(ForwardingOutcome::Forward(_)) => { + pending_htlcs.push(TrackedHtlc { + incoming_channel_id: incoming_id, + htlc_id, + outgoing_channel_id: outgoing_id, + incoming_amount_msat: incoming_amount, + outgoing_amount_msat: outgoing_amount, + incoming_cltv_expiry: cltv_expiry, + height_added: current_height, + incoming_accountable: accountable, + added_at: current_time, + }); + }, + _ => {}, + } + }, + // Resolve HTLC + 3 => { + let params = get_slice!(2); + if pending_htlcs.is_empty() { + continue; + } + let idx = params[0] as usize % pending_htlcs.len(); + let settled = params[1] % 2 == 1; + let htlc = &pending_htlcs[idx]; + let result = rm.resolve_htlc( + htlc.incoming_channel_id, + htlc.htlc_id, + htlc.outgoing_channel_id, + settled, + current_time, + ); + if result.is_ok() { + pending_htlcs.swap_remove(idx); + } + }, + // Serialization roundtrip + 4 => { + let mut buf = Vec::new(); + rm.write(&mut buf).unwrap(); + + let mut ldk_channel_limits = new_hash_map(); + for (k, v) in &channel_limits { + ldk_channel_limits.insert(*k, *v); + } + let deserialized = DefaultResourceManager::read( + &mut &buf[..], + ( + ResourceManagerConfig { + general_allocation_pct: general_pct, + congestion_allocation_pct: congestion_pct, + ..ResourceManagerConfig::default() + }, + &entropy_source, + &ldk_channel_limits, + ), + ); + let deserialized = match deserialized { + Ok(d) => d, + Err(_) => continue, + }; + + let replays: Vec = pending_htlcs + .iter() + .map(|h| PendingHTLCReplay { + incoming_channel_id: h.incoming_channel_id, + incoming_amount_msat: h.incoming_amount_msat, + incoming_htlc_id: h.htlc_id, + incoming_cltv_expiry: h.incoming_cltv_expiry, + incoming_accountable: h.incoming_accountable, + outgoing_channel_id: h.outgoing_channel_id, + outgoing_amount_msat: h.outgoing_amount_msat, + added_at_unix_seconds: h.added_at, + height_added: h.height_added, + }) + .collect(); + + if let Ok(outcomes) = deserialized.replay_pending_htlcs(&replays, &entropy_source) { + let mut to_remove = Vec::new(); + for (i, outcome) in outcomes.iter().enumerate() { + if *outcome == ForwardingOutcome::Fail { + to_remove.push(i); + } + } + for idx in to_remove.into_iter().rev() { + pending_htlcs.swap_remove(idx); + } + } + + rm = deserialized; + }, + // Advance time + 5 => { + let delta_byte = get_slice!(1)[0]; + current_time += TIME_DELTAS[(delta_byte % 7) as usize]; + }, + // Resolve all pending HTLCs + 6 => { + let settled = get_slice!(1)[0] % 2 == 1; + let htlcs: Vec = pending_htlcs.drain(..).collect(); + for htlc in &htlcs { + let _ = rm.resolve_htlc( + htlc.incoming_channel_id, + htlc.htlc_id, + htlc.outgoing_channel_id, + settled, + current_time, + ); + } + }, + _ => unreachable!(), + } + } +} + +pub fn resource_manager_test(data: &[u8], out: Out) { + do_test(data, out); +} + +#[no_mangle] +pub extern "C" fn resource_manager_run(data: *const u8, datalen: usize) { + do_test(unsafe { std::slice::from_raw_parts(data, datalen) }, test_logger::DevNull {}); +} diff --git a/fuzz/targets.h b/fuzz/targets.h index 921439836af..5bd00426eca 100644 --- a/fuzz/targets.h +++ b/fuzz/targets.h @@ -22,6 +22,7 @@ void fromstr_to_netaddress_run(const unsigned char* data, size_t data_len); void feature_flags_run(const unsigned char* data, size_t data_len); void lsps_message_run(const unsigned char* data, size_t data_len); void fs_store_run(const unsigned char* data, size_t data_len); +void resource_manager_run(const unsigned char* data, size_t data_len); void msg_accept_channel_run(const unsigned char* data, size_t data_len); void msg_announcement_signatures_run(const unsigned char* data, size_t data_len); void msg_channel_reestablish_run(const unsigned char* data, size_t data_len); diff --git a/lightning/src/ln/mod.rs b/lightning/src/ln/mod.rs index d6e0b92f1d0..ffb2da63d54 100644 --- a/lightning/src/ln/mod.rs +++ b/lightning/src/ln/mod.rs @@ -24,6 +24,10 @@ pub mod msgs; pub mod onion_payment; pub mod our_peer_storage; pub mod peer_handler; +#[cfg(fuzzing)] +pub mod resource_manager; +#[cfg(not(fuzzing))] +mod resource_manager; pub mod script; pub mod types; diff --git a/lightning/src/ln/resource_manager.rs b/lightning/src/ln/resource_manager.rs new file mode 100644 index 00000000000..cc694e049f9 --- /dev/null +++ b/lightning/src/ln/resource_manager.rs @@ -0,0 +1,2742 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +#![allow(dead_code)] + +use bitcoin::io::Read; +use core::{fmt::Display, time::Duration}; + +use crate::{ + crypto::chacha20::ChaCha20, + io, + ln::{channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS, msgs::DecodeError}, + prelude::{hash_map::Entry, new_hash_map, new_hash_set, HashMap, HashSet}, + sign::EntropySource, + sync::Mutex, + util::ser::{CollectionLength, Readable, ReadableArgs, Writeable, Writer}, +}; + +/// The minimum number of accepted HTLCs required for a channel to be added to the resource +/// manager. With the default bucket allocations of 40%/20%/40% (general/congestion/protected), +/// the general bucket (40%) needs at least 5 usable HTLC slots to function effectively. To +/// ensure the general bucket gets 5 slots at its 40% share, we need at least 12 total HTLC +/// slots. +const MIN_ACCEPTED_HTLCS: u16 = 12; + +/// The minimum `max_htlc_value_in_flight_msat` required for a channel to be added to the resource +/// manager. This corresponds to the default `min_funding_satoshis` of 1000 in +/// [`crate::util::config::ChannelHandshakeLimits`], which is the smallest channel size LDK will +/// accept. +const MIN_MAX_IN_FLIGHT_MSAT: u64 = 1_000_000; + +/// Resolution time in seconds that is considered "good". HTLCs resolved within this period are +/// considered normal and are rewarded in the reputation score. HTLCs resolved slower than this +/// will incur an opportunity cost to penalize slow resolving payments. +const ACCEPTABLE_RESOLUTION_PERIOD_SECS: u8 = 90; + +/// The maximum time (in seconds) that a HTLC can be held. Corresponds to the largest cltv delta +/// allowed in the protocol which is 2016 blocks. Assuming 10 minute blocks, this is roughly 2 +/// weeks. +const REVENUE_WINDOW: u64 = 2016 * 10 * 60; + +/// Configuration parameters for the resource manager. +/// +/// This configuration controls how the resource manager allocates channel resources (HTLC slots +/// and liquidity) across three buckets (general, congestion, and protected). +pub struct ResourceManagerConfig { + /// The percentage of channel resources allocated to the general bucket. + /// The general bucket is available to all traffic with basic denial-of-service protections. + /// + /// Default: 40% + pub general_allocation_pct: u8, + + /// The percentage of channel resources allocated to the congestion bucket. + /// The congestion bucket is used when the general bucket is saturated. It allows an outgoing + /// channel that does not have reputation to have a chance of getting the HTLC forwarded. + /// + /// Default: 20% + pub congestion_allocation_pct: u8, + + /// The amount of time a HTLC is allowed to resolve in that classifies as "good" behavior. + /// HTLCs resolved within this period are rewarded in the reputation score. HTLCs resolved + /// slower than this will incur an opportunity cost penalty. + /// + /// Default: 90 seconds + pub resolution_period: Duration, + + /// The rolling window over which we track the revenue on the incoming channel. + /// + /// This corresponds to the largest cltv delta from the current block height that a node will + /// allow a HTLC to set before failing it with `expiry_too_far`. Assuming 10 minute blocks, + /// the default 2016 blocks is roughly 2 weeks. + /// + /// Default: 2016 blocks * 10 minutes = ~2 weeks + pub revenue_window: Duration, + + /// A multiplier applied to [`revenue_window`] to determine the rolling window over which an + /// outgoing channel's forwarding history is considered when calculating reputation. The + /// outgoing channel reputation is tracked over a period of `revenue_window * reputation_multiplier`. + /// + /// Default: 12 (meaning reputation is tracked over 12 * 2 weeks = 24 weeks) + /// + /// [`revenue_window`]: Self::revenue_window + pub reputation_multiplier: u8, +} + +impl Default for ResourceManagerConfig { + fn default() -> ResourceManagerConfig { + Self { + general_allocation_pct: 40, + congestion_allocation_pct: 20, + resolution_period: Duration::from_secs(ACCEPTABLE_RESOLUTION_PERIOD_SECS.into()), + revenue_window: Duration::from_secs(REVENUE_WINDOW), + reputation_multiplier: 12, + } + } +} + +/// The outcome of an HTLC forwarding decision. +#[derive(PartialEq, Eq, Debug)] +pub enum ForwardingOutcome { + /// Forward the HTLC with the specified accountable signal. + Forward(bool), + /// Fail to forward the HTLC. + Fail, +} + +impl Display for ForwardingOutcome { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ForwardingOutcome::Forward(signal) => { + write!(f, "Forward as {}", if *signal { "accountable" } else { "unaccountable" }) + }, + ForwardingOutcome::Fail => { + write!(f, "Fail") + }, + } + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +enum BucketAssigned { + General, + Congestion, + Protected, +} + +struct GeneralBucket { + /// Our SCID + scid: u64, + + total_slots: u16, + total_liquidity: u64, + + /// The number of slots in the general bucket that each forwarding channel pair gets. + per_channel_slots: u8, + /// The liquidity amount of each slot in the general bucket that each forwarding channel pair + /// gets. + per_slot_msat: u64, + + /// Tracks the occupancy of HTLC slots in the bucket where the index represents the slot + /// number and the optional value indicates which channel is currently using the slot. + slots_occupied: Vec>, + + /// SCID -> (slots assigned, salt) + /// Maps short channel IDs to the slots that the channel is allowed to use and the salt. The + /// salt is stored to deterministically generate the slots for each channel on restarts. + channels_slots: HashMap, [u8; 32])>, +} + +impl GeneralBucket { + fn new(scid: u64, slots_allocated: u16, liquidity_allocated: u64) -> Self { + let general_slot_allocation = + u8::max(5, u8::try_from((slots_allocated * 5).div_ceil(100)).unwrap()); + + let general_liquidity_allocation = + liquidity_allocated * general_slot_allocation as u64 / slots_allocated as u64; + GeneralBucket { + scid, + total_slots: slots_allocated, + total_liquidity: liquidity_allocated, + per_channel_slots: general_slot_allocation, + per_slot_msat: general_liquidity_allocation, + slots_occupied: vec![None; slots_allocated as usize], + channels_slots: new_hash_map(), + } + } + + /// Returns the available slots that could be used by the outgoing scid for the specified + /// htlc amount. + fn slots_for_amount( + &mut self, outgoing_scid: u64, htlc_amount_msat: u64, entropy_source: &ES, + ) -> Result>, ()> { + let slots_needed = u64::max(1, htlc_amount_msat.div_ceil(self.per_slot_msat)); + + let channel_entry = match self.channels_slots.entry(outgoing_scid) { + Entry::Occupied(e) => e.into_mut(), + Entry::Vacant(entry) => { + let (slots, salt) = assign_slots_for_channel( + self.scid, + outgoing_scid, + None, + entropy_source, + self.per_channel_slots, + self.total_slots, + )?; + entry.insert((slots, salt)) + }, + }; + + let slots_to_use: Vec = channel_entry + .0 + .iter() + .filter(|idx| match self.slots_occupied.get(**idx as usize) { + Some(is_occupied) => is_occupied.is_none(), + None => { + debug_assert!(false, "assigned slot {} is not present in slots_occupied", idx); + false + }, + }) + .take(slots_needed as usize) + .copied() + .collect(); + + if (slots_to_use.len() as u64) < slots_needed { + Ok(None) + } else { + Ok(Some(slots_to_use)) + } + } + + fn can_add_htlc( + &mut self, outgoing_scid: u64, htlc_amount_msat: u64, entropy_source: &ES, + ) -> Result { + Ok(self.slots_for_amount(outgoing_scid, htlc_amount_msat, entropy_source)?.is_some()) + } + + fn add_htlc( + &mut self, outgoing_scid: u64, htlc_amount_msat: u64, entropy_source: &ES, + ) -> Result, ()> { + match self.slots_for_amount(outgoing_scid, htlc_amount_msat, entropy_source)? { + Some(slots) => { + for slot_idx in &slots { + debug_assert!(self.slots_occupied[*slot_idx as usize].is_none()); + self.slots_occupied[*slot_idx as usize] = Some(outgoing_scid); + } + Ok(slots) + }, + None => Err(()), + } + } + + fn remove_htlc(&mut self, outgoing_scid: u64, htlc_amount_msat: u64) -> Result<(), ()> { + let channel_slots = match self.channels_slots.get(&outgoing_scid) { + Some((slots, _)) => slots, + None => return Err(()), + }; + + let slots_needed = u64::max(1, htlc_amount_msat.div_ceil(self.per_slot_msat)); + + let mut slots_used_by_channel: Vec = channel_slots + .iter() + .filter(|slot_idx| self.slots_occupied[**slot_idx as usize] == Some(outgoing_scid)) + .copied() + .collect(); + + if slots_needed > slots_used_by_channel.len() as u64 { + return Err(()); + } + let slots_released: Vec = + slots_used_by_channel.drain(0..slots_needed as usize).collect(); + + for slot_idx in slots_released { + debug_assert!(self.slots_occupied[slot_idx as usize] == Some(outgoing_scid)); + self.slots_occupied[slot_idx as usize] = None; + } + Ok(()) + } + + fn remove_channel_slots(&mut self, outgoing_scid: u64) { + if let Some((slots, _)) = self.channels_slots.remove(&outgoing_scid) { + for slot_idx in slots { + if self.slots_occupied[slot_idx as usize] == Some(outgoing_scid) { + self.slots_occupied[slot_idx as usize] = None; + } + } + } + } +} + +fn assign_slots_for_channel( + incoming_scid: u64, outgoing_scid: u64, salt: Option<[u8; 32]>, entropy_source: &ES, + per_channel_slots: u8, total_slots: u16, +) -> Result<(Vec, [u8; 32]), ()> { + debug_assert_ne!(incoming_scid, outgoing_scid); + + let mut channel_slots = Vec::with_capacity(per_channel_slots.into()); + let mut slots_assigned_counter = 0; + let salt = salt.unwrap_or(entropy_source.get_secure_random_bytes()); + + let mut nonce = [0u8; 12]; + nonce[..4].copy_from_slice(&incoming_scid.to_be_bytes()[..4]); + nonce[4..].copy_from_slice(&outgoing_scid.to_be_bytes()); + let mut prng = ChaCha20::new(&salt, &nonce); + let mut buf = [0u8; 4]; + + let max_attempts = per_channel_slots * 10; + for _ in 0..max_attempts { + if slots_assigned_counter == per_channel_slots { + break; + } + + prng.process_in_place(&mut buf); + let slot_idx: u16 = (u32::from_le_bytes(buf) % total_slots as u32) as u16; + if !channel_slots.contains(&slot_idx) { + channel_slots.push(slot_idx); + slots_assigned_counter += 1; + } + } + + if slots_assigned_counter < per_channel_slots { + return Err(()); + } + + Ok((channel_slots, salt)) +} + +struct GeneralBucketData { + scid: u64, + channel_salts: HashMap, +} + +impl_writeable_tlv_based!(GeneralBucketData, { + (1, scid, required), + (3, channel_salts, required), +}); + +struct BucketResources { + slots_allocated: u16, + slots_used: u16, + liquidity_allocated: u64, + liquidity_used: u64, +} + +impl BucketResources { + fn new(slots_allocated: u16, liquidity_allocated: u64) -> Self { + BucketResources { slots_allocated, slots_used: 0, liquidity_allocated, liquidity_used: 0 } + } + + fn resources_available(&self, htlc_amount_msat: u64) -> bool { + return (self.liquidity_used + htlc_amount_msat <= self.liquidity_allocated) + && (self.slots_used < self.slots_allocated); + } + + fn add_htlc(&mut self, htlc_amount_msat: u64) -> Result<(), ()> { + if !self.resources_available(htlc_amount_msat) { + return Err(()); + } + + self.slots_used += 1; + self.liquidity_used += htlc_amount_msat; + Ok(()) + } + + fn remove_htlc(&mut self, htlc_amount_msat: u64) -> Result<(), ()> { + if self.slots_used == 0 || self.liquidity_used < htlc_amount_msat { + return Err(()); + } + self.slots_used -= 1; + self.liquidity_used -= htlc_amount_msat; + Ok(()) + } +} + +struct BucketAllocations { + general_slots: u16, + general_liquidity: u64, + congestion_slots: u16, + congestion_liquidity: u64, + protected_slots: u16, + protected_liquidity: u64, +} + +fn bucket_allocations( + max_accepted_htlcs: u16, max_htlc_value_in_flight_msat: u64, general_pct: u8, + congestion_pct: u8, +) -> BucketAllocations { + let general_slots = (max_accepted_htlcs as f64 * general_pct as f64 / 100.0).round() as u16; + let general_liquidity = + (max_htlc_value_in_flight_msat as f64 * general_pct as f64 / 100.0).round() as u64; + + let congestion_slots = + (max_accepted_htlcs as f64 * congestion_pct as f64 / 100.0).round() as u16; + let congestion_liquidity = + (max_htlc_value_in_flight_msat as f64 * congestion_pct as f64 / 100.0).round() as u64; + + let protected_slots = max_accepted_htlcs - general_slots - congestion_slots; + let protected_liquidity = + max_htlc_value_in_flight_msat - general_liquidity - congestion_liquidity; + + BucketAllocations { + general_slots, + general_liquidity, + congestion_slots, + congestion_liquidity, + protected_slots, + protected_liquidity, + } +} + +#[derive(Debug, Clone)] +struct PendingHTLC { + incoming_amount_msat: u64, + fee: u64, + outgoing_accountable: bool, + added_at_unix_seconds: u64, + in_flight_risk: u64, + bucket: BucketAssigned, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct HtlcRef { + incoming_channel_id: u64, + htlc_id: u64, +} + +struct Channel { + max_htlc_value_in_flight_msat: u64, + max_accepted_htlcs: u16, + + /// The reputation this channel has accrued as an outgoing link. + outgoing_reputation: DecayingAverage, + + /// The revenue this channel has earned us as an incoming link. + incoming_revenue: AggregatedWindowAverage, + + /// HTLC Ref incoming channel -> pending HTLC outgoing. + /// It tracks all the pending HTLCs where this channel is the outgoing link. + pending_htlcs: HashMap, + + general_bucket: GeneralBucket, + congestion_bucket: BucketResources, + /// SCID -> unix seconds timestamp + /// Tracks which channels have misused the congestion bucket and the unix timestamp. + last_congestion_misuse: HashMap, + protected_bucket: BucketResources, +} + +impl Channel { + fn new( + scid: u64, max_htlc_value_in_flight_msat: u64, max_accepted_htlcs: u16, + general_bucket_pct: u8, congestion_bucket_pct: u8, reputation_window: Duration, + revenue_window_weeks: u8, revenue_week_avg: u8, timestamp_unix_secs: u64, + ) -> Result { + if max_accepted_htlcs > 483 + || max_htlc_value_in_flight_msat >= TOTAL_BITCOIN_SUPPLY_SATOSHIS * 1000 + { + return Err(()); + } + + if max_accepted_htlcs < MIN_ACCEPTED_HTLCS + || max_htlc_value_in_flight_msat < MIN_MAX_IN_FLIGHT_MSAT + { + return Err(()); + } + + if general_bucket_pct + congestion_bucket_pct >= 100 { + return Err(()); + } + + let alloc = bucket_allocations( + max_accepted_htlcs, + max_htlc_value_in_flight_msat, + general_bucket_pct, + congestion_bucket_pct, + ); + if alloc.general_slots < 5 || alloc.congestion_slots == 0 || alloc.protected_slots == 0 { + return Err(()); + } + + Ok(Channel { + max_htlc_value_in_flight_msat, + max_accepted_htlcs, + outgoing_reputation: DecayingAverage::new(timestamp_unix_secs, reputation_window), + incoming_revenue: AggregatedWindowAverage::new( + revenue_week_avg, + revenue_window_weeks, + timestamp_unix_secs, + ), + pending_htlcs: new_hash_map(), + general_bucket: GeneralBucket::new(scid, alloc.general_slots, alloc.general_liquidity), + congestion_bucket: BucketResources::new( + alloc.congestion_slots, + alloc.congestion_liquidity, + ), + last_congestion_misuse: new_hash_map(), + protected_bucket: BucketResources::new( + alloc.protected_slots, + alloc.protected_liquidity, + ), + }) + } + + fn general_available( + &mut self, incoming_amount_msat: u64, outgoing_channel_id: u64, entropy_source: &ES, + ) -> Result { + Ok(self.general_bucket.can_add_htlc( + outgoing_channel_id, + incoming_amount_msat, + entropy_source, + )?) + } + + fn congestion_eligible( + &mut self, pending_htlcs_in_congestion: bool, incoming_amount_msat: u64, + outgoing_channel_id: u64, at_timestamp: u64, + ) -> Result { + Ok(!pending_htlcs_in_congestion + && self.can_add_htlc_congestion( + outgoing_channel_id, + incoming_amount_msat, + at_timestamp, + )?) + } + + fn misused_congestion(&mut self, channel_id: u64, misuse_timestamp: u64) { + self.last_congestion_misuse.insert(channel_id, misuse_timestamp); + } + + // Returns whether the outgoing channel has misused the congestion bucket in the last two + // weeks. + fn has_misused_congestion( + &mut self, outgoing_scid: u64, at_timestamp: u64, + ) -> Result { + match self.last_congestion_misuse.entry(outgoing_scid) { + Entry::Vacant(_) => Ok(false), + Entry::Occupied(last_misuse) => { + if at_timestamp < *last_misuse.get() { + return Err(()); + } + // If the last misuse of the congestion bucket was over more than two + // weeks ago, remove the entry. + const TWO_WEEKS: u64 = 2016 * 10 * 60; + let since_last_misuse = at_timestamp - last_misuse.get(); + if since_last_misuse < TWO_WEEKS { + return Ok(true); + } else { + last_misuse.remove(); + return Ok(false); + } + }, + } + } + + fn can_add_htlc_congestion( + &mut self, channel_id: u64, htlc_amount_msat: u64, at_timestamp: u64, + ) -> Result { + let congestion_resources_available = + self.congestion_bucket.resources_available(htlc_amount_msat); + let misused_congestion = self.has_misused_congestion(channel_id, at_timestamp)?; + + let below_liquidity_limit = htlc_amount_msat + <= self.congestion_bucket.liquidity_allocated + / self.congestion_bucket.slots_allocated as u64; + + Ok(congestion_resources_available && !misused_congestion && below_liquidity_limit) + } + + fn pending_htlcs_in_congestion(&self, channel_id: u64) -> bool { + self.pending_htlcs + .iter() + .find(|(htlc_ref, pending_htlc)| { + htlc_ref.incoming_channel_id == channel_id + && pending_htlc.bucket == BucketAssigned::Congestion + }) + .is_some() + } + + fn sufficient_reputation( + &mut self, in_flight_htlc_risk: u64, outgoing_reputation: i64, + outgoing_in_flight_risk: u64, at_timestamp: u64, + ) -> Result { + let incoming_revenue_threshold = self.incoming_revenue.value_at_timestamp(at_timestamp)?; + + Ok(outgoing_reputation + .saturating_sub(i64::try_from(outgoing_in_flight_risk).unwrap_or(i64::MAX)) + .saturating_sub(i64::try_from(in_flight_htlc_risk).unwrap_or(i64::MAX)) + >= incoming_revenue_threshold) + } + + fn outgoing_in_flight_risk(&self) -> u64 { + // We only account the in-flight risk for HTLCs that are accountable + self.pending_htlcs + .iter() + .map(|htlc| if htlc.1.outgoing_accountable { htlc.1.in_flight_risk } else { 0 }) + .sum() + } +} + +impl Writeable for Channel { + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + let general_bucket_data = GeneralBucketData { + scid: self.general_bucket.scid, + channel_salts: self + .general_bucket + .channels_slots + .iter() + .map(|(scid, (_slots, salt))| (*scid, *salt)) + .collect(), + }; + write_tlv_fields!(writer, { + (1, self.outgoing_reputation, required), + (3, self.incoming_revenue, required), + (5, general_bucket_data, required), + (7, self.last_congestion_misuse, required), + }); + Ok(()) + } +} + +impl ReadableArgs<(u64, u16, &ResourceManagerConfig, &ES)> for Channel { + fn read( + reader: &mut R, args: (u64, u16, &ResourceManagerConfig, &ES), + ) -> Result { + let (max_htlc_value_in_flight_msat, max_accepted_htlcs, config, entropy_source) = args; + _init_and_read_len_prefixed_tlv_fields!(reader, { + (1, outgoing_reputation, required), + (3, incoming_revenue, required), + (5, general_bucket_data, required), + (7, last_congestion_misuse, required), + }); + + let general_bucket_data: GeneralBucketData = general_bucket_data.0.unwrap(); + + let alloc = bucket_allocations( + max_accepted_htlcs, + max_htlc_value_in_flight_msat, + config.general_allocation_pct, + config.congestion_allocation_pct, + ); + + let mut general_bucket = GeneralBucket::new( + general_bucket_data.scid, + alloc.general_slots, + alloc.general_liquidity, + ); + for (outgoing_scid, salt) in general_bucket_data.channel_salts { + let entry = assign_slots_for_channel( + general_bucket.scid, + outgoing_scid, + Some(salt), + entropy_source, + general_bucket.per_channel_slots, + general_bucket.total_slots, + ) + .map_err(|_| DecodeError::InvalidValue)?; + general_bucket.channels_slots.insert(outgoing_scid, entry); + } + + Ok(Channel { + max_htlc_value_in_flight_msat, + max_accepted_htlcs, + outgoing_reputation: outgoing_reputation.0.unwrap(), + incoming_revenue: incoming_revenue.0.unwrap(), + general_bucket, + pending_htlcs: new_hash_map(), + congestion_bucket: BucketResources::new( + alloc.congestion_slots, + alloc.congestion_liquidity, + ), + last_congestion_misuse: last_congestion_misuse.0.unwrap(), + protected_bucket: BucketResources::new( + alloc.protected_slots, + alloc.protected_liquidity, + ), + }) + } +} + +/// An implementation for managing channel resources and informing HTLC forwarding decisions. It +/// implements the core of the mitigation as proposed in https://github.com/lightning/bolts/pull/1280. +pub struct DefaultResourceManager { + config: ResourceManagerConfig, + channels: Mutex>, + /// Tracks HTLCs that returned [`ForwardingOutcome::Fail`] during [`Self::replay_pending_htlcs`]. + /// When [`Self::resolve_htlc`] is called for one of these, it is silently ignored instead of + /// returning an error. This should not happen often but there is a chance that HTLCs are + /// failed on replay even if it was accepted previously. This could happen if reputation has + /// decayed and reputation check fails. + failed_replays: Mutex>, +} + +impl DefaultResourceManager { + pub fn new(config: ResourceManagerConfig) -> Self { + debug_assert!(config.resolution_period > Duration::ZERO); + DefaultResourceManager { + config, + channels: Mutex::new(new_hash_map()), + failed_replays: Mutex::new(new_hash_set()), + } + } + + // To calculate the risk of pending HTLCs, we assume they will resolve in the worst + // possible case. Here we assume block times of 10 minutes. + fn htlc_in_flight_risk(&self, fee: u64, incoming_cltv_expiry: u32, height_added: u32) -> u64 { + let maximum_hold_time = (incoming_cltv_expiry.saturating_sub(height_added)) * 10 * 60; + self.opportunity_cost(Duration::from_secs(maximum_hold_time as u64), fee) + } + + fn opportunity_cost(&self, resolution_time: Duration, fee_msat: u64) -> u64 { + let resolution_period = self.config.resolution_period.as_secs_f64(); + let opportunity_cost = 0_f64 + .max((resolution_time.as_secs_f64() - resolution_period) / resolution_period) + * fee_msat as f64; + + opportunity_cost.round() as u64 + } + + fn effective_fees( + &self, fee_msat: u64, resolution_time: Duration, accountable: bool, settled: bool, + ) -> i64 { + let fee = i64::try_from(fee_msat).unwrap_or(i64::MAX); + if accountable { + let opportunity_cost = + i64::try_from(self.opportunity_cost(resolution_time, fee_msat)).unwrap_or(i64::MAX); + if settled { + fee - opportunity_cost + } else { + -opportunity_cost + } + } else { + if settled && resolution_time <= self.config.resolution_period { + fee + } else { + 0 + } + } + } +} + +impl DefaultResourceManager { + /// Registers a new channel with the resource manager for tracking. + /// + /// This should be called when a channel becomes ready for forwarding + pub fn add_channel( + &self, channel_id: u64, max_htlc_value_in_flight_msat: u64, max_accepted_htlcs: u16, + timestamp_unix_secs: u64, + ) -> Result<(), ()> { + let mut channels_lock = self.channels.lock().unwrap(); + let revenue_window_weeks_avg = self.config.revenue_window.as_secs() / (60 * 60 * 24 * 7); + match channels_lock.entry(channel_id) { + Entry::Vacant(entry) => { + let channel = Channel::new( + channel_id, + max_htlc_value_in_flight_msat, + max_accepted_htlcs, + self.config.general_allocation_pct, + self.config.congestion_allocation_pct, + self.config.revenue_window * self.config.reputation_multiplier.into(), + self.config.reputation_multiplier, + revenue_window_weeks_avg as u8, + timestamp_unix_secs, + )?; + entry.insert(channel); + Ok(()) + }, + Entry::Occupied(_) => Ok(()), + } + } + + /// Removes a channel from the resource manager. + /// + /// This should be called when a channel is closing. + pub fn remove_channel(&self, channel_id: u64) -> Result<(), ()> { + let mut channels_lock = self.channels.lock().unwrap(); + + // Release bucket resources on each incoming channel for its pending HTLCs. + if let Some(removed_channel) = channels_lock.remove(&channel_id) { + for (htlc_ref, pending_htlc) in &removed_channel.pending_htlcs { + if let Some(incoming_channel) = channels_lock.get_mut(&htlc_ref.incoming_channel_id) + { + let _ = match pending_htlc.bucket { + BucketAssigned::General => incoming_channel + .general_bucket + .remove_htlc(channel_id, pending_htlc.incoming_amount_msat), + BucketAssigned::Congestion => incoming_channel + .congestion_bucket + .remove_htlc(pending_htlc.incoming_amount_msat), + BucketAssigned::Protected => incoming_channel + .protected_bucket + .remove_htlc(pending_htlc.incoming_amount_msat), + }; + } + } + } + + // Clean up pending HTLC entries and channel slots. + for (_, channel) in channels_lock.iter_mut() { + channel.pending_htlcs.retain(|htlc_ref, _| htlc_ref.incoming_channel_id != channel_id); + channel.general_bucket.remove_channel_slots(channel_id); + } + Ok(()) + } + + /// Evaluates whether an HTLC should be forwarded and updates resource tracking. + /// + /// This is called when deciding whether to accept and forward an incoming HTLC. The + /// implementation determines if sufficient resources are available on the incoming + /// channel and whether the outgoing channel is suitable for forwarding. + /// + /// Returns a [`ForwardingOutcome`] indicating the forwarding decision: + /// - `ForwardingOutcome::Forward(accountable)`: The HTLC should be forwarded. The boolean + /// flag indicates the accountable signal to use for the outgoing HTLC. + /// - `ForwardingOutcome::Fail`: The HTLC should be failed back to the sender. + pub fn add_htlc( + &self, incoming_channel_id: u64, incoming_amount_msat: u64, incoming_cltv_expiry: u32, + outgoing_channel_id: u64, outgoing_amount_msat: u64, incoming_accountable: bool, + htlc_id: u64, height_added: u32, added_at: u64, entropy_source: &ES, + ) -> Result { + if (outgoing_amount_msat > incoming_amount_msat) || (height_added >= incoming_cltv_expiry) { + return Err(()); + } + + let mut channels_lock = self.channels.lock().unwrap(); + + let htlc_ref = HtlcRef { incoming_channel_id, htlc_id }; + let outgoing_channel = channels_lock.get_mut(&outgoing_channel_id).ok_or(())?; + + if outgoing_channel.pending_htlcs.get(&htlc_ref).is_some() { + return Err(()); + } + + let outgoing_reputation = + outgoing_channel.outgoing_reputation.value_at_timestamp(added_at)?; + + let outgoing_in_flight_risk: u64 = outgoing_channel.outgoing_in_flight_risk(); + let fee = incoming_amount_msat - outgoing_amount_msat; + let in_flight_htlc_risk = self.htlc_in_flight_risk(fee, incoming_cltv_expiry, height_added); + let pending_htlcs_in_congestion = + outgoing_channel.pending_htlcs_in_congestion(incoming_channel_id); + + let incoming_channel = channels_lock.get_mut(&incoming_channel_id).ok_or(())?; + + let (accountable, bucket_assigned) = if !incoming_accountable { + if incoming_channel.general_available( + incoming_amount_msat, + outgoing_channel_id, + entropy_source, + )? { + (false, BucketAssigned::General) + } else if incoming_channel.sufficient_reputation( + in_flight_htlc_risk, + outgoing_reputation, + outgoing_in_flight_risk, + added_at, + )? && incoming_channel + .protected_bucket + .resources_available(incoming_amount_msat) + { + (true, BucketAssigned::Protected) + } else if incoming_channel.congestion_eligible( + pending_htlcs_in_congestion, + incoming_amount_msat, + outgoing_channel_id, + added_at, + )? { + (true, BucketAssigned::Congestion) + } else { + return Ok(ForwardingOutcome::Fail); + } + } else { + // If the incoming HTLC is accountable, we only forward it if the outgoing + // channel has sufficient reputation, otherwise we fail it. + if incoming_channel.sufficient_reputation( + in_flight_htlc_risk, + outgoing_reputation, + outgoing_in_flight_risk, + added_at, + )? { + if incoming_channel.protected_bucket.resources_available(incoming_amount_msat) { + (true, BucketAssigned::Protected) + } else if incoming_channel.general_available( + incoming_amount_msat, + outgoing_channel_id, + entropy_source, + )? { + (true, BucketAssigned::General) + } else { + return Ok(ForwardingOutcome::Fail); + } + } else { + return Ok(ForwardingOutcome::Fail); + } + }; + + match bucket_assigned { + BucketAssigned::General => { + incoming_channel.general_bucket.add_htlc( + outgoing_channel_id, + incoming_amount_msat, + entropy_source, + )?; + }, + BucketAssigned::Congestion => { + incoming_channel.congestion_bucket.add_htlc(incoming_amount_msat)?; + }, + BucketAssigned::Protected => { + incoming_channel.protected_bucket.add_htlc(incoming_amount_msat)?; + }, + } + + let outgoing_channel = channels_lock.get_mut(&outgoing_channel_id).ok_or(())?; + let pending_htlc = PendingHTLC { + incoming_amount_msat, + fee, + outgoing_accountable: accountable, + added_at_unix_seconds: added_at, + in_flight_risk: in_flight_htlc_risk, + bucket: bucket_assigned, + }; + outgoing_channel.pending_htlcs.insert(htlc_ref, pending_htlc); + + Ok(ForwardingOutcome::Forward(accountable)) + } + + /// Records the resolution of a forwarded HTLC. + /// + /// This must be called for HTLCs where [`add_htlc`] returned [`ForwardingOutcome::Forward`]. + /// It reports if the HTLC was successfully settled or failed. This allows the implementation + /// to release resources and update any internal tracking state. + /// + /// [`add_htlc`]: DefaultResourceManager::add_htlc + pub fn resolve_htlc( + &self, incoming_channel_id: u64, htlc_id: u64, outgoing_channel_id: u64, settled: bool, + resolved_at: u64, + ) -> Result<(), ()> { + let htlc_ref = HtlcRef { incoming_channel_id, htlc_id }; + + { + let mut failed_replays = self.failed_replays.lock().unwrap(); + if failed_replays.remove(&htlc_ref) { + return Ok(()); + } + } + + let mut channels_lock = self.channels.lock().unwrap(); + let outgoing_channel = channels_lock.get_mut(&outgoing_channel_id).ok_or(())?; + let pending_htlc = outgoing_channel.pending_htlcs.get(&htlc_ref).ok_or(())?.clone(); + + if resolved_at < pending_htlc.added_at_unix_seconds { + return Err(()); + } + let resolution_time = Duration::from_secs(resolved_at - pending_htlc.added_at_unix_seconds); + let effective_fee = self.effective_fees( + pending_htlc.fee, + resolution_time, + pending_htlc.outgoing_accountable, + settled, + ); + outgoing_channel.outgoing_reputation.add_value(effective_fee, resolved_at)?; + outgoing_channel.pending_htlcs.remove(&htlc_ref).ok_or(())?; + + let incoming_channel = channels_lock.get_mut(&incoming_channel_id).ok_or(())?; + match pending_htlc.bucket { + BucketAssigned::General => incoming_channel + .general_bucket + .remove_htlc(outgoing_channel_id, pending_htlc.incoming_amount_msat)?, + BucketAssigned::Congestion => { + // Mark that congestion bucket was misused if it took more than the valid + // resolution period + if resolution_time > self.config.resolution_period { + incoming_channel.misused_congestion(outgoing_channel_id, resolved_at); + } + + incoming_channel.congestion_bucket.remove_htlc(pending_htlc.incoming_amount_msat)? + }, + BucketAssigned::Protected => { + incoming_channel.protected_bucket.remove_htlc(pending_htlc.incoming_amount_msat)? + }, + } + + if settled { + let fee: i64 = i64::try_from(pending_htlc.fee).unwrap_or(i64::MAX); + incoming_channel.incoming_revenue.add_value(fee, resolved_at)?; + } + + Ok(()) + } +} + +pub struct PendingHTLCReplay { + pub incoming_channel_id: u64, + pub incoming_amount_msat: u64, + pub incoming_htlc_id: u64, + pub incoming_cltv_expiry: u32, + pub incoming_accountable: bool, + pub outgoing_channel_id: u64, + pub outgoing_amount_msat: u64, + pub added_at_unix_seconds: u64, + pub height_added: u32, +} + +impl Writeable for DefaultResourceManager { + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + let channels = self.channels.lock().unwrap(); + write_tlv_fields!(writer, { + (1, channels, required), + }); + Ok(()) + } +} + +impl ReadableArgs<(ResourceManagerConfig, &ES, &HashMap)> + for DefaultResourceManager +{ + fn read( + reader: &mut R, args: (ResourceManagerConfig, &ES, &HashMap), + ) -> Result { + let (config, entropy_source, channel_limits) = args; + _init_and_read_len_prefixed_tlv_fields!(reader, { + (1, channels, (required: ReadableArgs, (&config, entropy_source, channel_limits))), + }); + let channels: HashMap = channels.0.unwrap(); + Ok(DefaultResourceManager { + config, + channels: Mutex::new(channels), + failed_replays: Mutex::new(new_hash_set()), + }) + } +} + +impl ReadableArgs<(&ResourceManagerConfig, &ES, &HashMap)> + for HashMap +{ + fn read( + r: &mut R, args: (&ResourceManagerConfig, &ES, &HashMap), + ) -> Result { + let (config, entropy_source, channel_limits) = args; + let len: CollectionLength = Readable::read(r)?; + let mut ret = new_hash_map(); + for _ in 0..len.0 { + let k: u64 = Readable::read(r)?; + let &(max_htlc_value_in_flight_msat, max_accepted_htlcs) = + channel_limits.get(&k).ok_or(DecodeError::InvalidValue)?; + let v = Channel::read( + r, + (max_htlc_value_in_flight_msat, max_accepted_htlcs, config, entropy_source), + )?; + if ret.insert(k, v).is_some() { + return Err(DecodeError::InvalidValue); + } + } + Ok(ret) + } +} + +impl DefaultResourceManager { + // This should only be called once during startup to replay pending HTLCs we had before + // shutdown. + pub fn replay_pending_htlcs( + &self, pending_htlcs: &[PendingHTLCReplay], entropy_source: &ES, + ) -> Result, DecodeError> { + let mut forwarding_outcomes = Vec::with_capacity(pending_htlcs.len()); + let mut failed_replays = self.failed_replays.lock().unwrap(); + for htlc in pending_htlcs { + let outcome = self + .add_htlc( + htlc.incoming_channel_id, + htlc.incoming_amount_msat, + htlc.incoming_cltv_expiry, + htlc.outgoing_channel_id, + htlc.outgoing_amount_msat, + htlc.incoming_accountable, + htlc.incoming_htlc_id, + htlc.height_added, + htlc.added_at_unix_seconds, + entropy_source, + ) + .map_err(|_| DecodeError::InvalidValue)?; + + if outcome == ForwardingOutcome::Fail { + failed_replays.insert(HtlcRef { + incoming_channel_id: htlc.incoming_channel_id, + htlc_id: htlc.incoming_htlc_id, + }); + } + + forwarding_outcomes.push(outcome); + } + Ok(forwarding_outcomes) + } +} + +/// A weighted average that decays over a specified window. +/// +/// It enables tracking of historical behavior without storing individual data points. +/// Instead of maintaining a complete history of events (such as HTLC forwards for tracking +/// reputation), the decaying average continuously adjusts a single accumulated value based on the +/// elapsed time in the window. +struct DecayingAverage { + value: i64, + last_updated_unix_secs: u64, + window: Duration, + half_life: f64, +} + +impl DecayingAverage { + fn new(start_timestamp_unix_secs: u64, window: Duration) -> Self { + DecayingAverage { + value: 0, + last_updated_unix_secs: start_timestamp_unix_secs, + window, + half_life: window.as_secs_f64() * 2_f64.ln(), + } + } + + fn value_at_timestamp(&mut self, timestamp_unix_secs: u64) -> Result { + if timestamp_unix_secs < self.last_updated_unix_secs { + return Err(()); + } + + let elapsed_secs = (timestamp_unix_secs - self.last_updated_unix_secs) as f64; + let decay_rate = 0.5_f64.powf(elapsed_secs / self.half_life); + self.value = (self.value as f64 * decay_rate).round() as i64; + self.last_updated_unix_secs = timestamp_unix_secs; + Ok(self.value) + } + + fn add_value(&mut self, value: i64, timestamp_unix_secs: u64) -> Result { + self.value_at_timestamp(timestamp_unix_secs)?; + self.value = self.value.saturating_add(value); + self.last_updated_unix_secs = timestamp_unix_secs; + Ok(self.value) + } +} + +impl_writeable_tlv_based!(DecayingAverage, { + (1, value, required), + (3, last_updated_unix_secs, required), + (5, window, required), + (_unused, half_life, (static_value, { + let w: Duration = window.0.unwrap(); + w.as_secs_f64() * 2_f64.ln() + })), +}); + +/// Approximates an [`Self::avg_weeks`]-week average by tracking a decaying average over a larger +/// [`Self::window_weeks`] window to smooth out volatility. +struct AggregatedWindowAverage { + start_timestamp_unix_secs: u64, + avg_weeks: u8, + window_weeks: u8, + aggregated_revenue_decaying: DecayingAverage, +} + +impl AggregatedWindowAverage { + fn new(avg_weeks: u8, window_multiplier: u8, start_timestamp_unix_secs: u64) -> Self { + let window_weeks = avg_weeks * window_multiplier; + let window_duration = Duration::from_secs(60 * 60 * 24 * 7 * window_weeks as u64); + AggregatedWindowAverage { + start_timestamp_unix_secs, + avg_weeks, + window_weeks, + aggregated_revenue_decaying: DecayingAverage::new( + start_timestamp_unix_secs, + window_duration, + ), + } + } + + fn add_value(&mut self, value: i64, timestamp: u64) -> Result { + self.aggregated_revenue_decaying.add_value(value, timestamp) + } + + fn value_at_timestamp(&mut self, timestamp_unix_secs: u64) -> Result { + if timestamp_unix_secs < self.start_timestamp_unix_secs { + return Err(()); + } + + let num_windows = (self.window_weeks / self.avg_weeks) as f64; + let elapsed = (timestamp_unix_secs - self.start_timestamp_unix_secs) as f64; + // Early on when elapsed < 5*window, the decaying average underestimates the true sum. + // The warmup_factor (1 - e^(-elapsed/window)) corrects for this. + let warmup_factor = + 1.0 - (-elapsed / (self.window_weeks as u64 * 60 * 60 * 24 * 7) as f64).exp(); + let divisor = f64::max(num_windows * warmup_factor, 1.0); + + Ok((self.aggregated_revenue_decaying.value_at_timestamp(timestamp_unix_secs)? as f64 + / divisor) + .round() as i64) + } +} + +impl_writeable_tlv_based!(AggregatedWindowAverage, { + (1, start_timestamp_unix_secs, required), + (3, avg_weeks, required), + (5, window_weeks, required), + (7, aggregated_revenue_decaying, required), +}); + +#[cfg(test)] +mod tests { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use crate::{ + crypto::chacha20::ChaCha20, + ln::{ + channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS, + resource_manager::{ + assign_slots_for_channel, AggregatedWindowAverage, BucketAssigned, BucketResources, + Channel, DecayingAverage, DefaultResourceManager, ForwardingOutcome, GeneralBucket, + HtlcRef, ResourceManagerConfig, + }, + }, + sign::EntropySource, + util::{ + ser::{ReadableArgs, Writeable}, + test_utils::TestKeysInterface, + }, + }; + use bitcoin::Network; + + const WINDOW: Duration = Duration::from_secs(2016 * 10 * 60); + + #[test] + fn test_general_bucket_channel_slots_count() { + struct TestCase { + general_slots: u16, + general_liquidity: u64, + expected_slots: u8, + expected_liquidity: u64, + } + + // Test that it correctly assigns the number of slots based on total slots in general + // bucket + let cases = vec![ + TestCase { + general_slots: 20, + general_liquidity: 100_000_000, + expected_slots: 5, + expected_liquidity: 25_000_000, + }, + TestCase { + general_slots: 50, + general_liquidity: 100_000_000, + expected_slots: 5, + expected_liquidity: 10_000_000, + }, + TestCase { + general_slots: 100, + general_liquidity: 100_000_000, + expected_slots: 5, + expected_liquidity: 5_000_000, + }, + TestCase { + general_slots: 114, + general_liquidity: 300_000_000, + expected_slots: 6, + expected_liquidity: 15789473, + }, + TestCase { + general_slots: 193, + general_liquidity: 100_000_000, + expected_slots: 10, + expected_liquidity: 5_181_347, + }, + ]; + + let scid = 21; + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + for case in cases { + let general_bucket = GeneralBucket::new(0, case.general_slots, case.general_liquidity); + + assert_eq!(general_bucket.per_channel_slots, case.expected_slots); + assert_eq!(general_bucket.per_slot_msat, case.expected_liquidity); + assert!(general_bucket.slots_occupied.iter().all(|slot| slot.is_none())); + + let (slots, _) = assign_slots_for_channel( + general_bucket.scid, + scid, + None, + &entropy_source, + general_bucket.per_channel_slots, + general_bucket.total_slots, + ) + .unwrap(); + assert_eq!(slots.len(), case.expected_slots as usize); + } + } + + #[test] + fn test_slots_from_salt() { + // Test deterministic slot generation from salt + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let general_bucket = GeneralBucket::new(0, 100, 100_000_000); + let scid = 21; + + let (slots, salt) = assign_slots_for_channel( + general_bucket.scid, + scid, + None, + &entropy_source, + general_bucket.per_channel_slots, + general_bucket.total_slots, + ) + .unwrap(); + + let (slots_from_salt, _) = assign_slots_for_channel( + general_bucket.scid, + scid, + Some(salt), + &entropy_source, + general_bucket.per_channel_slots, + general_bucket.total_slots, + ) + .unwrap(); + + // Test that slots initially assigned are equal to slots assigned from salt. + assert_eq!(slots, slots_from_salt); + } + + #[test] + fn test_general_bucket_add_htlc_over_max_liquidity() { + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let mut general_bucket = GeneralBucket::new(0, 100, 10_000); + + let scid = 21; + let htlc_amount_over_max = 3000; + // General bucket will assign 5 slots of 500 per channel. Max 5 * 500 = 2500 + // Adding an HTLC over the amount should return error. + let add_htlc_res = general_bucket.add_htlc(scid, htlc_amount_over_max, &entropy_source); + assert!(add_htlc_res.is_err()); + + // All slots for the channel should be unoccupied since adding the HTLC failed. + let slots = &general_bucket.channels_slots.get(&scid).unwrap().0; + assert!(slots + .iter() + .all(|slot_idx| general_bucket.slots_occupied[*slot_idx as usize].is_none())); + } + + #[test] + fn test_general_bucket_add_htlc() { + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + // General bucket will assign 5 slots of 500 per channel. Max 5 * 500 = 2500 + let mut general_bucket = GeneralBucket::new(0, 100, 10_000); + + let scid = 21; + // HTLC of 500 should take one slot + let add_htlc_res = general_bucket.add_htlc(scid, 500, &entropy_source); + assert!(add_htlc_res.is_ok()); + let slots_occupied = add_htlc_res.unwrap(); + assert_eq!(slots_occupied.len(), 1); + + let slot_occupied = slots_occupied[0]; + assert_eq!(general_bucket.slots_occupied[slot_occupied as usize], Some(scid)); + + // HTLC of 1200 should take 3 general slots + let add_htlc_res = general_bucket.add_htlc(scid, 1200, &entropy_source); + assert!(add_htlc_res.is_ok()); + let slots_occupied = add_htlc_res.unwrap(); + assert_eq!(slots_occupied.len(), 3); + + for slot_occupied in slots_occupied.iter() { + assert_eq!(general_bucket.slots_occupied[*slot_occupied as usize], Some(scid)); + } + + // 4 slots have been taken. Trying to add HTLC that will take 2 or more slots should fail + // now. + assert!(general_bucket.add_htlc(scid, 501, &entropy_source).is_err()); + let channel_slots = &general_bucket.channels_slots.get(&scid).unwrap().0; + let unoccupied_slots_for_channel: Vec<&u16> = channel_slots + .iter() + .filter(|slot_idx| general_bucket.slots_occupied[**slot_idx as usize].is_none()) + .collect(); + assert_eq!(unoccupied_slots_for_channel.len(), 1); + } + + #[test] + fn test_general_bucket_remove_htlc() { + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let mut general_bucket = GeneralBucket::new(0, 100, 10_000); + + let scid = 21; + let htlc_amount = 400; + let slots_occupied = general_bucket.add_htlc(scid, htlc_amount, &entropy_source).unwrap(); + assert_eq!(slots_occupied.len(), 1); + let slot_occupied = slots_occupied[0]; + assert_eq!(general_bucket.slots_occupied[slot_occupied as usize], Some(scid)); + + // Trying to remove HTLC over number of slots previously used should result in a error + assert!(general_bucket.remove_htlc(scid, htlc_amount + 400).is_err()); + assert!(general_bucket.remove_htlc(scid, htlc_amount).is_ok()); + + assert!(general_bucket.slots_occupied[slot_occupied as usize].is_none()); + } + + fn test_bucket_resources() -> BucketResources { + BucketResources { + slots_allocated: 10, + slots_used: 0, + liquidity_allocated: 100_000, + liquidity_used: 0, + } + } + + #[test] + fn test_bucket_resources_add_htlc() { + let mut bucket_resources = test_bucket_resources(); + let available_liquidity = bucket_resources.liquidity_allocated; + assert!(bucket_resources.add_htlc(available_liquidity + 1000).is_err()); + + assert!(bucket_resources.add_htlc(21_000).is_ok()); + assert!(bucket_resources.add_htlc(42_000).is_ok()); + assert_eq!(bucket_resources.slots_used, 2); + assert_eq!(bucket_resources.liquidity_used, 63_000); + } + + #[test] + fn test_bucket_resources_add_htlc_over_resources_available() { + // Test trying to go over slot limit + let mut bucket_resources = test_bucket_resources(); + let slots_available = bucket_resources.slots_allocated; + for _ in 0..slots_available { + assert!(bucket_resources.add_htlc(10).is_ok()); + } + assert_eq!(bucket_resources.slots_used, slots_available); + assert!(bucket_resources.add_htlc(10).is_err()); + + // Test trying to go over liquidity limit + let mut bucket = test_bucket_resources(); + assert!(bucket.add_htlc(bucket.liquidity_allocated - 1000).is_ok()); + assert!(bucket.add_htlc(2000).is_err()); + } + + #[test] + fn test_bucket_resources_remove_htlc() { + let mut bucket_resources = test_bucket_resources(); + + // If no resources have been used, removing HTLC should fail + assert!(bucket_resources.remove_htlc(100).is_err()); + + bucket_resources.add_htlc(1000).unwrap(); + // Test failure if it tries to remove amount over what is currently in use. + assert!(bucket_resources.remove_htlc(1001).is_err()); + + assert!(bucket_resources.remove_htlc(1000).is_ok()); + assert_eq!(bucket_resources.slots_used, 0); + assert_eq!(bucket_resources.liquidity_used, 0); + } + + fn test_channel(config: &ResourceManagerConfig) -> Channel { + Channel::new( + 0, + 100_000_000, + 100, + config.general_allocation_pct, + config.congestion_allocation_pct, + config.revenue_window * config.reputation_multiplier as u32, + config.reputation_multiplier, + 2, + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + ) + .unwrap() + } + + #[test] + fn test_invalid_channel_configs() { + // (max_inflight, max_accepted_htlcs, general_pct, congestion_pct, protected_pct) + let cases: Vec<(u64, u16, u8, u8)> = vec![ + // Invalid max_accepted_htlcs (> 483) + (100_000, 500, 40, 20), + // Invalid max_htlc_value_in_flight_msat (>= total bitcoin supply) + (TOTAL_BITCOIN_SUPPLY_SATOSHIS * 1000 + 1, 483, 40, 20), + // Invalid bucket percentages + (100_000, 483, 70, 50), + // Invalid max_accepted_htlcs (< 12) + (100_000_000, 11, 40, 20), + // Invalid max_htlc_value_in_flight_msat (< 1000 sats) + (999_999, 100, 40, 20), + ]; + + for (max_inflight, max_htlcs, general_pct, congestion_pct) in cases { + assert!(Channel::new( + 0, + max_inflight, + max_htlcs, + general_pct, + congestion_pct, + WINDOW, + 12, + 2, + 0, + ) + .is_err()); + } + } + + #[test] + fn test_misuse_congestion_bucket() { + let config = ResourceManagerConfig::default(); + let mut channel = test_channel(&config); + let misusing_channel = 1; + + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + assert_eq!(channel.has_misused_congestion(misusing_channel, now).unwrap(), false); + + channel.misused_congestion(misusing_channel, now); + assert_eq!(channel.has_misused_congestion(misusing_channel, now + 5).unwrap(), true,); + + // Congestion misuse is taken into account if the bucket has been misused in the last 2 + // weeks. Test that after 2 weeks since last misuse, it returns that the bucket has not + // been misused. + let two_weeks = config.revenue_window.as_secs(); + assert_eq!( + channel.has_misused_congestion(misusing_channel, now + two_weeks).unwrap(), + false + ); + } + + #[test] + fn test_opportunity_cost() { + let config = ResourceManagerConfig::default(); + let resource_manager = DefaultResourceManager::new(config); + + // Less than resolution_period has zero cost. + assert_eq!(resource_manager.opportunity_cost(Duration::from_secs(10), 100), 0); + + // Above resolution period it is gradually incremented. + assert_eq!(resource_manager.opportunity_cost(Duration::from_secs(91), 100), 1); + assert_eq!(resource_manager.opportunity_cost(Duration::from_secs(135), 100), 50); + assert_eq!(resource_manager.opportunity_cost(Duration::from_secs(180), 100), 100); + + // Multiple periods above resolution_period charges multiples of fee. + assert_eq!(resource_manager.opportunity_cost(Duration::from_secs(900), 100), 900); + } + + #[test] + fn test_effective_fees() { + let config = ResourceManagerConfig::default(); + let fast_resolve = config.resolution_period / 2; + let slow_resolve = config.resolution_period * 3; + + let resource_manager = DefaultResourceManager::new(config); + + let accountable = true; + let settled = true; + let cases = vec![ + (1000, fast_resolve, accountable, settled, 1000), + (1000, slow_resolve, accountable, settled, -1000), + (1000, fast_resolve, accountable, !settled, 0), + (1000, slow_resolve, accountable, !settled, -2000), + // Unaccountable HTLCs do not affect negatively + (1000, fast_resolve, !accountable, settled, 1000), + (1000, slow_resolve, !accountable, settled, 0), + (1000, fast_resolve, !accountable, !settled, 0), + (1000, slow_resolve, !accountable, !settled, 0), + ]; + + for (fee_msat, hold_time, accountable, settled, expected) in cases { + let result = resource_manager.effective_fees(fee_msat, hold_time, accountable, settled); + assert_eq!(result, expected, "Case failed: fee_msat={fee_msat:?}, hold_time={hold_time:?}, accountable={accountable:?}, settled={settled:?}"); + } + } + + const INCOMING_SCID: u64 = 100; + const OUTGOING_SCID: u64 = 200; + const INCOMING_SCID_2: u64 = 101; + const OUTGOING_SCID_2: u64 = 201; + const HTLC_AMOUNT: u64 = 10_000_000; + const FEE_AMOUNT: u64 = 1_000; + const CURRENT_HEIGHT: u32 = 1000; + const CLTV_EXPIRY: u32 = 1144; + + fn create_test_resource_manager_with_channel_pairs(n_pairs: u8) -> DefaultResourceManager { + let config = ResourceManagerConfig::default(); + let rm = DefaultResourceManager::new(config); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + for i in 0..n_pairs { + rm.add_channel(INCOMING_SCID + i as u64, 5_000_000_000, 114, now).unwrap(); + rm.add_channel(OUTGOING_SCID + i as u64, 5_000_000_000, 114, now).unwrap(); + } + rm + } + + fn create_test_resource_manager_with_channels() -> DefaultResourceManager { + create_test_resource_manager_with_channel_pairs(1) + } + + fn add_test_htlc( + rm: &DefaultResourceManager, accountable: bool, htlc_id: u64, added_at: Option, + entropy_source: &ES, + ) -> Result { + rm.add_htlc( + INCOMING_SCID, + HTLC_AMOUNT + FEE_AMOUNT, + CLTV_EXPIRY, + OUTGOING_SCID, + HTLC_AMOUNT, + accountable, + htlc_id, + CURRENT_HEIGHT, + added_at.unwrap_or(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()), + entropy_source, + ) + } + + fn add_reputation(rm: &DefaultResourceManager, outgoing_scid: u64, target_reputation: i64) { + let mut channels = rm.channels.lock().unwrap(); + let outgoing_channel = channels.get_mut(&outgoing_scid).unwrap(); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + outgoing_channel.outgoing_reputation.add_value(target_reputation, now).unwrap(); + } + + fn add_revenue(rm: &DefaultResourceManager, incoming_scid: u64, revenue: i64) { + let mut channels = rm.channels.lock().unwrap(); + let channel = channels.get_mut(&incoming_scid).unwrap(); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + channel.incoming_revenue.add_value(revenue, now).unwrap(); + } + + fn fill_general_bucket(rm: &DefaultResourceManager, incoming_scid: u64) { + let mut channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get_mut(&incoming_scid).unwrap(); + for slot in incoming_channel.general_bucket.slots_occupied.iter_mut() { + *slot = Some(0); + } + } + + fn fill_congestion_bucket(rm: &DefaultResourceManager, incoming_scid: u64) { + let mut channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get_mut(&incoming_scid).unwrap(); + let slots_allocated = incoming_channel.congestion_bucket.slots_allocated; + let liquidity_allocated = incoming_channel.congestion_bucket.liquidity_allocated; + incoming_channel.congestion_bucket.slots_used = slots_allocated; + incoming_channel.congestion_bucket.liquidity_used = liquidity_allocated; + } + + fn fill_protected_bucket(rm: &DefaultResourceManager, incoming_scid: u64) { + let mut channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get_mut(&incoming_scid).unwrap(); + let slots_allocated = incoming_channel.protected_bucket.slots_allocated; + let liquidity_allocated = incoming_channel.protected_bucket.liquidity_allocated; + incoming_channel.protected_bucket.slots_used = slots_allocated; + incoming_channel.protected_bucket.liquidity_used = liquidity_allocated; + } + + fn mark_congestion_misused( + rm: &DefaultResourceManager, incoming_scid: u64, outgoing_scid: u64, + ) { + let mut channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get_mut(&incoming_scid).unwrap(); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + incoming_channel.misused_congestion(outgoing_scid, now); + } + + fn get_htlc_bucket( + rm: &DefaultResourceManager, incoming_channel_id: u64, htlc_id: u64, + outgoing_channel_id: u64, + ) -> Option { + let channels = rm.channels.lock().unwrap(); + let htlc_ref = HtlcRef { incoming_channel_id, htlc_id }; + let htlc = channels.get(&outgoing_channel_id).unwrap().pending_htlcs.get(&htlc_ref); + htlc.map(|htlc| htlc.bucket.clone()) + } + + fn count_pending_htlcs(rm: &DefaultResourceManager, outgoing_scid: u64) -> usize { + let channels = rm.channels.lock().unwrap(); + channels.get(&outgoing_scid).unwrap().pending_htlcs.len() + } + + fn assert_general_bucket_slots_used( + rm: &DefaultResourceManager, incoming_scid: u64, outgoing_scid: u64, expected_count: usize, + ) { + let channels = rm.channels.lock().unwrap(); + let channel = channels.get(&incoming_scid).unwrap(); + let slots = &channel.general_bucket.channels_slots.get(&outgoing_scid).unwrap().0; + let used_count = slots + .iter() + .filter(|slot_idx| { + channel.general_bucket.slots_occupied[**slot_idx as usize] == Some(outgoing_scid) + }) + .count(); + assert_eq!(used_count, expected_count); + } + + fn test_congestion_eligible(rm: &DefaultResourceManager, incoming_htlc_amount: u64) -> bool { + let mut channels_lock = rm.channels.lock().unwrap(); + let outgoing_channel = channels_lock.get_mut(&OUTGOING_SCID).unwrap(); + let pending_htlcs_in_congestion = + outgoing_channel.pending_htlcs_in_congestion(INCOMING_SCID); + + let incoming_channel = channels_lock.get_mut(&INCOMING_SCID).unwrap(); + incoming_channel + .congestion_eligible( + pending_htlcs_in_congestion, + incoming_htlc_amount, + OUTGOING_SCID, + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + ) + .unwrap() + } + + #[test] + fn test_not_congestion_eligible() { + // Test not congestion eligible for: + // - Outgoing channel already has HTLC in congestion bucket. + // - Congestion bucket is full + // - Congestion bucket was misused + let cases = vec![ + |rm: &DefaultResourceManager| { + fill_general_bucket(&rm, INCOMING_SCID); + let htlc_id = 1; + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + add_test_htlc(&rm, false, htlc_id, None, &entropy_source).unwrap(); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion + ); + }, + |rm: &DefaultResourceManager| { + fill_congestion_bucket(rm, INCOMING_SCID); + }, + |rm: &DefaultResourceManager| { + mark_congestion_misused(rm, INCOMING_SCID, OUTGOING_SCID); + }, + ]; + + for case_setup in cases { + let rm = create_test_resource_manager_with_channels(); + case_setup(&rm); + assert_eq!(test_congestion_eligible(&rm, HTLC_AMOUNT + FEE_AMOUNT), false); + } + } + + #[test] + fn test_congestion_eligible_htlc_over_slot_limit() { + let rm = create_test_resource_manager_with_channels(); + assert!(test_congestion_eligible(&rm, HTLC_AMOUNT + FEE_AMOUNT)); + + // Get the congestion bucket's per-slot limit + let channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get(&INCOMING_SCID).unwrap(); + let slot_limit = incoming_channel.congestion_bucket.liquidity_allocated + / incoming_channel.congestion_bucket.slots_allocated as u64; + drop(channels); + + // Try to add HTLC that exceeds the slot limit + let htlc_amount_over_limit = slot_limit + 1000; + assert!(!test_congestion_eligible(&rm, htlc_amount_over_limit)); + } + + fn test_sufficient_reputation(rm: &DefaultResourceManager) -> bool { + let mut channels_lock = rm.channels.lock().unwrap(); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + + let outgoing_channel = channels_lock.get_mut(&OUTGOING_SCID).unwrap(); + let outgoing_reputation = + outgoing_channel.outgoing_reputation.value_at_timestamp(now).unwrap(); + let outgoing_in_flight_risk: u64 = outgoing_channel.outgoing_in_flight_risk(); + let fee = FEE_AMOUNT; + let in_flight_htlc_risk = rm.htlc_in_flight_risk(fee, CLTV_EXPIRY, CURRENT_HEIGHT); + + let incoming_channel = channels_lock.get_mut(&INCOMING_SCID).unwrap(); + incoming_channel + .sufficient_reputation( + in_flight_htlc_risk, + outgoing_reputation, + outgoing_in_flight_risk, + now, + ) + .unwrap() + } + + #[test] + fn test_insufficient_reputation_outgoing_in_flight_risk() { + let rm = create_test_resource_manager_with_channels(); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let reputation = 50_000_000; + add_reputation(&rm, OUTGOING_SCID, reputation); + + // Successfully add unaccountable HTLC that should not count in the outgoing + // accumulated outgoing in-flight risk. + assert!(add_test_htlc(&rm, false, 0, None, &entropy_source).is_ok()); + + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let high_cltv_expiry = CURRENT_HEIGHT + 2000; + + // Add accountable HTLC that will add 49_329_633 to the in-flight risk. This is based + // on the 3700 and CLTV delta added. + assert!(rm + .add_htlc( + INCOMING_SCID, + HTLC_AMOUNT + 3700, + high_cltv_expiry, + OUTGOING_SCID, + HTLC_AMOUNT, + true, + 1, + CURRENT_HEIGHT, + current_time, + &entropy_source, + ) + .is_ok()); + + // Since we have added an accountable HTLC with in-fligh risk that is close to the + // reputation we added, the next accountable HTLC we try to add should fail. + assert_eq!(test_sufficient_reputation(&rm), false); + } + + #[test] + fn test_insufficient_reputation_higher_incoming_revenue_threshold() { + let rm = create_test_resource_manager_with_channels(); + add_reputation(&rm, OUTGOING_SCID, 10_000); + + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let mut channels = rm.channels.lock().unwrap(); + let incoming_channel = channels.get_mut(&INCOMING_SCID).unwrap(); + // Add revenue to incoming channel so that it goes above outgoing's reputation + incoming_channel.incoming_revenue.add_value(50_000, current_time).unwrap(); + drop(channels); + + assert_eq!(test_sufficient_reputation(&rm), false); + } + + #[test] + fn test_sufficient_reputation_exactly_at_threshold() { + let rm = create_test_resource_manager_with_channels(); + + let in_flight_risk = rm.htlc_in_flight_risk(FEE_AMOUNT, CLTV_EXPIRY, CURRENT_HEIGHT); + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let mut channels = rm.channels.lock().unwrap(); + + // Set incoming revenue threshold + let threshold = 10_000_000; + let incoming_channel = channels.get_mut(&INCOMING_SCID).unwrap(); + incoming_channel.incoming_revenue.add_value(threshold, current_time).unwrap(); + + // Set outgoing reputation to match threshold plus in-flight risk + let reputation_needed = threshold + i64::try_from(in_flight_risk).unwrap(); + let outgoing_channel = channels.get_mut(&OUTGOING_SCID).unwrap(); + outgoing_channel.outgoing_reputation.add_value(reputation_needed, current_time).unwrap(); + drop(channels); + + assert_eq!(test_sufficient_reputation(&rm), true); + } + + #[test] + fn test_add_htlc_unaccountable_forwarding_decisions() { + struct TestCase { + description: &'static str, + setup: fn(&DefaultResourceManager), + expected_outcome: ForwardingOutcome, + expected_bucket: Option, + } + + let cases = vec![ + TestCase { + description: "general bucket available", + setup: |_rm| {}, + expected_outcome: ForwardingOutcome::Forward(false), + expected_bucket: Some(BucketAssigned::General), + }, + TestCase { + description: "general full, sufficient reputation goes to protected", + setup: |rm| { + add_reputation(rm, OUTGOING_SCID, HTLC_AMOUNT as i64); + fill_general_bucket(rm, INCOMING_SCID); + }, + expected_outcome: ForwardingOutcome::Forward(true), + expected_bucket: Some(BucketAssigned::Protected), + }, + TestCase { + description: "general full, insufficient reputation goes to congestion", + setup: |rm| fill_general_bucket(rm, INCOMING_SCID), + expected_outcome: ForwardingOutcome::Forward(true), + expected_bucket: Some(BucketAssigned::Congestion), + }, + TestCase { + description: "congestion misused recently fails", + setup: |rm| { + fill_general_bucket(rm, INCOMING_SCID); + mark_congestion_misused(rm, INCOMING_SCID, OUTGOING_SCID); + }, + expected_outcome: ForwardingOutcome::Fail, + expected_bucket: None, + }, + ]; + + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let htlc_id = 1; + for case in cases { + let rm = create_test_resource_manager_with_channels(); + (case.setup)(&rm); + + let result = add_test_htlc(&rm, false, htlc_id, None, &entropy_source); + assert!(result.is_ok(), "case '{}': add_htlc returned Err", case.description); + assert_eq!( + result.unwrap(), + case.expected_outcome, + "case '{}': unexpected forwarding outcome", + case.description + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID), + case.expected_bucket, + "case '{}': unexpected bucket assignment", + case.description + ); + } + } + + #[test] + fn test_add_htlc_unaccountable_congestion_already_has_htlc() { + let rm = create_test_resource_manager_with_channels(); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + fill_general_bucket(&rm, INCOMING_SCID); + + // With general bucket full, adding HTLC here should go to congestion bucket. + let mut htlc_id = 1; + let result_1 = add_test_htlc(&rm, false, htlc_id, None, &entropy_source); + assert!(result_1.is_ok()); + assert_eq!(result_1.unwrap(), ForwardingOutcome::Forward(true)); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion + ); + + // Adding a second HTLC should fail because outgoing channel is already using a slot in + // the congestion bucket and it does not have sufficient reputation to get into the + // protected bucket. + htlc_id = 2; + let result_2 = add_test_htlc(&rm, false, htlc_id, None, &entropy_source); + assert_eq!(result_2.unwrap(), ForwardingOutcome::Fail); + assert!(get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).is_none()); + } + + #[test] + fn test_add_htlc_accountable_forwarding_decisions() { + struct TestCase { + description: &'static str, + setup: fn(&DefaultResourceManager), + expected_outcome: ForwardingOutcome, + expected_bucket: Option, + } + + let cases = vec![ + TestCase { + description: "sufficient reputation goes to protected", + setup: |rm| add_reputation(rm, OUTGOING_SCID, HTLC_AMOUNT as i64), + expected_outcome: ForwardingOutcome::Forward(true), + expected_bucket: Some(BucketAssigned::Protected), + }, + TestCase { + description: "insufficient reputation fails", + setup: |_rm| {}, + expected_outcome: ForwardingOutcome::Fail, + expected_bucket: None, + }, + TestCase { + description: "sufficient reputation, protected full, falls back to general", + setup: |rm| { + add_reputation(rm, OUTGOING_SCID, HTLC_AMOUNT as i64); + fill_protected_bucket(rm, INCOMING_SCID); + }, + expected_outcome: ForwardingOutcome::Forward(true), + expected_bucket: Some(BucketAssigned::General), + }, + TestCase { + description: "sufficient reputation, protected and general full, fails", + setup: |rm| { + add_reputation(rm, OUTGOING_SCID, HTLC_AMOUNT as i64); + fill_general_bucket(rm, INCOMING_SCID); + fill_protected_bucket(rm, INCOMING_SCID); + }, + expected_outcome: ForwardingOutcome::Fail, + expected_bucket: None, + }, + ]; + + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let htlc_id = 1; + + for case in cases { + let rm = create_test_resource_manager_with_channels(); + (case.setup)(&rm); + + let result = add_test_htlc(&rm, true, htlc_id, None, &entropy_source); + assert!(result.is_ok(), "case '{}': add_htlc returned Err", case.description); + assert_eq!( + result.unwrap(), + case.expected_outcome, + "case '{}': unexpected forwarding outcome", + case.description + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID), + case.expected_bucket, + "case '{}': unexpected bucket assignment", + case.description + ); + } + } + + #[test] + fn test_add_htlc_stores_correct_pending_htlc_data() { + let rm = create_test_resource_manager_with_channels(); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let htlc_id = 42; + let result = rm.add_htlc( + INCOMING_SCID, + HTLC_AMOUNT + FEE_AMOUNT, + CLTV_EXPIRY, + OUTGOING_SCID, + HTLC_AMOUNT, + false, + htlc_id, + CURRENT_HEIGHT, + current_time, + &entropy_source, + ); + assert!(result.is_ok()); + + let channels = rm.channels.lock().unwrap(); + let htlc_ref = HtlcRef { incoming_channel_id: INCOMING_SCID, htlc_id }; + let pending_htlc = channels.get(&OUTGOING_SCID).unwrap().pending_htlcs.get(&htlc_ref); + assert!(pending_htlc.is_some()); + // HTLC should only get added to pending list for outgoing channel + assert!(channels.get(&INCOMING_SCID).unwrap().pending_htlcs.get(&htlc_ref).is_none()); + + let pending_htlc = pending_htlc.unwrap(); + assert_eq!(pending_htlc.incoming_amount_msat, HTLC_AMOUNT + FEE_AMOUNT); + assert_eq!(pending_htlc.fee, FEE_AMOUNT); + assert_eq!(pending_htlc.added_at_unix_seconds, current_time); + + let expected_in_flight_risk = + rm.htlc_in_flight_risk(FEE_AMOUNT, CLTV_EXPIRY, CURRENT_HEIGHT); + assert_eq!(pending_htlc.in_flight_risk, expected_in_flight_risk); + } + + #[test] + fn test_resolve_htlc_unaccountable_outcomes() { + struct TestCase { + hold_time: Duration, + settled: bool, + expected_reputation: i64, + expected_revenue: i64, + } + + let config = ResourceManagerConfig::default(); + let fast_resolve = config.resolution_period / 2; + let slow_resolve = config.resolution_period * 3; + + let cases = vec![ + TestCase { + hold_time: fast_resolve, + settled: true, + expected_reputation: FEE_AMOUNT as i64, // effective_fee = fee + expected_revenue: FEE_AMOUNT as i64, + }, + TestCase { + hold_time: slow_resolve, + settled: true, + expected_reputation: 0, // effective_fee = 0 (slow unaccountable) + expected_revenue: FEE_AMOUNT as i64, // revenue increases regardless of speed + }, + TestCase { + hold_time: fast_resolve, + settled: false, + expected_reputation: 0, + expected_revenue: 0, + }, + TestCase { + hold_time: slow_resolve, + settled: false, + expected_reputation: 0, + expected_revenue: 0, + }, + ]; + + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + + for case in &cases { + let rm = create_test_resource_manager_with_channels(); + let htlc_id = 1; + + assert_eq!( + add_test_htlc(&rm, false, htlc_id, None, &entropy_source).unwrap(), + ForwardingOutcome::Forward(false), + ); + + let resolved_at = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() + + case.hold_time.as_secs(); + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, case.settled, resolved_at) + .unwrap(); + + let channels = rm.channels.lock().unwrap(); + assert_eq!( + channels.get(&OUTGOING_SCID).unwrap().outgoing_reputation.value, + case.expected_reputation, + ); + assert_eq!( + channels + .get(&INCOMING_SCID) + .unwrap() + .incoming_revenue + .aggregated_revenue_decaying + .value, + case.expected_revenue, + ); + } + } + + #[test] + fn test_resolve_htlc_congestion_outcomes() { + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let config = ResourceManagerConfig::default(); + let fast_resolve = config.resolution_period / 2; + let slow_resolve = config.resolution_period * 3; + + let rm = create_test_resource_manager_with_channels(); + fill_general_bucket(&rm, INCOMING_SCID); + let mut htlc_id = 1; + assert_eq!( + add_test_htlc(&rm, false, htlc_id, None, &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion, + ); + + let resolved_at = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() + + fast_resolve.as_secs(); + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, false, resolved_at).unwrap(); + + let mut channels = rm.channels.lock().unwrap(); + let incoming = channels.get_mut(&INCOMING_SCID).unwrap(); + + // The HTLC in congestion bucket resolved fast so it does not count as having misused the + // congestion bucket. + assert!(!incoming.has_misused_congestion(OUTGOING_SCID, resolved_at).unwrap()); + assert_eq!(incoming.congestion_bucket.slots_used, 0); + + drop(channels); + + // Since it does not count as congestion misused, this HTLC can be added to congestion + htlc_id += 1; + let added_at = resolved_at; + assert_eq!( + add_test_htlc(&rm, false, htlc_id, Some(added_at), &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion, + ); + + // Slow resolution + let resolved_at = added_at + slow_resolve.as_secs(); + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, false, resolved_at).unwrap(); + + let mut channels = rm.channels.lock().unwrap(); + let incoming = channels.get_mut(&INCOMING_SCID).unwrap(); + + // The HTLC in congestion bucket resolved slowly so it does count as having misused the + // congestion bucket. + assert!(incoming.has_misused_congestion(OUTGOING_SCID, resolved_at).unwrap()); + + drop(channels); + + // Congestion was misused so trying to add an HTLC should fail because the channel does + // not have reputation to get into protected. + htlc_id += 1; + let added_at = resolved_at; + assert_eq!( + add_test_htlc(&rm, false, htlc_id, Some(added_at), &entropy_source).unwrap(), + ForwardingOutcome::Fail, + ); + + let mut channels = rm.channels.lock().unwrap(); + let incoming = channels.get_mut(&INCOMING_SCID).unwrap(); + + // After two weeks, the misused entry should be removed and congestion bucket should be + // available again for use. + let after_two_weeks = added_at + config.revenue_window.as_secs(); + assert!(!incoming.has_misused_congestion(OUTGOING_SCID, after_two_weeks).unwrap()); + assert!(incoming.last_congestion_misuse.get(&OUTGOING_SCID).is_none()); + + drop(channels); + + htlc_id += 1; + assert_eq!( + add_test_htlc(&rm, false, htlc_id, Some(after_two_weeks), &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + } + + #[test] + fn test_resolve_htlc_accountable_outcomes() { + let rm = create_test_resource_manager_with_channels(); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + let fast_resolve = rm.config.resolution_period / 2; + let accountable = true; + + add_reputation(&rm, OUTGOING_SCID, HTLC_AMOUNT as i64); + + let mut htlc_id = 1; + let added_at = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + assert_eq!( + add_test_htlc(&rm, accountable, htlc_id, Some(added_at), &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Protected, + ); + + let get_reputation = |at_timestamp: u64| -> i64 { + let mut channels = rm.channels.lock().unwrap(); + channels + .get_mut(&OUTGOING_SCID) + .unwrap() + .outgoing_reputation + .value_at_timestamp(at_timestamp) + .unwrap() + }; + + // Check fast settled resolution adds to reputation + let resolved_at = added_at + fast_resolve.as_secs(); + let current_rep = get_reputation(resolved_at); + + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, true, resolved_at).unwrap(); + + let reputation_after_fast_resolve = get_reputation(resolved_at); + assert_eq!(reputation_after_fast_resolve, (current_rep + FEE_AMOUNT as i64)); + + let mut channels = rm.channels.lock().unwrap(); + let revenue = channels + .get_mut(&INCOMING_SCID) + .unwrap() + .incoming_revenue + .value_at_timestamp(resolved_at) + .unwrap(); + assert_eq!(revenue, FEE_AMOUNT as i64,); + drop(channels); + + // Fast failing accountable HTLC does not affect reputation + htlc_id += 1; + let added_at = resolved_at; + assert_eq!( + add_test_htlc(&rm, accountable, htlc_id, Some(added_at), &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Protected, + ); + + let resolved_at = added_at + fast_resolve.as_secs(); + let reputation_before_resolve = get_reputation(resolved_at); + + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, false, resolved_at).unwrap(); + + assert_eq!(get_reputation(resolved_at), reputation_before_resolve); + + // Slow resolution should decrease reputation by effective fee + let slow_resolve = rm.config.resolution_period * 10; + let added_at = resolved_at; + assert_eq!( + add_test_htlc(&rm, accountable, htlc_id, Some(added_at), &entropy_source).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Protected, + ); + + let resolved_at = added_at + slow_resolve.as_secs(); + let reputation_before_slow_resolve = get_reputation(resolved_at); + let effective_fee_slow_resolve = rm.effective_fees(FEE_AMOUNT, slow_resolve, true, true); + rm.resolve_htlc(INCOMING_SCID, htlc_id, OUTGOING_SCID, true, resolved_at).unwrap(); + let reputation_after_slow_resolve = get_reputation(resolved_at); + + assert_eq!( + reputation_after_slow_resolve, + reputation_before_slow_resolve + effective_fee_slow_resolve + ); + } + + #[test] + fn test_multi_channel_general_bucket_saturation_flow() { + let rm = create_test_resource_manager_with_channel_pairs(2); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + + // Fill general bucket (it should have been assigned 5 slots) + let mut htlc_ids = Vec::new(); + for i in 1..=5 { + let result = add_test_htlc(&rm, false, i, None, &entropy_source); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), ForwardingOutcome::Forward(false)); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, i, OUTGOING_SCID).unwrap(), + BucketAssigned::General + ); + htlc_ids.push(i); + } + assert_general_bucket_slots_used(&rm, INCOMING_SCID, OUTGOING_SCID, 5); + + // With the 5 slots in the general bucket used, the 6th HTLC goes to congestion + let result = add_test_htlc(&rm, false, 6, None, &entropy_source); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), ForwardingOutcome::Forward(true)); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, 6, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion + ); + + // 7th HTLC fails because it is already using a congestion slot and channel does not + // have sufficient reputation to get into protected bucket. + let result = add_test_htlc(&rm, false, 7, None, &entropy_source); + assert_eq!(result.unwrap(), ForwardingOutcome::Fail); + assert!(get_htlc_bucket(&rm, INCOMING_SCID, 7, OUTGOING_SCID).is_none()); + + // Resolve 3 HTLCs that were assigned to the general bucket. It should end up with 2 in + // general and one in congestion. + let resolved_at = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + rm.resolve_htlc(INCOMING_SCID, htlc_ids[0], OUTGOING_SCID, true, resolved_at).unwrap(); + rm.resolve_htlc(INCOMING_SCID, htlc_ids[2], OUTGOING_SCID, true, resolved_at).unwrap(); + rm.resolve_htlc(INCOMING_SCID, htlc_ids[4], OUTGOING_SCID, true, resolved_at).unwrap(); + assert_general_bucket_slots_used(&rm, INCOMING_SCID, OUTGOING_SCID, 2); + assert_eq!(count_pending_htlcs(&rm, OUTGOING_SCID), 3); + + // Adding more HTLCs should now use the freed general slots. + for i in 8..=10 { + let result = add_test_htlc(&rm, false, i, None, &entropy_source); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), ForwardingOutcome::Forward(false)); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, i, OUTGOING_SCID).unwrap(), + BucketAssigned::General + ); + } + assert_general_bucket_slots_used(&rm, INCOMING_SCID, OUTGOING_SCID, 5); + + // Adding HTLCs to a different outgoing channel should use its own slots. Some + // slots may conflict with OUTGOING_SCID's assigned slots, so we check how many + // are actually available. + let conflicting_slots = { + let mut channels = rm.channels.lock().unwrap(); + let incoming = channels.get_mut(&INCOMING_SCID).unwrap(); + let entry = assign_slots_for_channel( + incoming.general_bucket.scid, + OUTGOING_SCID_2, + None, + &entropy_source, + incoming.general_bucket.per_channel_slots, + incoming.general_bucket.total_slots, + ) + .unwrap(); + + let slots_1 = &incoming.general_bucket.channels_slots.get(&OUTGOING_SCID).unwrap().0; + let ret = entry.0.iter().filter(|s| slots_1.contains(s)).count(); + + incoming.general_bucket.channels_slots.insert(OUTGOING_SCID_2, entry); + ret + }; + + let available_slots = 5 - conflicting_slots; + for i in 11..11 + available_slots as u64 { + let result = rm.add_htlc( + INCOMING_SCID, + HTLC_AMOUNT + FEE_AMOUNT, + CLTV_EXPIRY, + OUTGOING_SCID_2, + HTLC_AMOUNT, + false, + i, + CURRENT_HEIGHT, + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + &entropy_source, + ); + assert!(result.is_ok()); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, i, OUTGOING_SCID_2).unwrap(), + BucketAssigned::General + ); + } + assert_general_bucket_slots_used(&rm, INCOMING_SCID, OUTGOING_SCID_2, available_slots); + + // Different incoming uses its own bucket + for i in 12 + available_slots as u64..=20 { + let result = rm.add_htlc( + INCOMING_SCID_2, + HTLC_AMOUNT + FEE_AMOUNT, + CLTV_EXPIRY, + OUTGOING_SCID, + HTLC_AMOUNT, + false, + i, + CURRENT_HEIGHT, + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + &entropy_source, + ); + assert!(result.is_ok()); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID_2, i, OUTGOING_SCID).unwrap(), + BucketAssigned::General + ); + } + + // Verify original channel pair still has 5 slots used + assert_general_bucket_slots_used(&rm, INCOMING_SCID, OUTGOING_SCID, 5); + } + + #[test] + fn test_multi_channel_bucket_fallback_with_earned_reputation() { + let entropy_source = TestKeysInterface::new(&[2; 32], Network::Testnet); + let rm = create_test_resource_manager_with_channel_pairs(2); + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + + let add_htlc_between = + |incoming_scid: u64, outgoing_scid: u64, accountable: bool, htlc_id: u64| { + rm.add_htlc( + incoming_scid, + HTLC_AMOUNT + FEE_AMOUNT, + CLTV_EXPIRY, + outgoing_scid, + HTLC_AMOUNT, + accountable, + htlc_id, + CURRENT_HEIGHT, + now, + &entropy_source, + ) + }; + + // Build a revenue threshold of 5000 on INCOMING_SCID. + for i in 1..=5_u64 { + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID_2, false, i).unwrap(), + ForwardingOutcome::Forward(false), + ); + rm.resolve_htlc(INCOMING_SCID, i, OUTGOING_SCID_2, true, now).unwrap(); + } + + // Use all generate slots available in INCOMING_SCID for both outgoing channels. + for i in 6..=10_u64 { + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID, false, i).unwrap(), + ForwardingOutcome::Forward(false), + ); + } + for i in 11..=15_u64 { + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID_2, false, i).unwrap(), + ForwardingOutcome::Forward(false), + ); + } + let mut htlc_id = 16_u64; + + // Acquire a congestion slot for both outgoing channels. Reputation has not been earned + // yet, so unaccountable HTLCs fall to congestion. + let congestion_htlc_outgoing = htlc_id; + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID, false, congestion_htlc_outgoing) + .unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, congestion_htlc_outgoing, OUTGOING_SCID).unwrap(), + BucketAssigned::Congestion, + ); + htlc_id += 1; + + let congestion_htlc_outgoing_2 = htlc_id; + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID_2, false, congestion_htlc_outgoing_2) + .unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, congestion_htlc_outgoing_2, OUTGOING_SCID_2) + .unwrap(), + BucketAssigned::Congestion, + ); + htlc_id += 1; + + // Build reputation for OUTGOING_SCID channel but not for OUTGOING_SCID_2. + let rep_fee = 200_000_u64; + let rep_htlc_amount = 1_000_000_u64; + for i in htlc_id..htlc_id + 10 { + assert_eq!( + rm.add_htlc( + INCOMING_SCID_2, + rep_htlc_amount + rep_fee, + CLTV_EXPIRY, + OUTGOING_SCID, + rep_htlc_amount, + false, + i, + CURRENT_HEIGHT, + now, + &entropy_source, + ) + .unwrap(), + ForwardingOutcome::Forward(false), + ); + rm.resolve_htlc(INCOMING_SCID_2, i, OUTGOING_SCID, true, now).unwrap(); + } + htlc_id += 10; + + // Accountable HTLC forwarding decisions diverge based on earned reputation. + // + // - OUTGOING_SCID has reputation so accountable HTLC will get access to protected + // bucket. + // - OUTGOING_SCID_2 does not have reputation so it should fail. + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID, true, htlc_id).unwrap(), + ForwardingOutcome::Forward(true), + ); + assert_eq!( + get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID).unwrap(), + BucketAssigned::Protected, + ); + htlc_id += 1; + + assert_eq!( + add_htlc_between(INCOMING_SCID, OUTGOING_SCID_2, true, htlc_id).unwrap(), + ForwardingOutcome::Fail, + ); + assert!(get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID_2).is_none()); + } + + #[test] + fn test_simple_manager_serialize_deserialize() { + // This is not a complete test of the serialization/deserialization of the resource + // manager because the pending HTLCs will be replayed through `replay_pending_htlcs` by + // the upstream i.e ChannelManager. + let rm = create_test_resource_manager_with_channels(); + let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet); + + add_test_htlc(&rm, false, 0, None, &entropy_source).unwrap(); + + let reputation = 50_000_000; + add_reputation(&rm, OUTGOING_SCID, reputation); + + let revenue = 70_000_000; + add_revenue(&rm, INCOMING_SCID, revenue); + + let serialized_rm = rm.encode(); + + let channels = rm.channels.lock().unwrap(); + let expected_incoming_channel = channels.get(&INCOMING_SCID).unwrap(); + let (expected_slots, expected_salt) = expected_incoming_channel + .general_bucket + .channels_slots + .get(&OUTGOING_SCID) + .unwrap() + .clone(); + + let mut channel_limits = crate::prelude::new_hash_map(); + channel_limits.insert(INCOMING_SCID, (5_000_000_000u64, 114u16)); + channel_limits.insert(OUTGOING_SCID, (5_000_000_000u64, 114u16)); + + let deserialized_rm = DefaultResourceManager::read( + &mut serialized_rm.as_slice(), + (ResourceManagerConfig::default(), &entropy_source, &channel_limits), + ) + .unwrap(); + let deserialized_channels = deserialized_rm.channels.lock().unwrap(); + assert_eq!(2, deserialized_channels.len()); + + let outgoing_channel = deserialized_channels.get(&OUTGOING_SCID).unwrap(); + assert!(outgoing_channel.general_bucket.channels_slots.is_empty()); + + assert_eq!(outgoing_channel.outgoing_reputation.value, reputation); + + let incoming_channel = deserialized_channels.get(&INCOMING_SCID).unwrap(); + assert_eq!(incoming_channel.incoming_revenue.aggregated_revenue_decaying.value, revenue); + + assert_eq!(incoming_channel.general_bucket.channels_slots.len(), 1); + + let (slots, salt) = + incoming_channel.general_bucket.channels_slots.get(&OUTGOING_SCID).unwrap().clone(); + assert_eq!(slots, expected_slots); + assert_eq!(salt, expected_salt); + + let congestion_bucket = &incoming_channel.congestion_bucket; + assert_eq!( + congestion_bucket.slots_allocated, + expected_incoming_channel.congestion_bucket.slots_allocated + ); + assert_eq!( + congestion_bucket.liquidity_allocated, + expected_incoming_channel.congestion_bucket.liquidity_allocated + ); + let protected_bucket = &incoming_channel.protected_bucket; + assert_eq!( + protected_bucket.slots_allocated, + expected_incoming_channel.protected_bucket.slots_allocated + ); + assert_eq!( + protected_bucket.liquidity_allocated, + expected_incoming_channel.protected_bucket.liquidity_allocated + ); + } + + #[test] + fn test_decaying_average_error() { + let timestamp = 1000; + let mut decaying_average = DecayingAverage::new(timestamp, WINDOW); + assert!(decaying_average.value_at_timestamp(timestamp - 100).is_err()); + assert!(decaying_average.add_value(500, timestamp - 100).is_err()); + } + + #[test] + fn test_decaying_average_bounds() { + for (start, bound) in [(1000, i64::MAX), (-1000, i64::MIN)] { + let timestamp = 1000; + let mut avg = DecayingAverage::new(timestamp, WINDOW); + assert_eq!(avg.add_value(start, timestamp).unwrap(), start); + assert_eq!(avg.add_value(bound, timestamp).unwrap(), bound); + } + } + + #[test] + fn test_value_decays_to_zero_eventually() { + let timestamp = 1000; + let mut avg = DecayingAverage::new(timestamp, Duration::from_secs(100)); + assert_eq!(avg.add_value(100_000_000, timestamp).unwrap(), 100_000_000); + + // After many window periods, value should decay to 0 + let result = avg.value_at_timestamp(timestamp * 1000); + assert_eq!(result, Ok(0)); + } + + #[test] + fn test_decaying_average_values() { + // Test average decay at different timestamps. The values we are asserting have been + // independently calculated. + let mut current_timestamp = 0; + let mut avg = DecayingAverage::new(current_timestamp, WINDOW); + + assert_eq!(avg.add_value(1000, current_timestamp).unwrap(), 1000); + + let one_week = 60 * 60 * 24 * 7; + + current_timestamp += one_week; // 1 week + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 607); + assert_eq!(avg.add_value(500, current_timestamp).unwrap(), 1107); + + current_timestamp += one_week / 2; // 1.5 weeks + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 862); + + current_timestamp += one_week / 2; // 2 weeks + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 671); + assert_eq!(avg.add_value(200, current_timestamp).unwrap(), 871); + + current_timestamp += one_week * 2; // 4 weeks + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 320); + + current_timestamp += one_week * 6; // 10 weeks + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 16); + assert_eq!(avg.add_value(1000, current_timestamp).unwrap(), 1016); + + current_timestamp += avg.half_life as u64; + assert_eq!(avg.value_at_timestamp(current_timestamp).unwrap(), 1016 / 2); + } + + #[test] + fn test_aggregated_window_average() { + let avg_weeks: u8 = 2; + let window_weeks: u8 = 12; + let num_windows = (window_weeks / avg_weeks) as usize; + let week_secs: u64 = 60 * 60 * 24 * 7; + let sum_window_secs = avg_weeks as u64 * week_secs; + + let num_points: usize = 50_000; + let duration_weeks: u64 = 120; + let skip_weeks: u64 = 10; + let duration_secs = duration_weeks * week_secs; + let start_timestamp: u64 = 0; + + let mut prng = ChaCha20::new(&[42u8; 32], &[0u8; 12]); + let mut timestamps = Vec::with_capacity(num_points); + let mut values = Vec::with_capacity(num_points); + for _ in 0..num_points { + let mut buf = [0u8; 8]; + prng.process_in_place(&mut buf); + let ts_offset = u64::from_le_bytes(buf) % duration_secs; + timestamps.push(start_timestamp + ts_offset); + + let mut buf = [0u8; 4]; + prng.process_in_place(&mut buf); + let val = (u32::from_le_bytes(buf) % 49_001 + 1_000) as i64; + values.push(val); + } + + let mut indices: Vec = (0..num_points).collect(); + indices.sort_by_key(|&i| timestamps[i]); + let sorted_ts: Vec = indices.iter().map(|&i| timestamps[i]).collect(); + let sorted_vals: Vec = indices.iter().map(|&i| values[i]).collect(); + + let mut avg = AggregatedWindowAverage::new(avg_weeks, window_weeks, start_timestamp); + let mut data_idx = 0; + + for w in 1..=duration_weeks { + let sample_time = start_timestamp + w * week_secs; + + // Add all data points up to this sample time. + while data_idx < num_points && sorted_ts[data_idx] <= sample_time { + avg.add_value(sorted_vals[data_idx], sorted_ts[data_idx]).unwrap(); + data_idx += 1; + } + + let approx_avg = avg.value_at_timestamp(sample_time).unwrap(); + + let mut window_sums = Vec::with_capacity(num_windows); + for i in 0..num_windows { + let window_end = sample_time - i as u64 * sum_window_secs; + if window_end < sum_window_secs + start_timestamp { + break; + } + let window_start = window_end - sum_window_secs; + let window_sum: i64 = sorted_ts + .iter() + .zip(sorted_vals.iter()) + .filter(|(&t, _)| t > window_start && t <= window_end) + .map(|(_, &v)| v) + .sum(); + window_sums.push(window_sum); + } + + let actual_avg = if window_sums.is_empty() { + 0 + } else { + (window_sums.iter().sum::() as f64 / window_sums.len() as f64).round() as i64 + }; + + let error_pct = if actual_avg != 0 { + (approx_avg - actual_avg) as f64 / actual_avg as f64 * 100.0 + } else { + 0.0 + }; + + if w >= skip_weeks { + assert!( + error_pct.abs() < 3.0, + "week {w}: error {error_pct:.2}% exceeds 3% \ + (approx={approx_avg}, actual={actual_avg})" + ); + } + } + } +}