From 771c1472f798e788197489c2b23d5db750dd02ae Mon Sep 17 00:00:00 2001 From: benthecarman Date: Wed, 1 Apr 2026 14:42:10 -0500 Subject: [PATCH 1/2] Gate SQLite behind a default feature flag Make rusqlite an optional dependency behind the new "sqlite" feature, which is enabled by default. This allows users who use an alternative storage backend to avoid compiling the bundled SQLite C library. Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.toml | 5 ++-- src/builder.rs | 7 ++++- src/io/mod.rs | 1 + tests/common/mod.rs | 45 +++++++++++++++++++++++++++------ tests/integration_tests_rust.rs | 2 ++ tests/reorg_test.rs | 2 ++ 6 files changed, 51 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 539941677..29df6f20a 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,8 @@ codegen-units = 1 # Reduce number of codegen units to increase optimizations. panic = 'abort' # Abort on panic [features] -default = [] +default = ["sqlite"] +sqlite = ["dep:rusqlite"] [dependencies] #lightning = { version = "0.2.0", features = ["std"] } @@ -58,7 +59,7 @@ bdk_wallet = { version = "2.3.0", default-features = false, features = ["std", " bitreq = { version = "0.3", default-features = false, features = ["async-https", "json-using-serde"] } rustls = { version = "0.23", default-features = false } -rusqlite = { version = "0.31.0", features = ["bundled"] } +rusqlite = { version = "0.31.0", features = ["bundled"], optional = true } bitcoin = "0.32.7" bip39 = { version = "2.0.0", features = ["rand"] } bip21 = { version = "0.5", features = ["std"], default-features = false } diff --git a/src/builder.rs b/src/builder.rs index cd8cc184f..3f092f85e 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -53,6 +53,9 @@ use crate::entropy::NodeEntropy; use crate::event::EventQueue; use crate::fee_estimator::OnchainFeeEstimator; use crate::gossip::GossipSource; +#[cfg(feature = "sqlite")] +use crate::io; +#[cfg(feature = "sqlite")] use crate::io::sqlite_store::SqliteStore; use crate::io::utils::{ read_event_queue, read_external_pathfinding_scores_from_cache, read_network_graph, @@ -61,7 +64,7 @@ use crate::io::utils::{ }; use crate::io::vss_store::VssStoreBuilder; use crate::io::{ - self, PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, + PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, PENDING_PAYMENT_INFO_PERSISTENCE_PRIMARY_NAMESPACE, PENDING_PAYMENT_INFO_PERSISTENCE_SECONDARY_NAMESPACE, }; @@ -616,6 +619,7 @@ impl NodeBuilder { /// Builds a [`Node`] instance with a [`SqliteStore`] backend and according to the options /// previously configured. + #[cfg(feature = "sqlite")] pub fn build(&self, node_entropy: NodeEntropy) -> Result { let storage_dir_path = self.config.storage_dir_path.clone(); fs::create_dir_all(storage_dir_path.clone()) @@ -1083,6 +1087,7 @@ impl ArcedNodeBuilder { /// Builds a [`Node`] instance with a [`SqliteStore`] backend and according to the options /// previously configured. + #[cfg(feature = "sqlite")] pub fn build(&self, node_entropy: Arc) -> Result, BuildError> { self.inner.read().unwrap().build(*node_entropy).map(Arc::new) } diff --git a/src/io/mod.rs b/src/io/mod.rs index e080d39f7..4db1dbe21 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -7,6 +7,7 @@ //! Objects and traits for data persistence. +#[cfg(feature = "sqlite")] pub mod sqlite_store; #[cfg(test)] pub(crate) mod test_utils; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4f68f9825..1c112b379 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -29,6 +29,7 @@ use electrsd::{corepc_node, ElectrsD}; use electrum_client::ElectrumApi; use ldk_node::config::{AsyncPaymentsRole, Config, ElectrumSyncConfig, EsploraSyncConfig}; use ldk_node::entropy::{generate_entropy_mnemonic, NodeEntropy}; +#[cfg(feature = "sqlite")] use ldk_node::io::sqlite_store::SqliteStore; use ldk_node::payment::{PaymentDirection, PaymentKind, PaymentStatus}; use ldk_node::{ @@ -329,6 +330,7 @@ pub(crate) enum TestChainSource<'a> { #[derive(Clone, Copy)] pub(crate) enum TestStoreType { TestSyncStore, + #[cfg(feature = "sqlite")] Sqlite, } @@ -486,6 +488,7 @@ pub(crate) fn setup_node(chain_source: &TestChainSource, config: TestConfig) -> let kv_store = TestSyncStore::new(config.node_config.storage_dir_path.into()); builder.build_with_store(config.node_entropy.into(), kv_store).unwrap() }, + #[cfg(feature = "sqlite")] TestStoreType::Sqlite => builder.build(config.node_entropy.into()).unwrap(), }; @@ -1519,6 +1522,7 @@ struct TestSyncStoreInner { serializer: RwLock<()>, test_store: TestStore, fs_store: FilesystemStore, + #[cfg(feature = "sqlite")] sqlite_store: SqliteStore, } @@ -1528,8 +1532,11 @@ impl TestSyncStoreInner { let mut fs_dir = dest_dir.clone(); fs_dir.push("fs_store"); let fs_store = FilesystemStore::new(fs_dir); + #[cfg(feature = "sqlite")] let mut sql_dir = dest_dir.clone(); + #[cfg(feature = "sqlite")] sql_dir.push("sqlite_store"); + #[cfg(feature = "sqlite")] let sqlite_store = SqliteStore::new( sql_dir, Some("test_sync_db".to_string()), @@ -1537,24 +1544,34 @@ impl TestSyncStoreInner { ) .unwrap(); let test_store = TestStore::new(false); - Self { serializer, fs_store, sqlite_store, test_store } + #[cfg(feature = "sqlite")] + { + return Self { serializer, fs_store, sqlite_store, test_store }; + } + #[cfg(not(feature = "sqlite"))] + { + Self { serializer, fs_store, test_store } + } } fn do_list( &self, primary_namespace: &str, secondary_namespace: &str, ) -> lightning::io::Result> { let fs_res = KVStoreSync::list(&self.fs_store, primary_namespace, secondary_namespace); - let sqlite_res = - KVStoreSync::list(&self.sqlite_store, primary_namespace, secondary_namespace); + #[cfg(feature = "sqlite")] + let sqlite_res = KVStoreSync::list(&self.sqlite_store, primary_namespace, secondary_namespace); let test_res = KVStoreSync::list(&self.test_store, primary_namespace, secondary_namespace); match fs_res { Ok(mut list) => { list.sort(); - let mut sqlite_list = sqlite_res.unwrap(); - sqlite_list.sort(); - assert_eq!(list, sqlite_list); + #[cfg(feature = "sqlite")] + { + let mut sqlite_list = sqlite_res.unwrap(); + sqlite_list.sort(); + assert_eq!(list, sqlite_list); + } let mut test_list = test_res.unwrap(); test_list.sort(); @@ -1563,6 +1580,7 @@ impl TestSyncStoreInner { Ok(list) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) @@ -1576,6 +1594,7 @@ impl TestSyncStoreInner { let _guard = self.serializer.read().unwrap(); let fs_res = KVStoreSync::read(&self.fs_store, primary_namespace, secondary_namespace, key); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::read(&self.sqlite_store, primary_namespace, secondary_namespace, key); let test_res = @@ -1583,13 +1602,17 @@ impl TestSyncStoreInner { match fs_res { Ok(read) => { + #[cfg(feature = "sqlite")] assert_eq!(read, sqlite_res.unwrap()); assert_eq!(read, test_res.unwrap()); Ok(read) }, Err(e) => { - assert!(sqlite_res.is_err()); - assert_eq!(e.kind(), unsafe { sqlite_res.unwrap_err_unchecked().kind() }); + #[cfg(feature = "sqlite")] + { + assert!(sqlite_res.is_err()); + assert_eq!(e.kind(), unsafe { sqlite_res.unwrap_err_unchecked().kind() }); + } assert!(test_res.is_err()); assert_eq!(e.kind(), unsafe { test_res.unwrap_err_unchecked().kind() }); Err(e) @@ -1608,6 +1631,7 @@ impl TestSyncStoreInner { key, buf.clone(), ); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::write( &self.sqlite_store, primary_namespace, @@ -1630,11 +1654,13 @@ impl TestSyncStoreInner { match fs_res { Ok(()) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_ok()); assert!(test_res.is_ok()); Ok(()) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) @@ -1648,6 +1674,7 @@ impl TestSyncStoreInner { let _guard = self.serializer.write().unwrap(); let fs_res = KVStoreSync::remove(&self.fs_store, primary_namespace, secondary_namespace, key, lazy); + #[cfg(feature = "sqlite")] let sqlite_res = KVStoreSync::remove( &self.sqlite_store, primary_namespace, @@ -1670,11 +1697,13 @@ impl TestSyncStoreInner { match fs_res { Ok(()) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_ok()); assert!(test_res.is_ok()); Ok(()) }, Err(e) => { + #[cfg(feature = "sqlite")] assert!(sqlite_res.is_err()); assert!(test_res.is_err()); Err(e) diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index 413b2d44a..711bd4bc3 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -5,6 +5,8 @@ // http://opensource.org/licenses/MIT>, at your option. You may not use this file except in // accordance with one or both of these licenses. +#![cfg(feature = "sqlite")] + mod common; use std::collections::HashSet; diff --git a/tests/reorg_test.rs b/tests/reorg_test.rs index 295d9fdd2..74892d831 100644 --- a/tests/reorg_test.rs +++ b/tests/reorg_test.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "sqlite")] + mod common; use std::collections::HashMap; From 8becefec50fa07dc70d6e0f1306d2a8f9af5b2e9 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Wed, 1 Apr 2026 14:43:45 -0500 Subject: [PATCH 2/2] Add PostgreSQL storage backend Add a PostgresStore implementation behind the "postgres" feature flag, mirroring the existing SqliteStore. Uses tokio-postgres (async-native) with an internal tokio runtime for the sync KVStoreSync trait, following the VssStore pattern. Includes unit tests, integration tests (channel full cycle and node restart), and a CI workflow that runs both against a PostgreSQL service container. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/postgres-integration.yml | 42 + Cargo.toml | 2 + src/builder.rs | 34 + src/io/mod.rs | 2 + src/io/postgres_store/migrations.rs | 21 + src/io/postgres_store/mod.rs | 986 +++++++++++++++++++++ tests/integration_tests_postgres.rs | 128 +++ 7 files changed, 1215 insertions(+) create mode 100644 .github/workflows/postgres-integration.yml create mode 100644 src/io/postgres_store/migrations.rs create mode 100644 src/io/postgres_store/mod.rs create mode 100644 tests/integration_tests_postgres.rs diff --git a/.github/workflows/postgres-integration.yml b/.github/workflows/postgres-integration.yml new file mode 100644 index 000000000..123e52a25 --- /dev/null +++ b/.github/workflows/postgres-integration.yml @@ -0,0 +1,42 @@ +name: CI Checks - PostgreSQL Integration Tests + +on: [push, pull_request] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:latest + ports: + - 5432:5432 + env: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Install Rust stable toolchain + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile=minimal --default-toolchain stable + - name: Run PostgreSQL store tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: cargo test --features postgres io::postgres_store + - name: Run PostgreSQL integration tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: | + RUSTFLAGS="--cfg no_download --cfg cycle_tests" cargo test --features postgres --test integration_tests_postgres diff --git a/Cargo.toml b/Cargo.toml index 29df6f20a..d4e63f9d4 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ panic = 'abort' # Abort on panic [features] default = ["sqlite"] sqlite = ["dep:rusqlite"] +postgres = ["dep:tokio-postgres"] [dependencies] #lightning = { version = "0.2.0", features = ["std"] } @@ -77,6 +78,7 @@ serde_json = { version = "1.0.128", default-features = false, features = ["std"] log = { version = "0.4.22", default-features = false, features = ["std"]} async-trait = { version = "0.1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false, features = ["runtime"], optional = true } vss-client = { package = "vss-client-ng", version = "0.5" } prost = { version = "0.11.6", default-features = false} #bitcoin-payment-instructions = { version = "0.6" } diff --git a/src/builder.rs b/src/builder.rs index 3f092f85e..72b7a17c0 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -633,6 +633,24 @@ impl NodeBuilder { self.build_with_store(node_entropy, kv_store) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// The given `kv_table_name` will be used or default to + /// [`DEFAULT_KV_TABLE_NAME`](crate::io::postgres_store::DEFAULT_KV_TABLE_NAME). + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: NodeEntropy, connection_string: &str, kv_table_name: Option, + ) -> Result { + let kv_store = + crate::io::postgres_store::PostgresStore::new(connection_string, kv_table_name) + .map_err(|_| BuildError::KVStoreSetupFailed)?; + self.build_with_store(node_entropy, kv_store) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store(&self, node_entropy: NodeEntropy) -> Result { @@ -1092,6 +1110,22 @@ impl ArcedNodeBuilder { self.inner.read().unwrap().build(*node_entropy).map(Arc::new) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: Arc, connection_string: String, + kv_table_name: Option, + ) -> Result, BuildError> { + self.inner + .read() + .unwrap() + .build_with_postgres_store(*node_entropy, &connection_string, kv_table_name) + .map(Arc::new) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store( diff --git a/src/io/mod.rs b/src/io/mod.rs index 4db1dbe21..d2cad17d3 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -7,6 +7,8 @@ //! Objects and traits for data persistence. +#[cfg(feature = "postgres")] +pub mod postgres_store; #[cfg(feature = "sqlite")] pub mod sqlite_store; #[cfg(test)] diff --git a/src/io/postgres_store/migrations.rs b/src/io/postgres_store/migrations.rs new file mode 100644 index 000000000..c9add1c57 --- /dev/null +++ b/src/io/postgres_store/migrations.rs @@ -0,0 +1,21 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use lightning::io; +use tokio_postgres::Client; + +pub(super) async fn migrate_schema( + _client: &Client, _kv_table_name: &str, from_version: u16, to_version: u16, +) -> io::Result<()> { + assert!(from_version < to_version); + // Future migrations go here, e.g.: + // if from_version == 1 && to_version >= 2 { + // migrate_v1_to_v2(client, kv_table_name).await?; + // from_version = 2; + // } + Ok(()) +} diff --git a/src/io/postgres_store/mod.rs b/src/io/postgres_store/mod.rs new file mode 100644 index 000000000..7b7f78879 --- /dev/null +++ b/src/io/postgres_store/mod.rs @@ -0,0 +1,986 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +//! Objects related to [`PostgresStore`] live here. +use std::collections::HashMap; +use std::future::Future; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use lightning::io; +use lightning::util::persist::{ + KVStore, KVStoreSync, PageToken, PaginatedKVStore, PaginatedKVStoreSync, PaginatedListResponse, +}; +use lightning_types::string::PrintableString; +use tokio_postgres::NoTls; + +use crate::io::utils::check_namespace_key_validity; + +mod migrations; + +/// The default table in which we store all data. +pub const DEFAULT_KV_TABLE_NAME: &str = "ldk_data"; + +// The current schema version for the PostgreSQL store. +const SCHEMA_VERSION: u16 = 1; + +// The number of entries returned per page in paginated list operations. +const PAGE_SIZE: usize = 50; + +// The number of worker threads for the internal runtime used by sync operations. +const INTERNAL_RUNTIME_WORKERS: usize = 2; + +/// A [`KVStoreSync`] implementation that writes to and reads from a [PostgreSQL] database. +/// +/// [PostgreSQL]: https://www.postgresql.org +pub struct PostgresStore { + inner: Arc, + + // Version counter to ensure that writes are applied in the correct order. It is assumed that read and list + // operations aren't sensitive to the order of execution. + next_write_version: AtomicU64, + + // An internal runtime we use to avoid any deadlocks we could hit when waiting on async + // operations to finish from a sync context. + internal_runtime: Option, +} + +// tokio::sync::Mutex (used for the DB client) contains UnsafeCell which opts out of +// RefUnwindSafe. std::sync::Mutex (used by SqliteStore) doesn't have this issue because +// it poisons on panic. This impl is needed for do_read_write_remove_list_persist which +// requires K: KVStoreSync + RefUnwindSafe. +impl std::panic::RefUnwindSafe for PostgresStore {} + +impl PostgresStore { + /// Constructs a new [`PostgresStore`]. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// + /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`]. + pub fn new(connection_string: &str, kv_table_name: Option) -> io::Result { + let internal_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("ldk-node-pg-runtime-{id}") + }) + .worker_threads(INTERNAL_RUNTIME_WORKERS) + .max_blocking_threads(INTERNAL_RUNTIME_WORKERS) + .build() + .unwrap(); + + let connection_string = connection_string.to_string(); + let inner = tokio::task::block_in_place(|| { + internal_runtime.block_on(async { + PostgresStoreInner::new(&connection_string, kv_table_name).await + }) + })?; + + let inner = Arc::new(inner); + let next_write_version = AtomicU64::new(1); + Ok(Self { inner, next_write_version, internal_runtime: Some(internal_runtime) }) + } + + fn build_locking_key( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> String { + format!("{primary_namespace}#{secondary_namespace}#{key}") + } + + fn get_new_version_and_lock_ref( + &self, locking_key: String, + ) -> (Arc>, u64) { + let version = self.next_write_version.fetch_add(1, Ordering::Relaxed); + if version == u64::MAX { + panic!("PostgresStore version counter overflowed"); + } + + let inner_lock_ref = self.inner.get_inner_lock_ref(locking_key); + + (inner_lock_ref, version) + } +} + +impl Drop for PostgresStore { + fn drop(&mut self) { + let internal_runtime = self.internal_runtime.take(); + tokio::task::block_in_place(move || drop(internal_runtime)); + } +} + +impl KVStore for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.read_internal(&primary_namespace, &secondary_namespace, &key).await } + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + } + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + } + } + + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await } + } +} + +impl KVStoreSync for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner.read_internal(&primary_namespace, &secondary_namespace, &key).await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStoreSync for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStore for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> impl Future> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + } + } +} + +struct PostgresStoreInner { + client: tokio::sync::Mutex, + kv_table_name: String, + write_version_locks: Mutex>>>, + next_sort_order: AtomicI64, +} + +impl PostgresStoreInner { + async fn new(connection_string: &str, kv_table_name: Option) -> io::Result { + let kv_table_name = kv_table_name.unwrap_or(DEFAULT_KV_TABLE_NAME.to_string()); + + let (client, connection) = + tokio_postgres::connect(connection_string, NoTls).await.map_err(|e| { + let msg = format!("Failed to connect to PostgreSQL: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Spawn the connection task so it runs in the background. + tokio::spawn(async move { + if let Err(e) = connection.await { + log::error!("PostgreSQL connection error: {e}"); + } + }); + + // Create the KV data table if it doesn't exist. + let sql = format!( + "CREATE TABLE IF NOT EXISTS {kv_table_name} ( + primary_namespace TEXT NOT NULL, + secondary_namespace TEXT NOT NULL DEFAULT '', + key TEXT NOT NULL CHECK (key <> ''), + value BYTEA, + sort_order BIGINT NOT NULL DEFAULT 0, + PRIMARY KEY (primary_namespace, secondary_namespace, key) + )" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Read the schema version from the table comment (analogous to SQLite's PRAGMA user_version). + let sql = format!("SELECT obj_description('{kv_table_name}'::regclass, 'pg_class')"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read schema version for {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let version_res: u16 = match row.get::<_, Option<&str>>(0) { + Some(version_str) => version_str.parse().map_err(|_| { + let msg = format!("Invalid schema version: {version_str}"); + io::Error::new(io::ErrorKind::Other, msg) + })?, + None => 0, + }; + + if version_res == 0 { + // New table, set our SCHEMA_VERSION. + let sql = format!("COMMENT ON TABLE {kv_table_name} IS '{SCHEMA_VERSION}'"); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to set schema version: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + } else if version_res < SCHEMA_VERSION { + migrations::migrate_schema(&client, &kv_table_name, version_res, SCHEMA_VERSION) + .await?; + } else if version_res > SCHEMA_VERSION { + let msg = format!( + "Failed to open database: incompatible schema version {version_res}. Expected: {SCHEMA_VERSION}" + ); + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + + // Create composite index for paginated listing. + let sql = format!( + "CREATE INDEX IF NOT EXISTS idx_{kv_table_name}_paginated ON {kv_table_name} (primary_namespace, secondary_namespace, sort_order DESC, key ASC)" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create index on table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Initialize next_sort_order from the max existing value. + let sql = format!("SELECT COALESCE(MAX(sort_order), 0) FROM {kv_table_name}"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read max sort_order from {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let max_sort_order: i64 = row.get(0); + let next_sort_order = AtomicI64::new(max_sort_order + 1); + + let client = tokio::sync::Mutex::new(client); + let write_version_locks = Mutex::new(HashMap::new()); + Ok(Self { client, kv_table_name, write_version_locks, next_sort_order }) + } + + fn get_inner_lock_ref(&self, locking_key: String) -> Arc> { + let mut outer_lock = self.write_version_locks.lock().unwrap(); + Arc::clone(&outer_lock.entry(locking_key).or_default()) + } + + async fn read_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "read")?; + + let locked_client = self.client.lock().await; + let sql = format!( + "SELECT value FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + let row = locked_client + .query_opt(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to read from key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + match row { + Some(row) => { + let value: Vec = row.get(0); + Ok(value) + }, + None => { + let msg = format!( + "Failed to read as key could not be found: {}/{}/{}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + ); + Err(io::Error::new(io::ErrorKind::NotFound, msg)) + }, + } + } + + async fn write_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "write")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let locked_client = self.client.lock().await; + + let sort_order = self.next_sort_order.fetch_add(1, Ordering::Relaxed); + + let sql = format!( + "INSERT INTO {} (primary_namespace, secondary_namespace, key, value, sort_order) \ + VALUES ($1, $2, $3, $4, $5) \ + ON CONFLICT (primary_namespace, secondary_namespace, key) DO UPDATE SET value = EXCLUDED.value", + self.kv_table_name + ); + + locked_client + .execute( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &key, + &buf, + &sort_order, + ], + ) + .await + .map(|_| ()) + .map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + }) + }) + .await + } + + async fn remove_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "remove")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let locked_client = self.client.lock().await; + + let sql = format!( + "DELETE FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + locked_client + .execute(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to delete key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + Ok(()) + }) + .await + } + + async fn list_internal( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; + + let locked_client = self.client.lock().await; + + let sql = format!( + "SELECT key FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + let keys: Vec = rows.iter().map(|row| row.get(0)).collect(); + Ok(keys) + } + + async fn list_paginated_internal( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + check_namespace_key_validity( + primary_namespace, + secondary_namespace, + None, + "list_paginated", + )?; + + let locked_client = self.client.lock().await; + + // Fetch one extra row beyond PAGE_SIZE to determine whether a next page exists. + let fetch_limit = (PAGE_SIZE + 1) as i64; + + let mut entries: Vec<(String, i64)> = match page_token { + Some(ref token) => { + let token_sort_order: i64 = token.as_str().parse().map_err(|_| { + let token_str = token.as_str(); + let msg = format!("Invalid page token: {token_str}"); + io::Error::new(io::ErrorKind::InvalidInput, msg) + })?; + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + AND sort_order < $3 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $4", + self.kv_table_name + ); + + let rows = locked_client + .query( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &token_sort_order, + &fetch_limit, + ], + ) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.iter().map(|row| (row.get(0), row.get(1))).collect() + }, + None => { + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $3", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace, &fetch_limit]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.into_iter().map(|row| (row.get(0), row.get(1))).collect() + }, + }; + + let has_more = entries.len() > PAGE_SIZE; + entries.truncate(PAGE_SIZE); + + let next_page_token = if has_more { + let (_, last_sort_order) = *entries.last().expect("must be non-empty"); + Some(PageToken::new(last_sort_order.to_string())) + } else { + None + }; + + let keys = entries.into_iter().map(|(k, _)| k).collect(); + Ok(PaginatedListResponse { keys, next_page_token }) + } + + async fn execute_locked_write>, FN: FnOnce() -> F>( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + callback: FN, + ) -> Result<(), io::Error> { + let res = { + let mut last_written_version = inner_lock_ref.lock().await; + + // Check if we already have a newer version written/removed. This is used in async contexts to realize eventual + // consistency. + let is_stale_version = version <= *last_written_version; + + // If the version is not stale, we execute the callback. Otherwise, we can and must skip writing. + if is_stale_version { + Ok(()) + } else { + callback().await.map(|_| { + *last_written_version = version; + }) + } + }; + + self.clean_locks(&inner_lock_ref, locking_key); + + res + } + + fn clean_locks(&self, inner_lock_ref: &Arc>, locking_key: String) { + // If there are no arcs in use elsewhere, this means that there are no in-flight writes. We can remove the map + // entry to prevent leaking memory. The two arcs that are expected are the one in the map and the one held here + // in inner_lock_ref. The outer lock is obtained first, to avoid a new arc being cloned after we've already + // counted. + let mut outer_lock = self.write_version_locks.lock().unwrap(); + + let strong_count = Arc::strong_count(inner_lock_ref); + debug_assert!(strong_count >= 2, "Unexpected PostgresStore strong count"); + + if strong_count == 2 { + outer_lock.remove(&locking_key); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::io::test_utils::{do_read_write_remove_list_persist, do_test_store}; + + fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) + } + + fn create_test_store(table_name: &str) -> PostgresStore { + PostgresStore::new(&test_connection_string(), Some(table_name.to_string())).unwrap() + } + + fn cleanup_store(store: &PostgresStore) { + if let Some(ref runtime) = store.internal_runtime { + let kv_table = store.inner.kv_table_name.clone(); + let inner = Arc::clone(&store.inner); + let _ = tokio::task::block_in_place(|| { + runtime.block_on(async { + let client = inner.client.lock().await; + let _ = client.execute(&format!("DROP TABLE IF EXISTS {kv_table}"), &[]).await; + }) + }); + } + } + + #[test] + fn read_write_remove_list_persist() { + let store = create_test_store("test_rwrl"); + do_read_write_remove_list_persist(&store); + cleanup_store(&store); + } + + #[test] + fn test_postgres_store() { + let store_0 = create_test_store("test_pg_store_0"); + let store_1 = create_test_store("test_pg_store_1"); + do_test_store(&store_0, &store_1); + cleanup_store(&store_0); + cleanup_store(&store_1); + } + + #[test] + fn test_postgres_store_paginated_listing() { + let store = create_test_store("test_pg_paginated"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + let num_entries = 225; + + for i in 0..num_entries { + let key = format!("key_{:04}", i); + let data = vec![i as u8; 32]; + KVStoreSync::write(&store, primary_namespace, secondary_namespace, &key, data).unwrap(); + } + + // Paginate through all entries and collect them + let mut all_keys = Vec::new(); + let mut page_token = None; + let mut page_count = 0; + + loop { + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + page_token, + ) + .unwrap(); + + all_keys.extend(response.keys.clone()); + page_count += 1; + + match response.next_page_token { + Some(token) => page_token = Some(token), + None => break, + } + } + + // Verify we got exactly the right number of entries + assert_eq!(all_keys.len(), num_entries); + + // Verify correct number of pages (225 entries at 50 per page = 5 pages) + assert_eq!(page_count, 5); + + // Verify no duplicates + let mut unique_keys = all_keys.clone(); + unique_keys.sort(); + unique_keys.dedup(); + assert_eq!(unique_keys.len(), num_entries); + + // Verify ordering: newest first (highest sort_order first). + assert_eq!(all_keys[0], format!("key_{:04}", num_entries - 1)); + assert_eq!(all_keys[num_entries - 1], "key_0000"); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_update_preserves_order() { + let store = create_test_store("test_pg_paginated_update"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![1u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "second", vec![2u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "third", vec![3u8; 8]) + .unwrap(); + + // Update the first entry + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![99u8; 8]) + .unwrap(); + + // Paginated listing should still show "first" with its original creation order + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + // Newest first: third, second, first + assert_eq!(response.keys, vec!["third", "second", "first"]); + + // Verify the updated value was persisted + let data = + KVStoreSync::read(&store, primary_namespace, secondary_namespace, "first").unwrap(); + assert_eq!(data, vec![99u8; 8]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_empty_namespace() { + let store = create_test_store("test_pg_paginated_empty"); + + // Paginating an empty or unknown namespace returns an empty result with no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, "nonexistent", "ns", None).unwrap(); + assert!(response.keys.is_empty()); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_namespace_isolation() { + let store = create_test_store("test_pg_paginated_isolation"); + + KVStoreSync::write(&store, "ns_a", "sub", "key_1", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "sub", "key_2", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_b", "sub", "key_3", vec![3u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "other", "key_4", vec![4u8; 8]).unwrap(); + + // ns_a/sub should only contain key_1 and key_2 (newest first). + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_2", "key_1"]); + assert!(response.next_page_token.is_none()); + + // ns_b/sub should only contain key_3. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_b", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_3"]); + + // ns_a/other should only contain key_4. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "other", None).unwrap(); + assert_eq!(response.keys, vec!["key_4"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_removal() { + let store = create_test_store("test_pg_paginated_removal"); + + let ns = "test_ns"; + let sub = "test_sub"; + + KVStoreSync::write(&store, ns, sub, "a", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "b", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "c", vec![3u8; 8]).unwrap(); + + KVStoreSync::remove(&store, ns, sub, "b", false).unwrap(); + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys, vec!["c", "a"]); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_exact_page_boundary() { + let store = create_test_store("test_pg_paginated_boundary"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write exactly PAGE_SIZE entries (50). + for i in 0..PAGE_SIZE { + let key = format!("key_{:04}", i); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + // Exactly PAGE_SIZE entries: all returned in one page with no next-page token. + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_none()); + + // Add one more entry (PAGE_SIZE + 1 total). First page should now have a token. + KVStoreSync::write(&store, ns, sub, "key_extra", vec![0u8; 8]).unwrap(); + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_some()); + + // Second page should have exactly 1 entry and no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, ns, sub, response.next_page_token) + .unwrap(); + assert_eq!(response.keys.len(), 1); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_fewer_than_page_size() { + let store = create_test_store("test_pg_paginated_few"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write fewer entries than PAGE_SIZE. + for i in 0..5 { + let key = format!("key_{i}"); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), 5); + // Fewer than PAGE_SIZE means no next page. + assert!(response.next_page_token.is_none()); + // Newest first. + assert_eq!(response.keys, vec!["key_4", "key_3", "key_2", "key_1", "key_0"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_write_version_persists_across_restart() { + let table_name = "test_pg_write_version_restart"; + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_a", + vec![1u8; 8], + ) + .unwrap(); + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_b", + vec![2u8; 8], + ) + .unwrap(); + + // Don't clean up since we want to reopen + } + + // Open a new store instance on the same database table and write more + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_c", + vec![3u8; 8], + ) + .unwrap(); + + // Paginated listing should show newest first: key_c, key_b, key_a + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + assert_eq!(response.keys, vec!["key_c", "key_b", "key_a"]); + + cleanup_store(&store); + } + } +} diff --git a/tests/integration_tests_postgres.rs b/tests/integration_tests_postgres.rs new file mode 100644 index 000000000..eb1e3d86b --- /dev/null +++ b/tests/integration_tests_postgres.rs @@ -0,0 +1,128 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +#![cfg(feature = "postgres")] + +mod common; + +use ldk_node::entropy::NodeEntropy; +use ldk_node::Builder; +use rand::RngCore; + +fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn channel_full_cycle_with_postgres_store() { + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + println!("== Node A =="); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let config_a = common::random_config(true); + let mut builder_a = Builder::from_config(config_a.node_config); + builder_a.set_chain_source_esplora(esplora_url.clone(), None); + let node_a = builder_a + .build_with_postgres_store( + config_a.node_entropy, + &test_connection_string(), + Some("channel_cycle_a".to_string()), + ) + .unwrap(); + node_a.start().unwrap(); + + println!("\n== Node B =="); + let config_b = common::random_config(true); + let mut builder_b = Builder::from_config(config_b.node_config); + builder_b.set_chain_source_esplora(esplora_url.clone(), None); + let node_b = builder_b + .build_with_postgres_store( + config_b.node_entropy, + &test_connection_string(), + Some("channel_cycle_b".to_string()), + ) + .unwrap(); + node_b.start().unwrap(); + + common::do_channel_full_cycle( + node_a, + node_b, + &bitcoind.client, + &electrsd.client, + false, + true, + false, + ) + .await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn postgres_node_restart() { + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let connection_string = test_connection_string(); + + let storage_path = common::random_storage_path().to_str().unwrap().to_owned(); + let mut seed_bytes = [42u8; 64]; + rand::rng().fill_bytes(&mut seed_bytes); + let node_entropy = NodeEntropy::from_seed_bytes(seed_bytes); + + // Setup initial node and fund it. + let (expected_balance_sats, expected_node_id) = { + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path.clone()); + builder.set_chain_source_esplora(esplora_url.clone(), None); + let node = builder + .build_with_postgres_store( + node_entropy, + &connection_string, + Some("restart_test".to_string()), + ) + .unwrap(); + + node.start().unwrap(); + let addr = node.onchain_payment().new_address().unwrap(); + common::premine_and_distribute_funds( + &bitcoind.client, + &electrsd.client, + vec![addr], + bitcoin::Amount::from_sat(100_000), + ) + .await; + node.sync_wallets().unwrap(); + + let balance = node.list_balances().spendable_onchain_balance_sats; + assert!(balance > 0); + let node_id = node.node_id(); + + node.stop().unwrap(); + (balance, node_id) + }; + + // Verify node can be restarted from PostgreSQL backend. + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path); + builder.set_chain_source_esplora(esplora_url, None); + + let node = builder + .build_with_postgres_store( + node_entropy, + &connection_string, + Some("restart_test".to_string()), + ) + .unwrap(); + + node.start().unwrap(); + node.sync_wallets().unwrap(); + + assert_eq!(expected_node_id, node.node_id()); + assert_eq!(expected_balance_sats, node.list_balances().spendable_onchain_balance_sats); + + node.stop().unwrap(); +}