diff --git a/.github/workflows/ldk-node-integration.yml b/.github/workflows/ldk-node-integration.yml index 732c104..e79d210 100644 --- a/.github/workflows/ldk-node-integration.yml +++ b/.github/workflows/ldk-node-integration.yml @@ -45,5 +45,5 @@ jobs: run: | cd ldk-node export TEST_VSS_BASE_URL="http://localhost:8080/vss" - RUSTFLAGS="--cfg vss_test" cargo test io::vss_store - RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss + RUSTFLAGS="--cfg vss_test" cargo test io::vss_store -- --test-threads=1 + RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss -- --test-threads=1 diff --git a/rust/impls/Cargo.toml b/rust/impls/Cargo.toml index 505bda1..fd8cece 100644 --- a/rust/impls/Cargo.toml +++ b/rust/impls/Cargo.toml @@ -8,9 +8,8 @@ async-trait = "0.1.77" api = { path = "../api" } chrono = "0.4.38" tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] } -bb8-postgres = "0.7" bytes = "1.4.0" -tokio = { version = "1.38.0", default-features = false } +tokio = { version = "1.38.0", default-features = false, features = ["rt", "macros"] } native-tls = { version = "0.2.14", default-features = false } postgres-native-tls = { version = "0.5.2", default-features = false, features = ["runtime"] } diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index 3bc176c..35363ef 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -7,14 +7,13 @@ use api::types::{ ListKeyVersionsRequest, ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse, }; use async_trait::async_trait; -use bb8_postgres::bb8::Pool; -use bb8_postgres::PostgresConnectionManager; use bytes::Bytes; use chrono::Utc; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use std::cmp::min; use std::io::{self, Error, ErrorKind}; +use tokio::sync::Mutex; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{error, Client, NoTls, Socket, Transaction}; @@ -47,6 +46,72 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100; /// Exceeding this value will result in request rejection through [`VssError::InvalidRequestError`]. pub const MAX_PUT_REQUEST_ITEM_COUNT: usize = 1000; +const POOL_SIZE: usize = 10; + +struct SmallPool { + connections: [Mutex; POOL_SIZE], + endpoint: String, + db_name: String, + tls: T, +} + +impl SmallPool +where + T: MakeTlsConnect + Clone + Send + Sync + 'static, + T::Stream: Send + Sync, + T::TlsConnect: Send, + <>::TlsConnect as TlsConnect>::Future: Send, +{ + async fn new(postgres_endpoint: &str, vss_db: &str, tls: T) -> Result { + let connections = [ + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?), + ]; + + let pool = SmallPool { + connections, + endpoint: String::from(postgres_endpoint), + db_name: String::from(vss_db), + tls, + }; + Ok(pool) + } + + async fn get(&self) -> Result, Error> { + let mut conn = tokio::select! { + conn_0 = self.connections[0].lock() => conn_0, + conn_1 = self.connections[1].lock() => conn_1, + conn_2 = self.connections[2].lock() => conn_2, + conn_3 = self.connections[3].lock() => conn_3, + conn_4 = self.connections[4].lock() => conn_4, + conn_5 = self.connections[5].lock() => conn_5, + conn_6 = self.connections[6].lock() => conn_6, + conn_7 = self.connections[7].lock() => conn_7, + conn_8 = self.connections[8].lock() => conn_8, + conn_9 = self.connections[9].lock() => conn_9, + }; + self.ensure_connected(&mut conn).await?; + Ok(conn) + } + + async fn ensure_connected(&self, client: &mut Client) -> Result<(), Error> { + if client.is_closed() || client.check_connection().await.is_err() { + let new_client = + make_db_connection(&self.endpoint, &self.db_name, self.tls.clone()).await?; + *client = new_client; + } + Ok(()) + } +} + /// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS. pub struct PostgresBackend where @@ -55,7 +120,7 @@ where >::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - pool: Pool>, + pool: SmallPool, } /// A postgres backend with plaintext connections to the database @@ -183,22 +248,8 @@ where postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T, ) -> Result { create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?; - let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db); - let manager = - PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| { - Error::new( - ErrorKind::Other, - format!("Failed to create PostgresConnectionManager: {}", e), - ) - })?; - // By default, Pool maintains 0 long-running connections, so returning a pool - // here is no guarantee that Pool established a connection to the database. - // - // See Builder::min_idle to increase the long-running connection count. - let pool = Pool::builder() - .build(manager) - .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?; + + let pool = SmallPool::new(postgres_endpoint, vss_db, tls).await?; let postgres_backend = PostgresBackend { pool }; #[cfg(not(test))] @@ -208,10 +259,7 @@ where } async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> { - let mut conn = self.pool.get().await.map_err(|e| { - Error::new(ErrorKind::Other, format!("Failed to fetch a connection from Pool: {}", e)) - })?; - + let mut conn = self.pool.get().await?; // Get the next migration to be applied. let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await { Ok(row) => { @@ -464,11 +512,7 @@ where async fn get( &self, user_token: String, request: GetObjectRequest, ) -> Result { - let conn = self - .pool - .get() - .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + let conn = self.pool.get().await?; let stmt = "SELECT key, value, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key = $3"; let row = conn .query_opt(stmt, &[&user_token, &request.store_id, &request.key]) @@ -525,11 +569,7 @@ where vss_put_records.push(global_version_record); } - let mut conn = self - .pool - .get() - .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + let mut conn = self.pool.get().await?; let transaction = conn .transaction() .await @@ -573,11 +613,7 @@ where })?; let vss_record = self.build_vss_record(user_token, store_id, key_value); - let mut conn = self - .pool - .get() - .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + let mut conn = self.pool.get().await?; let transaction = conn .transaction() .await @@ -622,11 +658,7 @@ where let limit = min(page_size, LIST_KEY_VERSIONS_MAX_PAGE_SIZE) as i64; - let conn = self - .pool - .get() - .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + let conn = self.pool.get().await?; let stmt = "SELECT key, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key > $3 AND key LIKE $4 ORDER BY key LIMIT $5";