diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index bc0d42ac191..d9e1f2ce251 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -53,8 +53,6 @@ use lightning::routing::utxo::UtxoLookup; use lightning::sign::{ ChangeDestinationSource, ChangeDestinationSourceSync, EntropySource, OutputSpender, }; -#[cfg(not(c_bindings))] -use lightning::util::async_poll::MaybeSend; use lightning::util::logger::Logger; use lightning::util::persist::{ KVStore, KVStoreSync, KVStoreSyncWrapper, CHANNEL_MANAGER_PERSISTENCE_KEY, @@ -63,6 +61,8 @@ use lightning::util::persist::{ NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE, SCORER_PERSISTENCE_KEY, SCORER_PERSISTENCE_PRIMARY_NAMESPACE, SCORER_PERSISTENCE_SECONDARY_NAMESPACE, }; +#[cfg(not(c_bindings))] +use lightning::util::native_async::MaybeSend; use lightning::util::sweep::{OutputSweeper, OutputSweeperSync}; #[cfg(feature = "std")] use lightning::util::wakers::Sleeper; diff --git a/lightning-block-sync/src/async_poll.rs b/lightning-block-sync/src/async_poll.rs new file mode 120000 index 00000000000..eb85cdac697 --- /dev/null +++ b/lightning-block-sync/src/async_poll.rs @@ -0,0 +1 @@ +../../lightning/src/util/async_poll.rs \ No newline at end of file diff --git a/lightning-block-sync/src/gossip.rs b/lightning-block-sync/src/gossip.rs index 596098350c7..004914daac9 100644 --- a/lightning-block-sync/src/gossip.rs +++ b/lightning-block-sync/src/gossip.rs @@ -49,8 +49,12 @@ pub trait UtxoSource: BlockSource + 'static { pub struct TokioSpawner; #[cfg(feature = "tokio")] impl FutureSpawner for TokioSpawner { - fn spawn + Send + 'static>(&self, future: T) { - tokio::spawn(future); + type E = tokio::task::JoinError; + type SpawnedFutureResult = tokio::task::JoinHandle; + fn spawn + Send + 'static>( + &self, future: F, + ) -> Self::SpawnedFutureResult { + tokio::spawn(future) } } @@ -280,7 +284,7 @@ where let gossiper = Arc::clone(&self.gossiper); let block_cache = Arc::clone(&self.block_cache); let pmw = Arc::clone(&self.peer_manager_wake); - self.spawn.spawn(async move { + let _ = self.spawn.spawn(async move { let res = Self::retrieve_utxo(source, block_cache, short_channel_id).await; fut.resolve(gossiper.network_graph(), &*gossiper, res); (pmw)(); diff --git a/lightning-block-sync/src/init.rs b/lightning-block-sync/src/init.rs index a870f8ca88c..52eea92e244 100644 --- a/lightning-block-sync/src/init.rs +++ b/lightning-block-sync/src/init.rs @@ -1,11 +1,11 @@ //! Utilities to assist in the initial sync required to initialize or reload Rust-Lightning objects //! from disk. -use crate::poll::{ChainPoller, Validate, ValidatedBlockHeader}; -use crate::{BlockSource, BlockSourceResult, Cache, ChainNotifier}; +use crate::async_poll::{MultiResultFuturePoller, ResultFuture}; +use crate::poll::{ChainPoller, Poll, Validate, ValidatedBlockHeader}; +use crate::{BlockData, BlockSource, BlockSourceResult, ChainNotifier, HeaderCache}; use bitcoin::block::Header; -use bitcoin::hash_types::BlockHash; use bitcoin::network::Network; use lightning::chain; @@ -32,19 +32,18 @@ where /// Performs a one-time sync of chain listeners using a single *trusted* block source, bringing each /// listener's view of the chain from its paired block hash to `block_source`'s best chain tip. /// -/// Upon success, the returned header can be used to initialize [`SpvClient`]. In the case of -/// failure, each listener may be left at a different block hash than the one it was originally -/// paired with. +/// Upon success, the returned header and header cache can be used to initialize [`SpvClient`]. In +/// the case of failure, each listener may be left at a different block hash than the one it was +/// originally paired with. /// /// Useful during startup to bring the [`ChannelManager`] and each [`ChannelMonitor`] in sync before /// switching to [`SpvClient`]. For example: /// /// ``` -/// use bitcoin::hash_types::BlockHash; /// use bitcoin::network::Network; /// /// use lightning::chain; -/// use lightning::chain::Watch; +/// use lightning::chain::{BestBlock, Watch}; /// use lightning::chain::chainmonitor; /// use lightning::chain::chainmonitor::ChainMonitor; /// use lightning::chain::channelmonitor::ChannelMonitor; @@ -89,14 +88,14 @@ where /// logger: &L, /// persister: &P, /// ) { -/// // Read a serialized channel monitor paired with the block hash when it was persisted. +/// // Read a serialized channel monitor paired with the best block when it was persisted. /// let serialized_monitor = "..."; -/// let (monitor_block_hash, mut monitor) = <(BlockHash, ChannelMonitor)>::read( +/// let (monitor_best_block, mut monitor) = <(BestBlock, ChannelMonitor)>::read( /// &mut Cursor::new(&serialized_monitor), (entropy_source, signer_provider)).unwrap(); /// -/// // Read the channel manager paired with the block hash when it was persisted. +/// // Read the channel manager paired with the best block when it was persisted. /// let serialized_manager = "..."; -/// let (manager_block_hash, mut manager) = { +/// let (manager_best_block, mut manager) = { /// let read_args = ChannelManagerReadArgs::new( /// entropy_source, /// node_signer, @@ -110,19 +109,18 @@ where /// config, /// vec![&mut monitor], /// ); -/// <(BlockHash, ChannelManager<&ChainMonitor, &T, &ES, &NS, &SP, &F, &R, &MR, &L>)>::read( +/// <(BestBlock, ChannelManager<&ChainMonitor, &T, &ES, &NS, &SP, &F, &R, &MR, &L>)>::read( /// &mut Cursor::new(&serialized_manager), read_args).unwrap() /// }; /// /// // Synchronize any channel monitors and the channel manager to be on the best block. -/// let mut cache = UnboundedCache::new(); /// let mut monitor_listener = (monitor, &*tx_broadcaster, &*fee_estimator, &*logger); /// let listeners = vec![ -/// (monitor_block_hash, &monitor_listener as &dyn chain::Listen), -/// (manager_block_hash, &manager as &dyn chain::Listen), +/// (monitor_best_block, &monitor_listener as &dyn chain::Listen), +/// (manager_best_block, &manager as &dyn chain::Listen), /// ]; -/// let chain_tip = init::synchronize_listeners( -/// block_source, Network::Bitcoin, &mut cache, listeners).await.unwrap(); +/// let (chain_cache, chain_tip) = init::synchronize_listeners( +/// block_source, Network::Bitcoin, listeners).await.unwrap(); /// /// // Allow the chain monitor to watch any channels. /// let monitor = monitor_listener.0; @@ -131,7 +129,7 @@ where /// // Create an SPV client to notify the chain monitor and channel manager of block events. /// let chain_poller = poll::ChainPoller::new(block_source, Network::Bitcoin); /// let mut chain_listener = (chain_monitor, &manager); -/// let spv_client = SpvClient::new(chain_tip, chain_poller, &mut cache, &chain_listener); +/// let spv_client = SpvClient::new(chain_tip, chain_poller, chain_cache, &chain_listener); /// } /// ``` /// @@ -140,85 +138,87 @@ where /// [`ChannelMonitor`]: lightning::chain::channelmonitor::ChannelMonitor pub async fn synchronize_listeners< B: Deref + Sized + Send + Sync, - C: Cache, L: chain::Listen + ?Sized, >( - block_source: B, network: Network, header_cache: &mut C, - mut chain_listeners: Vec<(BlockHash, &L)>, -) -> BlockSourceResult + block_source: B, network: Network, + mut chain_listeners: Vec<(BestBlock, &L)>, +) -> BlockSourceResult<(HeaderCache, ValidatedBlockHeader)> where B::Target: BlockSource, { let best_header = validate_best_block_header(&*block_source).await?; - // Fetch the header for the block hash paired with each listener. - let mut chain_listeners_with_old_headers = Vec::new(); - for (old_block_hash, chain_listener) in chain_listeners.drain(..) { - let old_header = match header_cache.look_up(&old_block_hash) { - Some(header) => *header, - None => { - block_source.get_header(&old_block_hash, None).await?.validate(old_block_hash)? - }, - }; - chain_listeners_with_old_headers.push((old_header, chain_listener)) - } - // Find differences and disconnect blocks for each listener individually. let mut chain_poller = ChainPoller::new(block_source, network); let mut chain_listeners_at_height = Vec::new(); - let mut most_common_ancestor = None; let mut most_connected_blocks = Vec::new(); - for (old_header, chain_listener) in chain_listeners_with_old_headers.drain(..) { + let mut header_cache = HeaderCache::new(); + for (old_best_block, chain_listener) in chain_listeners.drain(..) { // Disconnect any stale blocks, but keep them in the cache for the next iteration. - let header_cache = &mut ReadOnlyCache(header_cache); let (common_ancestor, connected_blocks) = { let chain_listener = &DynamicChainListener(chain_listener); - let mut chain_notifier = ChainNotifier { header_cache, chain_listener }; + let mut chain_notifier = ChainNotifier { header_cache: &mut header_cache, chain_listener }; let difference = - chain_notifier.find_difference(best_header, &old_header, &mut chain_poller).await?; - chain_notifier.disconnect_blocks(difference.disconnected_blocks); + chain_notifier.find_difference_from_best_block(best_header, old_best_block, &mut chain_poller).await?; + if difference.common_ancestor.block_hash != old_best_block.block_hash { + chain_notifier.disconnect_blocks(difference.common_ancestor); + } (difference.common_ancestor, difference.connected_blocks) }; // Keep track of the most common ancestor and all blocks connected across all listeners. chain_listeners_at_height.push((common_ancestor.height, chain_listener)); if connected_blocks.len() > most_connected_blocks.len() { - most_common_ancestor = Some(common_ancestor); most_connected_blocks = connected_blocks; } } - // Connect new blocks for all listeners at once to avoid re-fetching blocks. - if let Some(common_ancestor) = most_common_ancestor { - let chain_listener = &ChainListenerSet(chain_listeners_at_height); - let mut chain_notifier = ChainNotifier { header_cache, chain_listener }; - chain_notifier - .connect_blocks(common_ancestor, most_connected_blocks, &mut chain_poller) - .await - .map_err(|(e, _)| e)?; - } - - Ok(best_header) -} - -/// A wrapper to make a cache read-only. -/// -/// Used to prevent losing headers that may be needed to disconnect blocks common to more than one -/// listener. -struct ReadOnlyCache<'a, C: Cache>(&'a mut C); + while !most_connected_blocks.is_empty() { + #[cfg(not(test))] + const MAX_BLOCKS_AT_ONCE: usize = 6 * 6; // Six hours of blocks, 144MiB encoded + #[cfg(test)] + const MAX_BLOCKS_AT_ONCE: usize = 2; + + let mut fetch_block_futures = + Vec::with_capacity(core::cmp::min(MAX_BLOCKS_AT_ONCE, most_connected_blocks.len())); + for header in most_connected_blocks.iter().rev().take(MAX_BLOCKS_AT_ONCE) { + let fetch_future = chain_poller.fetch_block(header); + fetch_block_futures.push(ResultFuture::Pending(Box::pin(async move { + (header, fetch_future.await) + }))); + } + let results = MultiResultFuturePoller::new(fetch_block_futures).await.into_iter(); -impl<'a, C: Cache> Cache for ReadOnlyCache<'a, C> { - fn look_up(&self, block_hash: &BlockHash) -> Option<&ValidatedBlockHeader> { - self.0.look_up(block_hash) - } + let mut fetched_blocks = [const { None }; MAX_BLOCKS_AT_ONCE]; + for ((header, block_res), result) in results.into_iter().zip(fetched_blocks.iter_mut()) { + *result = Some((header.height, block_res?)); + } + debug_assert!(fetched_blocks.iter().take(most_connected_blocks.len()).all(|r| r.is_some())); + debug_assert!(fetched_blocks.is_sorted_by_key(|r| r.as_ref().map(|(height, _)| *height).unwrap_or(u32::MAX))); + + for (listener_height, listener) in chain_listeners_at_height.iter() { + // Connect blocks for this listener. + for result in fetched_blocks.iter() { + if let Some((height, block_data)) = result { + if *height > *listener_height { + match &**block_data { + BlockData::FullBlock(block) => { + listener.block_connected(&block, *height); + }, + BlockData::HeaderOnly(header_data) => { + listener.filtered_block_connected(&header_data, &[], *height); + }, + } + } + } + } + } - fn block_connected(&mut self, _block_hash: BlockHash, _block_header: ValidatedBlockHeader) { - unreachable!() + most_connected_blocks + .truncate(most_connected_blocks.len().saturating_sub(MAX_BLOCKS_AT_ONCE)); } - fn block_disconnected(&mut self, _block_hash: &BlockHash) -> Option { - None - } + Ok((header_cache, best_header)) } /// Wrapper for supporting dynamically sized chain listeners. @@ -236,33 +236,6 @@ impl<'a, L: chain::Listen + ?Sized> chain::Listen for DynamicChainListener<'a, L } } -/// A set of dynamically sized chain listeners, each paired with a starting block height. -struct ChainListenerSet<'a, L: chain::Listen + ?Sized>(Vec<(u32, &'a L)>); - -impl<'a, L: chain::Listen + ?Sized> chain::Listen for ChainListenerSet<'a, L> { - fn block_connected(&self, block: &bitcoin::Block, height: u32) { - for (starting_height, chain_listener) in self.0.iter() { - if height > *starting_height { - chain_listener.block_connected(block, height); - } - } - } - - fn filtered_block_connected( - &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32, - ) { - for (starting_height, chain_listener) in self.0.iter() { - if height > *starting_height { - chain_listener.filtered_block_connected(header, txdata, height); - } - } - } - - fn blocks_disconnected(&self, _fork_point: BestBlock) { - unreachable!() - } -} - #[cfg(test)] mod tests { use super::*; @@ -282,13 +255,12 @@ mod tests { let listener_3 = MockChainListener::new().expect_block_connected(*chain.at_height(4)); let listeners = vec![ - (chain.at_height(1).block_hash, &listener_1 as &dyn chain::Listen), - (chain.at_height(2).block_hash, &listener_2 as &dyn chain::Listen), - (chain.at_height(3).block_hash, &listener_3 as &dyn chain::Listen), + (chain.best_block_at_height(1), &listener_1 as &dyn chain::Listen), + (chain.best_block_at_height(2), &listener_2 as &dyn chain::Listen), + (chain.best_block_at_height(3), &listener_3 as &dyn chain::Listen), ]; - let mut cache = chain.header_cache(0..=4); - match synchronize_listeners(&chain, Network::Bitcoin, &mut cache, listeners).await { - Ok(header) => assert_eq!(header, chain.tip()), + match synchronize_listeners(&chain, Network::Bitcoin, listeners).await { + Ok((_, header)) => assert_eq!(header, chain.tip()), Err(e) => panic!("Unexpected error: {:?}", e), } } @@ -314,15 +286,12 @@ mod tests { .expect_block_connected(*main_chain.at_height(4)); let listeners = vec![ - (fork_chain_1.tip().block_hash, &listener_1 as &dyn chain::Listen), - (fork_chain_2.tip().block_hash, &listener_2 as &dyn chain::Listen), - (fork_chain_3.tip().block_hash, &listener_3 as &dyn chain::Listen), + (fork_chain_1.best_block(), &listener_1 as &dyn chain::Listen), + (fork_chain_2.best_block(), &listener_2 as &dyn chain::Listen), + (fork_chain_3.best_block(), &listener_3 as &dyn chain::Listen), ]; - let mut cache = fork_chain_1.header_cache(2..=4); - cache.extend(fork_chain_2.header_cache(3..=4)); - cache.extend(fork_chain_3.header_cache(4..=4)); - match synchronize_listeners(&main_chain, Network::Bitcoin, &mut cache, listeners).await { - Ok(header) => assert_eq!(header, main_chain.tip()), + match synchronize_listeners(&main_chain, Network::Bitcoin, listeners).await { + Ok((_, header)) => assert_eq!(header, main_chain.tip()), Err(e) => panic!("Unexpected error: {:?}", e), } } @@ -351,37 +320,12 @@ mod tests { .expect_block_connected(*main_chain.at_height(4)); let listeners = vec![ - (fork_chain_1.tip().block_hash, &listener_1 as &dyn chain::Listen), - (fork_chain_2.tip().block_hash, &listener_2 as &dyn chain::Listen), - (fork_chain_3.tip().block_hash, &listener_3 as &dyn chain::Listen), + (fork_chain_1.best_block(), &listener_1 as &dyn chain::Listen), + (fork_chain_2.best_block(), &listener_2 as &dyn chain::Listen), + (fork_chain_3.best_block(), &listener_3 as &dyn chain::Listen), ]; - let mut cache = fork_chain_1.header_cache(2..=4); - cache.extend(fork_chain_2.header_cache(3..=4)); - cache.extend(fork_chain_3.header_cache(4..=4)); - match synchronize_listeners(&main_chain, Network::Bitcoin, &mut cache, listeners).await { - Ok(header) => assert_eq!(header, main_chain.tip()), - Err(e) => panic!("Unexpected error: {:?}", e), - } - } - - #[tokio::test] - async fn cache_connected_and_keep_disconnected_blocks() { - let main_chain = Blockchain::default().with_height(2); - let fork_chain = main_chain.fork_at_height(1); - let new_tip = main_chain.tip(); - let old_tip = fork_chain.tip(); - - let listener = MockChainListener::new() - .expect_blocks_disconnected(*fork_chain.at_height(1)) - .expect_block_connected(*new_tip); - - let listeners = vec![(old_tip.block_hash, &listener as &dyn chain::Listen)]; - let mut cache = fork_chain.header_cache(2..=2); - match synchronize_listeners(&main_chain, Network::Bitcoin, &mut cache, listeners).await { - Ok(_) => { - assert!(cache.contains_key(&new_tip.block_hash)); - assert!(cache.contains_key(&old_tip.block_hash)); - }, + match synchronize_listeners(&main_chain, Network::Bitcoin, listeners).await { + Ok((_, header)) => assert_eq!(header, main_chain.tip()), Err(e) => panic!("Unexpected error: {:?}", e), } } diff --git a/lightning-block-sync/src/lib.rs b/lightning-block-sync/src/lib.rs index 02593047658..ab91d3f396f 100644 --- a/lightning-block-sync/src/lib.rs +++ b/lightning-block-sync/src/lib.rs @@ -16,9 +16,11 @@ #![deny(rustdoc::broken_intra_doc_links)] #![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] -#![deny(unsafe_code)] #![cfg_attr(docsrs, feature(doc_cfg))] +extern crate alloc; +extern crate core; + #[cfg(any(feature = "rest-client", feature = "rpc-client"))] pub mod http; @@ -42,6 +44,9 @@ mod test_utils; #[cfg(any(feature = "rest-client", feature = "rpc-client"))] mod utils; +#[allow(unused)] +mod async_poll; + use crate::poll::{ChainTip, Poll, ValidatedBlockHeader}; use bitcoin::block::{Block, Header}; @@ -170,18 +175,13 @@ pub enum BlockData { /// sources for the best chain tip. During this process it detects any chain forks, determines which /// constitutes the best chain, and updates the listener accordingly with any blocks that were /// connected or disconnected since the last poll. -/// -/// Block headers for the best chain are maintained in the parameterized cache, allowing for a -/// custom cache eviction policy. This offers flexibility to those sensitive to resource usage. -/// Hence, there is a trade-off between a lower memory footprint and potentially increased network -/// I/O as headers are re-fetched during fork detection. -pub struct SpvClient<'a, P: Poll, C: Cache, L: Deref> +pub struct SpvClient where L::Target: chain::Listen, { chain_tip: ValidatedBlockHeader, chain_poller: P, - chain_notifier: ChainNotifier<'a, C, L>, + chain_notifier: ChainNotifier, } /// The `Cache` trait defines behavior for managing a block header cache, where block headers are @@ -194,37 +194,83 @@ where /// Implementations may define how long to retain headers such that it's unlikely they will ever be /// needed to disconnect a block. In cases where block sources provide access to headers on stale /// forks reliably, caches may be entirely unnecessary. -pub trait Cache { +pub(crate) trait Cache { /// Retrieves the block header keyed by the given block hash. fn look_up(&self, block_hash: &BlockHash) -> Option<&ValidatedBlockHeader>; + /// Inserts the given block header during a find_difference operation, implying it might not be + /// the best header. + fn insert_during_diff(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader); + /// Called when a block has been connected to the best chain to ensure it is available to be /// disconnected later if needed. fn block_connected(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader); - /// Called when a block has been disconnected from the best chain. Once disconnected, a block's - /// header is no longer needed and thus can be removed. - fn block_disconnected(&mut self, block_hash: &BlockHash) -> Option; + /// Called when blocks have been disconnected from the best chain. Only the fork point + /// (best comon ancestor) is provided. + /// + /// Once disconnected, a block's header is no longer needed and thus can be removed. + fn blocks_disconnected(&mut self, fork_point: &ValidatedBlockHeader); } -/// Unbounded cache of block headers keyed by block hash. -pub type UnboundedCache = std::collections::HashMap; +/// Bounded cache of block headers keyed by block hash. +/// +/// Retains only the last `ANTI_REORG_DELAY * 2` block headers based on height. +pub struct HeaderCache(std::collections::HashMap); -impl Cache for UnboundedCache { +impl HeaderCache { + /// Creates a new empty header cache. + pub fn new() -> Self { + Self(std::collections::HashMap::new()) + } +} + +impl Cache for HeaderCache { fn look_up(&self, block_hash: &BlockHash) -> Option<&ValidatedBlockHeader> { - self.get(block_hash) + self.0.get(block_hash) + } + + fn insert_during_diff(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader) { + self.0.insert(block_hash, block_header); + + // Remove headers older than our newest header minus a week. + let best_height = self.0.iter().map(|(_, header)| header.height).max().unwrap_or(0); + let cutoff_height = best_height.saturating_sub(6 * 24 * 7); + self.0.retain(|_, header| header.height >= cutoff_height); } fn block_connected(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader) { - self.insert(block_hash, block_header); + self.0.insert(block_hash, block_header); + + // Remove headers older than a week. + let cutoff_height = block_header.height.saturating_sub(6 * 24 * 7); + self.0.retain(|_, header| header.height >= cutoff_height); } - fn block_disconnected(&mut self, block_hash: &BlockHash) -> Option { - self.remove(block_hash) + fn blocks_disconnected(&mut self, fork_point: &ValidatedBlockHeader) { + self.0.retain(|_, block_info| block_info.height < fork_point.height); } } -impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> +impl Cache for &mut HeaderCache { + fn look_up(&self, block_hash: &BlockHash) -> Option<&ValidatedBlockHeader> { + self.0.get(block_hash) + } + + fn insert_during_diff(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader) { + (*self).insert_during_diff(block_hash, block_header); + } + + fn block_connected(&mut self, block_hash: BlockHash, block_header: ValidatedBlockHeader) { + (*self).block_connected(block_hash, block_header); + } + + fn blocks_disconnected(&mut self, fork_point: &ValidatedBlockHeader) { + self.0.retain(|_, block_info| block_info.height < fork_point.height); + } +} + +impl SpvClient where L::Target: chain::Listen, { @@ -239,7 +285,7 @@ where /// /// [`poll_best_tip`]: SpvClient::poll_best_tip pub fn new( - chain_tip: ValidatedBlockHeader, chain_poller: P, header_cache: &'a mut C, + chain_tip: ValidatedBlockHeader, chain_poller: P, header_cache: HeaderCache, chain_listener: L, ) -> Self { let chain_notifier = ChainNotifier { header_cache, chain_listener }; @@ -293,15 +339,15 @@ where /// Notifies [listeners] of blocks that have been connected or disconnected from the chain. /// /// [listeners]: lightning::chain::Listen -pub struct ChainNotifier<'a, C: Cache, L: Deref> +pub(crate) struct ChainNotifier where L::Target: chain::Listen, { /// Cache for looking up headers before fetching from a block source. - header_cache: &'a mut C, + pub(crate) header_cache: C, /// Listener that will be notified of connected or disconnected blocks. - chain_listener: L, + pub(crate) chain_listener: L, } /// Changes made to the chain between subsequent polls that transformed it from having one chain tip @@ -315,14 +361,11 @@ struct ChainDifference { /// If there are any disconnected blocks, this is where the chain forked. common_ancestor: ValidatedBlockHeader, - /// Blocks that were disconnected from the chain since the last poll. - disconnected_blocks: Vec, - /// Blocks that were connected to the chain since the last poll. connected_blocks: Vec, } -impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> +impl ChainNotifier where L::Target: chain::Listen, { @@ -338,23 +381,65 @@ where chain_poller: &mut P, ) -> Result<(), (BlockSourceError, Option)> { let difference = self - .find_difference(new_header, old_header, chain_poller) + .find_difference_from_header(new_header, old_header, chain_poller) .await .map_err(|e| (e, None))?; - self.disconnect_blocks(difference.disconnected_blocks); + if difference.common_ancestor != *old_header { + self.disconnect_blocks(difference.common_ancestor); + } self.connect_blocks(difference.common_ancestor, difference.connected_blocks, chain_poller) .await } + /// Returns the changes needed to produce the chain with `current_header` as its tip from the + /// chain with `prev_best_block` as its tip. + /// + /// First resolves `prev_best_block` to a `ValidatedBlockHeader` using the `previous_blocks` + /// field as fallback if needed, then finds the common ancestor. + async fn find_difference_from_best_block( + &mut self, current_header: ValidatedBlockHeader, prev_best_block: BestBlock, + chain_poller: &mut P, + ) -> BlockSourceResult { + // Try to resolve the header for the previous best block. First try the block_hash, + // then fall back to previous_blocks if that fails. + let cur_tip = core::iter::once((0, &prev_best_block.block_hash)); + let prev_tips = prev_best_block.previous_blocks.iter().enumerate().filter_map(|(idx, hash_opt)| { + if let Some(block_hash) = hash_opt { + Some((idx as u32 + 1, block_hash)) + } else { + None + } + }); + let mut found_header = None; + for (height_diff, block_hash) in cur_tip.chain(prev_tips) { + if let Some(header) = self.header_cache.look_up(block_hash) { + found_header = Some(*header); + break; + } + let height = prev_best_block.height.checked_sub(height_diff).ok_or( + BlockSourceError::persistent("BestBlock had more previous_blocks than its height") + )?; + if let Ok(header) = chain_poller.get_header(block_hash, Some(height)).await { + found_header = Some(header); + self.header_cache.insert_during_diff(*block_hash, header); + break; + } + } + let found_header = found_header.ok_or_else(|| { + BlockSourceError::persistent("could not resolve any block from BestBlock") + })?; + + self.find_difference_from_header(current_header, &found_header, chain_poller).await + } + /// Returns the changes needed to produce the chain with `current_header` as its tip from the /// chain with `prev_header` as its tip. /// /// Walks backwards from `current_header` and `prev_header`, finding the common ancestor. - async fn find_difference( + async fn find_difference_from_header( &self, current_header: ValidatedBlockHeader, prev_header: &ValidatedBlockHeader, chain_poller: &mut P, ) -> BlockSourceResult { - let mut disconnected_blocks = Vec::new(); let mut connected_blocks = Vec::new(); let mut current = current_header; let mut previous = *prev_header; @@ -369,7 +454,6 @@ where let current_height = current.height; let previous_height = previous.height; if current_height <= previous_height { - disconnected_blocks.push(previous); previous = self.look_up_previous_header(chain_poller, &previous).await?; } if current_height >= previous_height { @@ -379,7 +463,7 @@ where } let common_ancestor = current; - Ok(ChainDifference { common_ancestor, disconnected_blocks, connected_blocks }) + Ok(ChainDifference { common_ancestor, connected_blocks }) } /// Returns the previous header for the given header, either by looking it up in the cache or @@ -394,16 +478,10 @@ where } /// Notifies the chain listeners of disconnected blocks. - fn disconnect_blocks(&mut self, disconnected_blocks: Vec) { - for header in disconnected_blocks.iter() { - if let Some(cached_header) = self.header_cache.block_disconnected(&header.block_hash) { - assert_eq!(cached_header, *header); - } - } - if let Some(block) = disconnected_blocks.last() { - let fork_point = BestBlock::new(block.header.prev_blockhash, block.height - 1); - self.chain_listener.blocks_disconnected(fork_point); - } + fn disconnect_blocks(&mut self, fork_point: ValidatedBlockHeader) { + self.header_cache.blocks_disconnected(&fork_point); + let best_block = BestBlock::new(fork_point.block_hash, fork_point.height); + self.chain_listener.blocks_disconnected(best_block); } /// Notifies the chain listeners of connected blocks. @@ -447,9 +525,9 @@ mod spv_client_tests { let best_tip = chain.at_height(1); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(best_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(best_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => { assert_eq!(e.kind(), BlockSourceErrorKind::Persistent); @@ -466,9 +544,9 @@ mod spv_client_tests { let common_tip = chain.tip(); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(common_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(common_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => panic!("Unexpected error: {:?}", e), Ok((chain_tip, blocks_connected)) => { @@ -486,9 +564,9 @@ mod spv_client_tests { let old_tip = chain.at_height(1); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(old_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(old_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => panic!("Unexpected error: {:?}", e), Ok((chain_tip, blocks_connected)) => { @@ -506,9 +584,9 @@ mod spv_client_tests { let old_tip = chain.at_height(1); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(old_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(old_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => panic!("Unexpected error: {:?}", e), Ok((chain_tip, blocks_connected)) => { @@ -526,9 +604,9 @@ mod spv_client_tests { let old_tip = chain.at_height(1); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(old_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(old_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => panic!("Unexpected error: {:?}", e), Ok((chain_tip, blocks_connected)) => { @@ -547,9 +625,9 @@ mod spv_client_tests { let worse_tip = chain.tip(); let poller = poll::ChainPoller::new(&mut chain, Network::Testnet); - let mut cache = UnboundedCache::new(); + let cache = HeaderCache::new(); let mut listener = NullChainListener {}; - let mut client = SpvClient::new(best_tip, poller, &mut cache, &mut listener); + let mut client = SpvClient::new(best_tip, poller, cache, &mut listener); match client.poll_best_tip().await { Err(e) => panic!("Unexpected error: {:?}", e), Ok((chain_tip, blocks_connected)) => { diff --git a/lightning-block-sync/src/poll.rs b/lightning-block-sync/src/poll.rs index 13e0403c3b6..fd8c546c56f 100644 --- a/lightning-block-sync/src/poll.rs +++ b/lightning-block-sync/src/poll.rs @@ -31,6 +31,11 @@ pub trait Poll { fn fetch_block<'a>( &'a self, header: &'a ValidatedBlockHeader, ) -> impl Future> + Send + 'a; + + /// Returns the header for a given hash and optional height hint. + fn get_header<'a>( + &'a self, block_hash: &'a BlockHash, height_hint: Option, + ) -> impl Future> + Send + 'a; } /// A chain tip relative to another chain tip in terms of block hash and chainwork. @@ -258,6 +263,14 @@ impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll ) -> impl Future> + Send + 'a { async move { self.block_source.get_block(&header.block_hash).await?.validate(header.block_hash) } } + + fn get_header<'a>( + &'a self, block_hash: &'a BlockHash, height_hint: Option, + ) -> impl Future> + Send + 'a { + Box::pin(async move { + self.block_source.get_header(block_hash, height_hint).await?.validate(*block_hash) + }) + } } #[cfg(test)] diff --git a/lightning-block-sync/src/test_utils.rs b/lightning-block-sync/src/test_utils.rs index 40788e4d08c..89cb3e81d60 100644 --- a/lightning-block-sync/src/test_utils.rs +++ b/lightning-block-sync/src/test_utils.rs @@ -1,6 +1,7 @@ use crate::poll::{Validate, ValidatedBlockHeader}; use crate::{ - BlockData, BlockHeaderData, BlockSource, BlockSourceError, BlockSourceResult, UnboundedCache, + BlockData, BlockHeaderData, BlockSource, BlockSourceError, BlockSourceResult, Cache, + HeaderCache, }; use bitcoin::block::{Block, Header, Version}; @@ -104,6 +105,18 @@ impl Blockchain { block_header.validate(block_hash).unwrap() } + pub fn best_block_at_height(&self, height: usize) -> BestBlock { + let mut previous_blocks = [None; 12]; + for (i, height) in (0..height).rev().take(12).enumerate() { + previous_blocks[i] = Some(self.blocks[height].block_hash()); + } + BestBlock { + height: height as u32, + block_hash: self.blocks[height].block_hash(), + previous_blocks, + } + } + fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData { assert!(!self.blocks.is_empty()); assert!(height < self.blocks.len()); @@ -123,16 +136,21 @@ impl Blockchain { self.at_height(self.blocks.len() - 1) } + pub fn best_block(&self) -> BestBlock { + assert!(!self.blocks.is_empty()); + self.best_block_at_height(self.blocks.len() - 1) + } + pub fn disconnect_tip(&mut self) -> Option { self.blocks.pop() } - pub fn header_cache(&self, heights: std::ops::RangeInclusive) -> UnboundedCache { - let mut cache = UnboundedCache::new(); + pub fn header_cache(&self, heights: std::ops::RangeInclusive) -> HeaderCache { + let mut cache = HeaderCache::new(); for i in heights { let value = self.at_height(i); let key = value.header.block_hash(); - assert!(cache.insert(key, value).is_none()); + cache.block_connected(key, value); } cache } diff --git a/lightning/src/chain/chainmonitor.rs b/lightning/src/chain/chainmonitor.rs index d8d6c90921f..32f36fb3a78 100644 --- a/lightning/src/chain/chainmonitor.rs +++ b/lightning/src/chain/chainmonitor.rs @@ -51,10 +51,9 @@ use crate::sign::ecdsa::EcdsaChannelSigner; use crate::sign::{EntropySource, PeerStorageKey, SignerProvider}; use crate::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard}; use crate::types::features::{InitFeatures, NodeFeatures}; -use crate::util::async_poll::{MaybeSend, MaybeSync}; use crate::util::errors::APIError; use crate::util::logger::{Logger, WithContext}; -use crate::util::native_async::FutureSpawner; +use crate::util::native_async::{FutureSpawner, MaybeSend, MaybeSync}; use crate::util::persist::{KVStore, MonitorName, MonitorUpdatingPersisterAsync}; #[cfg(peer_storage)] use crate::util::ser::{VecWriter, Writeable}; diff --git a/lightning/src/chain/channelmonitor.rs b/lightning/src/chain/channelmonitor.rs index 1d035b68650..e2ea5855576 100644 --- a/lightning/src/chain/channelmonitor.rs +++ b/lightning/src/chain/channelmonitor.rs @@ -1059,7 +1059,7 @@ impl Readable for IrrevocablyResolvedHTLC { /// You MUST ensure that no ChannelMonitors for a given channel anywhere contain out-of-date /// information and are actively monitoring the chain. /// -/// Like the [`ChannelManager`], deserialization is implemented for `(BlockHash, ChannelMonitor)`, +/// Like the [`ChannelManager`], deserialization is implemented for `(BestBlock, ChannelMonitor)`, /// providing you with the last block hash which was connected before shutting down. You must begin /// syncing the chain from that point, disconnecting and connecting blocks as required to get to /// the best chain on startup. Note that all [`ChannelMonitor`]s passed to a [`ChainMonitor`] must @@ -1067,7 +1067,7 @@ impl Readable for IrrevocablyResolvedHTLC { /// initialization. /// /// For those loading potentially-ancient [`ChannelMonitor`]s, deserialization is also implemented -/// for `Option<(BlockHash, ChannelMonitor)>`. LDK can no longer deserialize a [`ChannelMonitor`] +/// for `Option<(BestBlock, ChannelMonitor)>`. LDK can no longer deserialize a [`ChannelMonitor`] /// that was first created in LDK prior to 0.0.110 and last updated prior to LDK 0.0.119. In such /// cases, the `Option<(..)>` deserialization option may return `Ok(None)` rather than failing to /// deserialize, allowing you to differentiate between the two cases. @@ -6420,7 +6420,7 @@ where const MAX_ALLOC_SIZE: usize = 64 * 1024; impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP)> - for (BlockHash, ChannelMonitor) + for (BestBlock, ChannelMonitor) { fn read(reader: &mut R, args: (&'a ES, &'b SP)) -> Result { match >::read(reader, args) { @@ -6432,7 +6432,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP } impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP)> - for Option<(BlockHash, ChannelMonitor)> + for Option<(BestBlock, ChannelMonitor)> { #[rustfmt::skip] fn read(reader: &mut R, args: (&'a ES, &'b SP)) -> Result { @@ -6857,14 +6857,14 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP To continue, run a v0.1 release, send/route a payment over the channel or close it."); } } - Ok(Some((best_block.block_hash, monitor))) + Ok(Some((best_block, monitor))) } } #[cfg(test)] mod tests { use bitcoin::amount::Amount; - use bitcoin::hash_types::{BlockHash, Txid}; + use bitcoin::hash_types::Txid; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::Hash; use bitcoin::hex::FromHex; @@ -6964,7 +6964,7 @@ mod tests { nodes[1].chain_monitor.chain_monitor.transactions_confirmed(&new_header, &[(0, broadcast_tx)], conf_height); - let (_, pre_update_monitor) = <(BlockHash, ChannelMonitor<_>)>::read( + let (_, pre_update_monitor) = <(BestBlock, ChannelMonitor<_>)>::read( &mut io::Cursor::new(&get_monitor!(nodes[1], channel.2).encode()), (&nodes[1].keys_manager.backing, &nodes[1].keys_manager.backing)).unwrap(); diff --git a/lightning/src/chain/mod.rs b/lightning/src/chain/mod.rs index b4cc6a302ae..8ef796ecb2e 100644 --- a/lightning/src/chain/mod.rs +++ b/lightning/src/chain/mod.rs @@ -18,9 +18,10 @@ use bitcoin::network::Network; use bitcoin::script::{Script, ScriptBuf}; use bitcoin::secp256k1::PublicKey; -use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, MonitorEvent}; +use crate::chain::channelmonitor::{ + ANTI_REORG_DELAY, ChannelMonitor, ChannelMonitorUpdate, MonitorEvent, +}; use crate::chain::transaction::{OutPoint, TransactionData}; -use crate::impl_writeable_tlv_based; use crate::ln::types::ChannelId; use crate::sign::ecdsa::EcdsaChannelSigner; use crate::sign::HTLCDescriptor; @@ -42,13 +43,20 @@ pub struct BestBlock { pub block_hash: BlockHash, /// The height at which the block was confirmed. pub height: u32, + /// Previous blocks immediately before [`Self::block_hash`], in reverse chronological order. + /// + /// These ensure we can find the fork point of a reorg if our block source no longer has the + /// previous best tip after a restart. + pub previous_blocks: [Option; ANTI_REORG_DELAY as usize * 2], } impl BestBlock { /// Constructs a `BestBlock` that represents the genesis block at height 0 of the given /// network. pub fn from_network(network: Network) -> Self { - BestBlock { block_hash: genesis_block(network).header.block_hash(), height: 0 } + let block_hash = genesis_block(network).header.block_hash(); + let previous_blocks = [None; ANTI_REORG_DELAY as usize * 2]; + BestBlock { block_hash, height: 0, previous_blocks } } /// Returns a `BestBlock` as identified by the given block hash and height. @@ -56,13 +64,75 @@ impl BestBlock { /// This is not exported to bindings users directly as the bindings auto-generate an /// equivalent `new`. pub fn new(block_hash: BlockHash, height: u32) -> Self { - BestBlock { block_hash, height } + let previous_blocks = [None; ANTI_REORG_DELAY as usize * 2]; + BestBlock { block_hash, height, previous_blocks } + } + + /// Advances to a new block at height [`Self::height`] + 1. + pub fn advance(&mut self, new_hash: BlockHash) { + // Shift all block hashes to the right (making room for the old tip at index 0) + for i in (1..self.previous_blocks.len()).rev() { + self.previous_blocks[i] = self.previous_blocks[i - 1]; + } + + // The old tip becomes the new index 0 (tip-1) + self.previous_blocks[0] = Some(self.block_hash); + + // Update to the new tip + self.block_hash = new_hash; + self.height += 1; + } + + /// Returns the block hash at the given height, if available in our history. + pub fn get_hash_at_height(&self, height: u32) -> Option { + if height > self.height { + return None; + } + if height == self.height { + return Some(self.block_hash); + } + + // offset = 1 means we want tip-1, which is block_hashes[0] + // offset = 2 means we want tip-2, which is block_hashes[1], etc. + let offset = self.height.saturating_sub(height) as usize; + if offset >= 1 && offset <= self.previous_blocks.len() { + self.previous_blocks[offset - 1] + } else { + None + } + } + + /// Find the most recent common ancestor between two BestBlocks by searching their block hash + /// histories. + /// + /// Returns the common block hash and height, or None if no common block is found in the + /// available histories. + pub fn find_common_ancestor(&self, other: &BestBlock) -> Option<(BlockHash, u32)> { + // First check if either tip matches + if self.block_hash == other.block_hash && self.height == other.height { + return Some((self.block_hash, self.height)); + } + + // Check all heights covered by self's history + let min_height = self.height.saturating_sub(self.previous_blocks.len() as u32); + for check_height in (min_height..=self.height).rev() { + if let Some(self_hash) = self.get_hash_at_height(check_height) { + if let Some(other_hash) = other.get_hash_at_height(check_height) { + if self_hash == other_hash { + return Some((self_hash, check_height)); + } + } + } + } + None } } impl_writeable_tlv_based!(BestBlock, { (0, block_hash, required), + (1, previous_blocks_read, (legacy, [Option; ANTI_REORG_DELAY as usize * 2], |us: &BestBlock| Some(us.previous_blocks))), (2, height, required), + (unused, previous_blocks, (static_value, previous_blocks_read.unwrap_or([None; ANTI_REORG_DELAY as usize * 2]))), }); /// The `Listen` trait is used to notify when blocks have been connected or disconnected from the diff --git a/lightning/src/events/bump_transaction/mod.rs b/lightning/src/events/bump_transaction/mod.rs index e141d9b8abc..9c35194751f 100644 --- a/lightning/src/events/bump_transaction/mod.rs +++ b/lightning/src/events/bump_transaction/mod.rs @@ -37,7 +37,7 @@ use crate::sign::{ ChannelDerivationParameters, HTLCDescriptor, SignerProvider, P2WPKH_WITNESS_WEIGHT, }; use crate::sync::Mutex; -use crate::util::async_poll::{MaybeSend, MaybeSync}; +use crate::util::native_async::{MaybeSend, MaybeSync}; use crate::util::logger::Logger; use bitcoin::amount::Amount; diff --git a/lightning/src/events/bump_transaction/sync.rs b/lightning/src/events/bump_transaction/sync.rs index 1328c2c1b3a..03ffa479097 100644 --- a/lightning/src/events/bump_transaction/sync.rs +++ b/lightning/src/events/bump_transaction/sync.rs @@ -18,8 +18,9 @@ use crate::chain::chaininterface::BroadcasterInterface; use crate::chain::ClaimId; use crate::prelude::*; use crate::sign::SignerProvider; -use crate::util::async_poll::{dummy_waker, MaybeSend, MaybeSync}; +use crate::util::async_poll::dummy_waker; use crate::util::logger::Logger; +use crate::util::native_async::{MaybeSend, MaybeSync}; use bitcoin::{Psbt, ScriptBuf, Transaction, TxOut}; diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index 1a9af4f2071..87eec7c48aa 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -16,7 +16,7 @@ use crate::chain::chaininterface::LowerBoundedFeeEstimator; use crate::chain::chainmonitor::ChainMonitor; use crate::chain::channelmonitor::{ChannelMonitor, MonitorEvent, ANTI_REORG_DELAY}; use crate::chain::transaction::OutPoint; -use crate::chain::{ChannelMonitorUpdateStatus, Listen, Watch}; +use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Listen, Watch}; use crate::events::{ClosureReason, Event, HTLCHandlingFailureType, PaymentPurpose}; use crate::ln::channel::AnnouncementSigsState; use crate::ln::channelmanager::{PaymentId, RAACommitmentOrder, RecipientOnionFields}; @@ -93,7 +93,7 @@ fn test_monitor_and_persister_update_fail() { let chain_mon = { let new_monitor = { let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(chan.2).unwrap(); - let (_, new_monitor) = <(BlockHash, ChannelMonitor)>::read( + let (_, new_monitor) = <(BestBlock, ChannelMonitor)>::read( &mut &monitor.encode()[..], (nodes[0].keys_manager, nodes[0].keys_manager), ) diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 644920557d2..03ce7affe20 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -1941,7 +1941,6 @@ where /// detailed in the [`ChannelManagerReadArgs`] documentation. /// /// ``` -/// use bitcoin::BlockHash; /// use bitcoin::network::Network; /// use lightning::chain::BestBlock; /// # use lightning::chain::channelmonitor::ChannelMonitor; @@ -1990,8 +1989,8 @@ where /// entropy_source, node_signer, signer_provider, fee_estimator, chain_monitor, tx_broadcaster, /// router, message_router, logger, config, channel_monitors.iter().collect(), /// ); -/// let (block_hash, channel_manager) = -/// <(BlockHash, ChannelManager<_, _, _, _, _, _, _, _, _>)>::read(&mut reader, args)?; +/// let (best_block, channel_manager) = +/// <(BestBlock, ChannelManager<_, _, _, _, _, _, _, _, _>)>::read(&mut reader, args)?; /// /// // Update the ChannelManager and ChannelMonitors with the latest chain data /// // ... @@ -2557,7 +2556,7 @@ where /// [`read`], those channels will be force-closed based on the `ChannelMonitor` state and no funds /// will be lost (modulo on-chain transaction fees). /// -/// Note that the deserializer is only implemented for `(`[`BlockHash`]`, `[`ChannelManager`]`)`, which +/// Note that the deserializer is only implemented for `(`[`BestBlock`]`, `[`ChannelManager`]`)`, which /// tells you the last block hash which was connected. You should get the best block tip before using the manager. /// See [`chain::Listen`] and [`chain::Confirm`] for more details. /// @@ -2624,7 +2623,6 @@ where /// [`peer_disconnected`]: msgs::BaseMessageHandler::peer_disconnected /// [`funding_created`]: msgs::FundingCreated /// [`funding_transaction_generated`]: Self::funding_transaction_generated -/// [`BlockHash`]: bitcoin::hash_types::BlockHash /// [`update_channel`]: chain::Watch::update_channel /// [`ChannelUpdate`]: msgs::ChannelUpdate /// [`read`]: ReadableArgs::read @@ -16695,7 +16693,7 @@ impl< MR: Deref, L: Deref + Clone, > ReadableArgs> - for (BlockHash, Arc>) + for (BestBlock, Arc>) where M::Target: chain::Watch<::EcdsaSigner>, T::Target: BroadcasterInterface, @@ -16710,9 +16708,9 @@ where fn read( reader: &mut Reader, args: ChannelManagerReadArgs<'a, M, T, ES, NS, SP, F, R, MR, L>, ) -> Result { - let (blockhash, chan_manager) = - <(BlockHash, ChannelManager)>::read(reader, args)?; - Ok((blockhash, Arc::new(chan_manager))) + let (best_block, chan_manager) = + <(BestBlock, ChannelManager)>::read(reader, args)?; + Ok((best_block, Arc::new(chan_manager))) } } @@ -16728,7 +16726,7 @@ impl< MR: Deref, L: Deref + Clone, > ReadableArgs> - for (BlockHash, ChannelManager) + for (BestBlock, ChannelManager) where M::Target: chain::Watch<::EcdsaSigner>, T::Target: BroadcasterInterface, @@ -18402,7 +18400,7 @@ where //TODO: Broadcast channel update for closed channels, but only after we've made a //connection or two. - Ok((best_block_hash.clone(), channel_manager)) + Ok((best_block, channel_manager)) } } diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 271d458bcc8..57e8b5aa1cd 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -847,7 +847,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { let mon = self.chain_monitor.chain_monitor.get_monitor(channel_id).unwrap(); mon.write(&mut w).unwrap(); let (_, deserialized_monitor) = - <(BlockHash, ChannelMonitor)>::read( + <(BestBlock, ChannelMonitor)>::read( &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager), ) @@ -875,7 +875,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { let mut w = test_utils::TestVecWriter(Vec::new()); self.node.write(&mut w).unwrap(); <( - BlockHash, + BestBlock, ChannelManager< &test_utils::TestChainMonitor, &test_utils::TestBroadcaster, @@ -1332,7 +1332,7 @@ pub fn _reload_node<'a, 'b, 'c>( let mut monitors_read = Vec::with_capacity(monitors_encoded.len()); for encoded in monitors_encoded { let mut monitor_read = &encoded[..]; - let (_, monitor) = <(BlockHash, ChannelMonitor)>::read( + let (_, monitor) = <(BestBlock, ChannelMonitor)>::read( &mut monitor_read, (node.keys_manager, node.keys_manager), ) @@ -1347,7 +1347,7 @@ pub fn _reload_node<'a, 'b, 'c>( for monitor in monitors_read.iter() { assert!(channel_monitors.insert(monitor.channel_id(), monitor).is_none()); } - <(BlockHash, TestChannelManager<'b, 'c>)>::read( + <(BestBlock, TestChannelManager<'b, 'c>)>::read( &mut node_read, ChannelManagerReadArgs { config, diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index c161a9664c0..dd57fd5ecd4 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -19,6 +19,7 @@ use crate::chain::channelmonitor::{ LATENCY_GRACE_PERIOD_BLOCKS, }; use crate::chain::transaction::OutPoint; +use crate::chain::BestBlock; use crate::chain::{ChannelMonitorUpdateStatus, Confirm, Listen, Watch}; use crate::events::{ ClosureReason, Event, HTLCHandlingFailureType, PathFailure, PaymentFailureReason, @@ -7239,7 +7240,7 @@ pub fn test_update_err_monitor_lockdown() { let new_monitor = { let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(chan_1.2).unwrap(); let new_monitor = - <(BlockHash, channelmonitor::ChannelMonitor)>::read( + <(BestBlock, channelmonitor::ChannelMonitor)>::read( &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager), ) @@ -7345,7 +7346,7 @@ pub fn test_concurrent_monitor_claim() { let new_monitor = { let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(chan_1.2).unwrap(); let new_monitor = - <(BlockHash, channelmonitor::ChannelMonitor)>::read( + <(BestBlock, channelmonitor::ChannelMonitor)>::read( &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager), ) @@ -7395,7 +7396,7 @@ pub fn test_concurrent_monitor_claim() { let new_monitor = { let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(chan_1.2).unwrap(); let new_monitor = - <(BlockHash, channelmonitor::ChannelMonitor)>::read( + <(BestBlock, channelmonitor::ChannelMonitor)>::read( &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager), ) diff --git a/lightning/src/ln/reload_tests.rs b/lightning/src/ln/reload_tests.rs index 8c9e552fa04..9fb5ce8394c 100644 --- a/lightning/src/ln/reload_tests.rs +++ b/lightning/src/ln/reload_tests.rs @@ -11,7 +11,7 @@ //! Functional tests which test for correct behavior across node restarts. -use crate::chain::{ChannelMonitorUpdateStatus, Watch}; +use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Watch}; use crate::chain::chaininterface::LowerBoundedFeeEstimator; use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdateStep}; use crate::routing::router::{PaymentParameters, RouteParameters}; @@ -29,7 +29,6 @@ use crate::util::ser::{Writeable, ReadableArgs}; use crate::util::config::UserConfig; use bitcoin::hashes::Hash; -use bitcoin::hash_types::BlockHash; use types::payment::{PaymentHash, PaymentPreimage}; use crate::prelude::*; @@ -410,7 +409,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut node_0_stale_monitors = Vec::new(); for serialized in node_0_stale_monitors_serialized.iter() { let mut read = &serialized[..]; - let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, (keys_manager, keys_manager)).unwrap(); + let (_, monitor) = <(BestBlock, ChannelMonitor)>::read(&mut read, (keys_manager, keys_manager)).unwrap(); assert!(read.is_empty()); node_0_stale_monitors.push(monitor); } @@ -418,14 +417,14 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut node_0_monitors = Vec::new(); for serialized in node_0_monitors_serialized.iter() { let mut read = &serialized[..]; - let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, (keys_manager, keys_manager)).unwrap(); + let (_, monitor) = <(BestBlock, ChannelMonitor)>::read(&mut read, (keys_manager, keys_manager)).unwrap(); assert!(read.is_empty()); node_0_monitors.push(monitor); } let mut nodes_0_read = &nodes_0_serialized[..]; if let Err(msgs::DecodeError::DangerousValue) = - <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestMessageRouter, &test_utils::TestLogger>)>::read(&mut nodes_0_read, ChannelManagerReadArgs { + <(BestBlock, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestMessageRouter, &test_utils::TestLogger>)>::read(&mut nodes_0_read, ChannelManagerReadArgs { config: UserConfig::default(), entropy_source: keys_manager, node_signer: keys_manager, @@ -443,7 +442,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut nodes_0_read = &nodes_0_serialized[..]; let (_, nodes_0_deserialized_tmp) = - <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestMessageRouter, &test_utils::TestLogger>)>::read(&mut nodes_0_read, ChannelManagerReadArgs { + <(BestBlock, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestMessageRouter, &test_utils::TestLogger>)>::read(&mut nodes_0_read, ChannelManagerReadArgs { config: UserConfig::default(), entropy_source: keys_manager, node_signer: keys_manager, diff --git a/lightning/src/sign/mod.rs b/lightning/src/sign/mod.rs index 6d0d5bf405a..d562ff13406 100644 --- a/lightning/src/sign/mod.rs +++ b/lightning/src/sign/mod.rs @@ -58,7 +58,7 @@ use crate::ln::script::ShutdownScript; use crate::offers::invoice::UnsignedBolt12Invoice; use crate::types::features::ChannelTypeFeatures; use crate::types::payment::PaymentPreimage; -use crate::util::async_poll::MaybeSend; +use crate::util::native_async::MaybeSend; use crate::util::ser::{ReadableArgs, Writeable}; use crate::util::transaction_utils; diff --git a/lightning/src/sync/nostd_sync.rs b/lightning/src/sync/nostd_sync.rs index 12070741918..18055d1ebe4 100644 --- a/lightning/src/sync/nostd_sync.rs +++ b/lightning/src/sync/nostd_sync.rs @@ -61,7 +61,7 @@ impl<'a, T: 'a> LockTestExt<'a> for Mutex { } type ExclLock = MutexGuard<'a, T>; #[inline] - fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard { + fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<'a, T> { self.lock().unwrap() } } @@ -132,7 +132,7 @@ impl<'a, T: 'a> LockTestExt<'a> for RwLock { } type ExclLock = RwLockWriteGuard<'a, T>; #[inline] - fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard { + fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> { self.write().unwrap() } } diff --git a/lightning/src/util/async_poll.rs b/lightning/src/util/async_poll.rs index 9c2ca4c247f..23ca1aad603 100644 --- a/lightning/src/util/async_poll.rs +++ b/lightning/src/util/async_poll.rs @@ -15,26 +15,100 @@ use core::marker::Unpin; use core::pin::Pin; use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; -pub(crate) enum ResultFuture>, E: Unpin> { +pub(crate) enum ResultFuture + Unpin, O> { Pending(F), - Ready(Result<(), E>), + Ready(O), } -pub(crate) struct MultiResultFuturePoller> + Unpin, E: Unpin> { - futures_state: Vec>, +pub(crate) struct TwoFutureJoiner< + AO, + BO, + AF: Future + Unpin, + BF: Future + Unpin, +> { + a: Option>, + b: Option>, } -impl> + Unpin, E: Unpin> MultiResultFuturePoller { - pub fn new(futures_state: Vec>) -> Self { +impl + Unpin, BF: Future + Unpin> + TwoFutureJoiner +{ + pub fn new(future_a: AF, future_b: BF) -> Self { + Self { a: Some(ResultFuture::Pending(future_a)), b: Some(ResultFuture::Pending(future_b)) } + } +} + +impl + Unpin, BF: Future + Unpin> Future + for TwoFutureJoiner +{ + type Output = (AO, BO); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<(AO, BO)> { + let mut have_pending_futures = false; + // SAFETY: While we are pinned, we can't get direct access to our internal state because we + // aren't `Unpin`. However, we don't actually need the `Pin` - we only use it below on the + // `Future` in the `ResultFuture::Pending` case, and the `Future` is bound by `Unpin`. + // Thus, the `Pin` is not actually used, and its safe to bypass it and access the inner + // reference directly. + let state = unsafe { &mut self.get_unchecked_mut() }; + macro_rules! poll_future { + ($future: ident) => { + match state.$future { + Some(ResultFuture::Pending(ref mut fut)) => match Pin::new(fut).poll(cx) { + Poll::Ready(res) => { + state.$future = Some(ResultFuture::Ready(res)); + }, + Poll::Pending => { + have_pending_futures = true; + }, + }, + Some(ResultFuture::Ready(_)) => {}, + None => { + debug_assert!(false, "Future polled after Ready"); + return Poll::Pending; + }, + } + }; + } + poll_future!(a); + poll_future!(b); + + if have_pending_futures { + Poll::Pending + } else { + Poll::Ready(( + match state.a.take() { + Some(ResultFuture::Ready(a)) => a, + _ => unreachable!(), + }, + match state.b.take() { + Some(ResultFuture::Ready(b)) => b, + _ => unreachable!(), + }, + )) + } + } +} + +pub(crate) struct MultiResultFuturePoller + Unpin, O> { + futures_state: Vec>, +} + +impl + Unpin, O> MultiResultFuturePoller { + pub fn new(futures_state: Vec>) -> Self { Self { futures_state } } } -impl> + Unpin, E: Unpin> Future for MultiResultFuturePoller { - type Output = Vec>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { +impl + Unpin, O> Future for MultiResultFuturePoller { + type Output = Vec; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut have_pending_futures = false; - let futures_state = &mut self.get_mut().futures_state; + // SAFETY: While we are pinned, we can't get direct access to `futures_state` because we + // aren't `Unpin`. However, we don't actually need the `Pin` - we only use it below on the + // `Future` in the `ResultFuture::Pending` case, and the `Future` is bound by `Unpin`. + // Thus, the `Pin` is not actually used, and its safe to bypass it and access the inner + // reference directly. + let futures_state = unsafe { &mut self.get_unchecked_mut().futures_state }; for state in futures_state.iter_mut() { match state { ResultFuture::Pending(ref mut fut) => match Pin::new(fut).poll(cx) { @@ -90,31 +164,3 @@ const DUMMY_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( pub(crate) fn dummy_waker() -> Waker { unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE)) } } - -/// Marker trait to optionally implement `Sync` under std. -/// -/// This is not exported to bindings users as async is only supported in Rust. -#[cfg(feature = "std")] -pub use core::marker::Sync as MaybeSync; - -#[cfg(not(feature = "std"))] -/// Marker trait to optionally implement `Sync` under std. -/// -/// This is not exported to bindings users as async is only supported in Rust. -pub trait MaybeSync {} -#[cfg(not(feature = "std"))] -impl MaybeSync for T where T: ?Sized {} - -/// Marker trait to optionally implement `Send` under std. -/// -/// This is not exported to bindings users as async is only supported in Rust. -#[cfg(feature = "std")] -pub use core::marker::Send as MaybeSend; - -#[cfg(not(feature = "std"))] -/// Marker trait to optionally implement `Send` under std. -/// -/// This is not exported to bindings users as async is only supported in Rust. -pub trait MaybeSend {} -#[cfg(not(feature = "std"))] -impl MaybeSend for T where T: ?Sized {} diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index dcbea904b51..51e608d185f 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -20,7 +20,7 @@ pub mod mut_global; pub mod anchor_channel_reserves; -pub mod async_poll; +pub(crate) mod async_poll; #[cfg(fuzzing)] pub mod base32; #[cfg(not(fuzzing))] diff --git a/lightning/src/util/native_async.rs b/lightning/src/util/native_async.rs index 886146e976d..34bdb056329 100644 --- a/lightning/src/util/native_async.rs +++ b/lightning/src/util/native_async.rs @@ -8,23 +8,45 @@ //! environment. #[cfg(all(test, feature = "std"))] -use crate::sync::Mutex; -use crate::util::async_poll::{MaybeSend, MaybeSync}; +use crate::sync::{Arc, Mutex}; + +#[cfg(test)] +use alloc::boxed::Box; +#[cfg(all(test, not(feature = "std")))] +use alloc::rc::Rc; #[cfg(all(test, not(feature = "std")))] use core::cell::RefCell; +#[cfg(test)] +use core::convert::Infallible; use core::future::Future; #[cfg(test)] use core::pin::Pin; +#[cfg(test)] +use core::task::{Context, Poll}; -/// A generic trait which is able to spawn futures in the background. +/// A generic trait which is able to spawn futures to be polled in the background. +/// +/// When the spawned future completes, the returned [`Self::SpawnedFutureResult`] should resolve +/// with the output of the spawned future. +/// +/// Spawned futures must be polled independently in the background even if the returned +/// [`Self::SpawnedFutureResult`] is dropped without being polled. This matches the semantics of +/// `tokio::spawn`. /// /// This is not exported to bindings users as async is only supported in Rust. pub trait FutureSpawner: MaybeSend + MaybeSync + 'static { + /// The error type of [`Self::SpawnedFutureResult`]. This can be used to indicate that the + /// spawned future was cancelled or panicked. + type E; + /// The result of [`Self::spawn`], a future which completes when the spawned future completes. + type SpawnedFutureResult: Future> + Unpin; /// Spawns the given future as a background task. /// /// This method MUST NOT block on the given future immediately. - fn spawn + MaybeSend + 'static>(&self, future: T); + fn spawn + MaybeSend + 'static>( + &self, future: T, + ) -> Self::SpawnedFutureResult; } #[cfg(test)] @@ -32,6 +54,34 @@ trait MaybeSendableFuture: Future + MaybeSend + 'static {} #[cfg(test)] impl + MaybeSend + 'static> MaybeSendableFuture for F {} +/// Marker trait to optionally implement `Sync` under std. +/// +/// This is not exported to bindings users as async is only supported in Rust. +#[cfg(feature = "std")] +pub use core::marker::Sync as MaybeSync; + +#[cfg(not(feature = "std"))] +/// Marker trait to optionally implement `Sync` under std. +/// +/// This is not exported to bindings users as async is only supported in Rust. +pub trait MaybeSync {} +#[cfg(not(feature = "std"))] +impl MaybeSync for T where T: ?Sized {} + +/// Marker trait to optionally implement `Send` under std. +/// +/// This is not exported to bindings users as async is only supported in Rust. +#[cfg(feature = "std")] +pub use core::marker::Send as MaybeSend; + +#[cfg(not(feature = "std"))] +/// Marker trait to optionally implement `Send` under std. +/// +/// This is not exported to bindings users as async is only supported in Rust. +pub trait MaybeSend {} +#[cfg(not(feature = "std"))] +impl MaybeSend for T where T: ?Sized {} + /// A simple [`FutureSpawner`] which holds [`Future`]s until they are manually polled via /// [`Self::poll_futures`]. #[cfg(all(test, feature = "std"))] @@ -39,6 +89,69 @@ pub(crate) struct FutureQueue(Mutex>>>); #[cfg(all(test, not(feature = "std")))] pub(crate) struct FutureQueue(RefCell>>>); +#[cfg(all(test, feature = "std"))] +pub struct FutureQueueCompletion(Arc>>); +#[cfg(all(test, not(feature = "std")))] +pub struct FutureQueueCompletion(Rc>>); + +#[cfg(all(test, feature = "std"))] +impl FutureQueueCompletion { + fn new() -> Self { + Self(Arc::new(Mutex::new(None))) + } + + fn complete(&self, o: O) { + *self.0.lock().unwrap() = Some(o); + } +} + +#[cfg(all(test, feature = "std"))] +impl Clone for FutureQueueCompletion { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +#[cfg(all(test, not(feature = "std")))] +impl FutureQueueCompletion { + fn new() -> Self { + Self(Rc::new(RefCell::new(None))) + } + + fn complete(&self, o: O) { + *self.0.borrow_mut() = Some(o); + } +} + +#[cfg(all(test, not(feature = "std")))] +impl Clone for FutureQueueCompletion { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +#[cfg(all(test, feature = "std"))] +impl Future for FutureQueueCompletion { + type Output = Result; + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match Pin::into_inner(self).0.lock().unwrap().take() { + None => Poll::Pending, + Some(o) => Poll::Ready(Ok(o)), + } + } +} + +#[cfg(all(test, not(feature = "std")))] +impl Future for FutureQueueCompletion { + type Output = Result; + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match Pin::into_inner(self).0.borrow_mut().take() { + None => Poll::Pending, + Some(o) => Poll::Ready(Ok(o)), + } + } +} + #[cfg(test)] impl FutureQueue { pub(crate) fn new() -> Self { @@ -74,7 +187,6 @@ impl FutureQueue { futures = self.0.borrow_mut(); } futures.retain_mut(|fut| { - use core::task::{Context, Poll}; let waker = crate::util::async_poll::dummy_waker(); match fut.as_mut().poll(&mut Context::from_waker(&waker)) { Poll::Ready(()) => false, @@ -86,7 +198,16 @@ impl FutureQueue { #[cfg(test)] impl FutureSpawner for FutureQueue { - fn spawn + MaybeSend + 'static>(&self, future: T) { + type E = Infallible; + type SpawnedFutureResult = FutureQueueCompletion; + fn spawn + MaybeSend + 'static>( + &self, f: F, + ) -> FutureQueueCompletion { + let completion = FutureQueueCompletion::new(); + let compl_ref = completion.clone(); + let future = async move { + compl_ref.complete(f.await); + }; #[cfg(feature = "std")] { self.0.lock().unwrap().push(Box::pin(future)); @@ -95,6 +216,7 @@ impl FutureSpawner for FutureQueue { { self.0.borrow_mut().push(Box::pin(future)); } + completion } } @@ -102,7 +224,16 @@ impl FutureSpawner for FutureQueue { impl + MaybeSend + MaybeSync + 'static> FutureSpawner for D { - fn spawn + MaybeSend + 'static>(&self, future: T) { + type E = Infallible; + type SpawnedFutureResult = FutureQueueCompletion; + fn spawn + MaybeSend + 'static>( + &self, f: F, + ) -> FutureQueueCompletion { + let completion = FutureQueueCompletion::new(); + let compl_ref = completion.clone(); + let future = async move { + compl_ref.complete(f.await); + }; #[cfg(feature = "std")] { self.0.lock().unwrap().push(Box::pin(future)); @@ -111,5 +242,6 @@ impl + MaybeSend + MaybeSync + 'static { self.0.borrow_mut().push(Box::pin(future)); } + completion } } diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs index 7feb781a57a..807cf2bf9fd 100644 --- a/lightning/src/util/persist.rs +++ b/lightning/src/util/persist.rs @@ -14,8 +14,9 @@ use alloc::sync::Arc; use bitcoin::hashes::hex::FromHex; -use bitcoin::{BlockHash, Txid}; +use bitcoin::Txid; +use core::convert::Infallible; use core::future::Future; use core::mem; use core::ops::Deref; @@ -31,12 +32,15 @@ use crate::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; use crate::chain::chainmonitor::Persist; use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate}; use crate::chain::transaction::OutPoint; +use crate::chain::BestBlock; use crate::ln::types::ChannelId; use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, SignerProvider}; use crate::sync::Mutex; -use crate::util::async_poll::{dummy_waker, MaybeSend, MaybeSync}; +use crate::util::async_poll::{ + dummy_waker, MultiResultFuturePoller, ResultFuture, TwoFutureJoiner, +}; use crate::util::logger::Logger; -use crate::util::native_async::FutureSpawner; +use crate::util::native_async::{FutureSpawner, MaybeSend, MaybeSync}; use crate::util::ser::{Readable, ReadableArgs, Writeable}; use crate::util::wakers::Notifier; @@ -444,7 +448,7 @@ impl Persist( kv_store: K, entropy_source: ES, signer_provider: SP, -) -> Result::EcdsaSigner>)>, io::Error> +) -> Result::EcdsaSigner>)>, io::Error> where K::Target: KVStoreSync, ES::Target: EntropySource + Sized, @@ -456,7 +460,7 @@ where CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, )? { - match ::EcdsaSigner>)>>::read( + match ::EcdsaSigner>)>>::read( &mut io::Cursor::new(kv_store.read( CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, @@ -464,7 +468,7 @@ where )?), (&*entropy_source, &*signer_provider), ) { - Ok(Some((block_hash, channel_monitor))) => { + Ok(Some((best_block, channel_monitor))) => { let monitor_name = MonitorName::from_str(&stored_key)?; if channel_monitor.persistence_key() != monitor_name { return Err(io::Error::new( @@ -473,7 +477,7 @@ where )); } - res.push((block_hash, channel_monitor)); + res.push((best_block, channel_monitor)); }, Ok(None) => {}, Err(_) => { @@ -489,7 +493,11 @@ where struct PanicingSpawner; impl FutureSpawner for PanicingSpawner { - fn spawn + MaybeSend + 'static>(&self, _: T) { + type E = Infallible; + type SpawnedFutureResult = Box> + Unpin>; + fn spawn + MaybeSend + 'static>( + &self, _: T, + ) -> Self::SpawnedFutureResult { unreachable!(); } } @@ -569,15 +577,6 @@ fn poll_sync_future(future: F) -> F::Output { /// list channel monitors themselves and load channels individually using /// [`MonitorUpdatingPersister::read_channel_monitor_with_updates`]. /// -/// ## EXTREMELY IMPORTANT -/// -/// It is extremely important that your [`KVStoreSync::read`] implementation uses the -/// [`io::ErrorKind::NotFound`] variant correctly: that is, when a file is not found, and _only_ in -/// that circumstance (not when there is really a permissions error, for example). This is because -/// neither channel monitor reading function lists updates. Instead, either reads the monitor, and -/// using its stored `update_id`, synthesizes update storage keys, and tries them in sequence until -/// one is not found. All _other_ errors will be bubbled up in the function's [`Result`]. -/// /// # Pruning stale channel updates /// /// Stale updates are pruned when the consolidation threshold is reached according to `maximum_pending_updates`. @@ -651,14 +650,10 @@ where } /// Reads all stored channel monitors, along with any stored updates for them. - /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. pub fn read_all_channel_monitors_with_updates( &self, ) -> Result< - Vec<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + Vec<(BestBlock, ChannelMonitor<::EcdsaSigner>)>, io::Error, > { poll_sync_future(self.0.read_all_channel_monitors_with_updates()) @@ -666,10 +661,6 @@ where /// Read a single channel monitor, along with any stored updates for it. /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. - /// /// For `monitor_key`, channel storage keys can be the channel's funding [`OutPoint`], with an /// underscore `_` between txid and index for v1 channels. For example, given: /// @@ -685,7 +676,7 @@ where /// function to accomplish this. Take care to limit the number of parallel readers. pub fn read_channel_monitor_with_updates( &self, monitor_key: &str, - ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> + ) -> Result<(BestBlock, ChannelMonitor<::EcdsaSigner>), io::Error> { poll_sync_future(self.0.read_channel_monitor_with_updates(monitor_key)) } @@ -863,34 +854,81 @@ where /// Reads all stored channel monitors, along with any stored updates for them. /// - /// It is extremely important that your [`KVStore::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. + /// While the reads themselves are performend in parallel, deserializing the + /// [`ChannelMonitor`]s is not. For large [`ChannelMonitor`]s actively used for forwarding, + /// this may substantially limit the parallelism of this method. pub async fn read_all_channel_monitors_with_updates( &self, ) -> Result< - Vec<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + Vec<(BestBlock, ChannelMonitor<::EcdsaSigner>)>, io::Error, > { let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; let monitor_list = self.0.kv_store.list(primary, secondary).await?; - let mut res = Vec::with_capacity(monitor_list.len()); + let mut futures = Vec::with_capacity(monitor_list.len()); for monitor_key in monitor_list { - let result = - self.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await?; - if let Some(read_res) = result { + futures.push(ResultFuture::Pending(Box::pin(async move { + self.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await + }))); + } + let future_results = MultiResultFuturePoller::new(futures).await; + let mut res = Vec::with_capacity(future_results.len()); + for result in future_results { + if let Some(read_res) = result? { res.push(read_res); } } Ok(res) } - /// Read a single channel monitor, along with any stored updates for it. + /// Reads all stored channel monitors, along with any stored updates for them, in parallel. + /// + /// Because deserializing large [`ChannelMonitor`]s from forwarding nodes is often CPU-bound, + /// this version of [`Self::read_all_channel_monitors_with_updates`] uses the [`FutureSpawner`] + /// to parallelize deserialization as well as the IO operations. /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. + /// Because [`FutureSpawner`] requires that the spawned future be `'static` (matching `tokio` + /// and other multi-threaded runtime requirements), this method requires that `self` be an + /// `Arc` that can live for `'static` and be sent and accessed across threads. + pub async fn read_all_channel_monitors_with_updates_parallel( + self: &Arc, + ) -> Result< + Vec<(BestBlock, ChannelMonitor<::EcdsaSigner>)>, + io::Error, + > where + K: MaybeSend + MaybeSync + 'static, + L: MaybeSend + MaybeSync + 'static, + ES: MaybeSend + MaybeSync + 'static, + SP: MaybeSend + MaybeSync + 'static, + BI: MaybeSend + MaybeSync + 'static, + FE: MaybeSend + MaybeSync + 'static, + ::EcdsaSigner: MaybeSend, + { + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let monitor_list = self.0.kv_store.list(primary, secondary).await?; + let mut futures = Vec::with_capacity(monitor_list.len()); + for monitor_key in monitor_list { + let us = Arc::clone(&self); + futures.push(ResultFuture::Pending(self.0.future_spawner.spawn(async move { + us.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await + }))); + } + let future_results = MultiResultFuturePoller::new(futures).await; + let mut res = Vec::with_capacity(future_results.len()); + for result in future_results { + match result { + Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "Future was cancelled")), + Ok(Err(e)) => return Err(e), + Ok(Ok(Some(read_res))) => res.push(read_res), + Ok(Ok(None)) => {}, + } + } + Ok(res) + } + + /// Read a single channel monitor, along with any stored updates for it. /// /// For `monitor_key`, channel storage keys can be the channel's funding [`OutPoint`], with an /// underscore `_` between txid and index for v1 channels. For example, given: @@ -907,7 +945,7 @@ where /// function to accomplish this. Take care to limit the number of parallel readers. pub async fn read_channel_monitor_with_updates( &self, monitor_key: &str, - ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> + ) -> Result<(BestBlock, ChannelMonitor<::EcdsaSigner>), io::Error> { self.0.read_channel_monitor_with_updates(monitor_key).await } @@ -952,7 +990,7 @@ where let future = inner.persist_new_channel(monitor_name, monitor); let channel_id = monitor.channel_id(); let completion = (monitor.channel_id(), monitor.get_latest_update_id()); - self.0.future_spawner.spawn(async move { + let _runs_free = self.0.future_spawner.spawn(async move { match future.await { Ok(()) => { inner.async_completed_updates.lock().unwrap().push(completion); @@ -984,7 +1022,7 @@ where None }; let inner = Arc::clone(&self.0); - self.0.future_spawner.spawn(async move { + let _runs_free = self.0.future_spawner.spawn(async move { match future.await { Ok(()) => if let Some(completion) = completion { inner.async_completed_updates.lock().unwrap().push(completion); @@ -1002,7 +1040,7 @@ where pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) { let inner = Arc::clone(&self.0); - self.0.future_spawner.spawn(async move { + let _runs_free = self.0.future_spawner.spawn(async move { inner.archive_persisted_channel(monitor_name).await; }); } @@ -1027,7 +1065,7 @@ where { pub async fn read_channel_monitor_with_updates( &self, monitor_key: &str, - ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> + ) -> Result<(BestBlock, ChannelMonitor<::EcdsaSigner>), io::Error> { match self.maybe_read_channel_monitor_with_updates(monitor_key).await? { Some(res) => Ok(res), @@ -1046,32 +1084,33 @@ where async fn maybe_read_channel_monitor_with_updates( &self, monitor_key: &str, ) -> Result< - Option<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + Option<(BestBlock, ChannelMonitor<::EcdsaSigner>)>, io::Error, > { let monitor_name = MonitorName::from_str(monitor_key)?; - let read_res = self.maybe_read_monitor(&monitor_name, monitor_key).await?; - let (block_hash, monitor) = match read_res { + let read_future = pin!(self.maybe_read_monitor(&monitor_name, monitor_key)); + let list_future = pin!(self + .kv_store + .list(CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, monitor_key)); + let (read_res, list_res) = TwoFutureJoiner::new(read_future, list_future).await; + let (best_block, monitor) = match read_res? { Some(res) => res, None => return Ok(None), }; - let mut current_update_id = monitor.get_latest_update_id(); - // TODO: Parallelize this loop by speculatively reading a batch of updates - loop { - current_update_id = match current_update_id.checked_add(1) { - Some(next_update_id) => next_update_id, - None => break, - }; - let update_name = UpdateName::from(current_update_id); - let update = match self.read_monitor_update(monitor_key, &update_name).await { - Ok(update) => update, - Err(err) if err.kind() == io::ErrorKind::NotFound => { - // We can't find any more updates, so we are done. - break; - }, - Err(err) => return Err(err), - }; - + let current_update_id = monitor.get_latest_update_id(); + let updates: Result, _> = + list_res?.into_iter().map(|name| UpdateName::new(name)).collect(); + let mut updates = updates?; + updates.sort_unstable(); + let updates_to_load = updates.iter().filter(|update| update.0 > current_update_id); + let mut update_futures = Vec::with_capacity(updates_to_load.clone().count()); + for update_name in updates_to_load { + update_futures.push(ResultFuture::Pending(Box::pin(async move { + (update_name, self.read_monitor_update(monitor_key, update_name).await) + }))); + } + for (update_name, update_res) in MultiResultFuturePoller::new(update_futures).await { + let update = update_res?; monitor .update_monitor(&update, &self.broadcaster, &self.fee_estimator, &self.logger) .map_err(|e| { @@ -1085,14 +1124,14 @@ where io::Error::new(io::ErrorKind::Other, "Monitor update failed") })?; } - Ok(Some((block_hash, monitor))) + Ok(Some((best_block, monitor))) } /// Read a channel monitor. async fn maybe_read_monitor( &self, monitor_name: &MonitorName, monitor_key: &str, ) -> Result< - Option<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + Option<(BestBlock, ChannelMonitor<::EcdsaSigner>)>, io::Error, > { let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; @@ -1103,12 +1142,12 @@ where if monitor_cursor.get_ref().starts_with(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL) { monitor_cursor.set_position(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL.len() as u64); } - match ::EcdsaSigner>)>>::read( + match ::EcdsaSigner>)>>::read( &mut monitor_cursor, (&*self.entropy_source, &*self.signer_provider), ) { Ok(None) => Ok(None), - Ok(Some((blockhash, channel_monitor))) => { + Ok(Some((best_block, channel_monitor))) => { if channel_monitor.persistence_key() != *monitor_name { log_error!( self.logger, @@ -1120,7 +1159,7 @@ where "ChannelMonitor was stored under the wrong key", )) } else { - Ok(Some((blockhash, channel_monitor))) + Ok(Some((best_block, channel_monitor))) } }, Err(e) => { @@ -1299,7 +1338,7 @@ where async fn archive_persisted_channel(&self, monitor_name: MonitorName) { let monitor_key = monitor_name.to_string(); let monitor = match self.read_channel_monitor_with_updates(&monitor_key).await { - Ok((_block_hash, monitor)) => monitor, + Ok((_best_block, monitor)) => monitor, Err(_) => return, }; let primary = ARCHIVED_CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; @@ -1458,7 +1497,7 @@ impl core::fmt::Display for MonitorName { /// let monitor_name = "some_monitor_name"; /// let storage_key = format!("channel_monitor_updates/{}/{}", monitor_name, update_name.as_str()); /// ``` -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct UpdateName(pub u64, String); impl UpdateName { diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index f821aa5afc0..2bd8a9eb879 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -1448,6 +1448,33 @@ impl Readable for BlockHash { } } +impl Writeable for [Option; 12] { + fn write(&self, w: &mut W) -> Result<(), io::Error> { + for hash_opt in self { + match hash_opt { + Some(hash) => hash.write(w)?, + None => ([0u8; 32]).write(w)?, + } + } + Ok(()) + } +} + +impl Readable for [Option; 12] { + fn read(r: &mut R) -> Result { + use bitcoin::hashes::Hash; + + let mut res = [None; 12]; + for hash_opt in res.iter_mut() { + let buf: [u8; 32] = Readable::read(r)?; + if buf != [0; 32] { + *hash_opt = Some(BlockHash::from_slice(&buf[..]).unwrap()); + } + } + Ok(res) + } +} + impl Writeable for ChainHash { fn write(&self, w: &mut W) -> Result<(), io::Error> { w.write_all(self.as_bytes()) diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index b4db17bee20..8e66b8b0dfb 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -20,6 +20,7 @@ use crate::chain::channelmonitor::{ ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateStep, MonitorEvent, }; use crate::chain::transaction::OutPoint; +use crate::chain::BestBlock; use crate::chain::WatchedOutput; use crate::events::bump_transaction::sync::WalletSourceSync; use crate::events::bump_transaction::Utxo; @@ -50,7 +51,6 @@ use crate::sign::{self, ReceiveAuthKey}; use crate::sign::{ChannelSigner, PeerStorageKey}; use crate::sync::RwLock; use crate::types::features::{ChannelFeatures, InitFeatures, NodeFeatures}; -use crate::util::async_poll::MaybeSend; use crate::util::config::UserConfig; use crate::util::dyn_signer::{ DynKeysInterface, DynKeysInterfaceTrait, DynPhantomKeysInterface, DynSigner, @@ -58,6 +58,7 @@ use crate::util::dyn_signer::{ use crate::util::logger::{Logger, Record}; #[cfg(feature = "std")] use crate::util::mut_global::MutGlobal; +use crate::util::native_async::MaybeSend; use crate::util::persist::{KVStore, KVStoreSync, MonitorName}; use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use crate::util::test_channel_signer::{EnforcementState, TestChannelSigner}; @@ -66,7 +67,7 @@ use bitcoin::amount::Amount; use bitcoin::block::Block; use bitcoin::constants::genesis_block; use bitcoin::constants::ChainHash; -use bitcoin::hash_types::{BlockHash, Txid}; +use bitcoin::hash_types::Txid; use bitcoin::hashes::Hash; use bitcoin::network::Network; use bitcoin::script::{Builder, Script, ScriptBuf}; @@ -533,7 +534,7 @@ impl<'a> TestChainMonitor<'a> { // underlying `ChainMonitor`. let mut w = TestVecWriter(Vec::new()); monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, ChannelMonitor)>::read( + let new_monitor = <(BestBlock, ChannelMonitor)>::read( &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager), ) @@ -565,7 +566,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { // monitor to a serialized copy and get he same one back. let mut w = TestVecWriter(Vec::new()); monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, ChannelMonitor)>::read( + let new_monitor = <(BestBlock, ChannelMonitor)>::read( &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager), ) @@ -621,7 +622,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { let monitor = self.chain_monitor.get_monitor(channel_id).unwrap(); w.0.clear(); monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, ChannelMonitor)>::read( + let new_monitor = <(BestBlock, ChannelMonitor)>::read( &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager), )