Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ldk-node-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
run: |
cd ldk-node
export TEST_VSS_BASE_URL="http://localhost:8080/vss"
RUSTFLAGS="--cfg vss_test" cargo test io::vss_store
RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss
RUSTFLAGS="--cfg vss_test" cargo test io::vss_store -- --test-threads=1
RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss -- --test-threads=1
3 changes: 1 addition & 2 deletions rust/impls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ async-trait = "0.1.77"
api = { path = "../api" }
chrono = "0.4.38"
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
bb8-postgres = "0.7"
bytes = "1.4.0"
tokio = { version = "1.38.0", default-features = false }
tokio = { version = "1.38.0", default-features = false, features = ["rt", "macros"] }
native-tls = { version = "0.2.14", default-features = false }
postgres-native-tls = { version = "0.5.2", default-features = false, features = ["runtime"] }

Expand Down
118 changes: 75 additions & 43 deletions rust/impls/src/postgres_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ use api::types::{
ListKeyVersionsRequest, ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
};
use async_trait::async_trait;
use bb8_postgres::bb8::Pool;
use bb8_postgres::PostgresConnectionManager;
use bytes::Bytes;
use chrono::Utc;
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use std::cmp::min;
use std::io::{self, Error, ErrorKind};
use tokio::sync::Mutex;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{error, Client, NoTls, Socket, Transaction};

Expand Down Expand Up @@ -47,6 +46,72 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100;
/// Exceeding this value will result in request rejection through [`VssError::InvalidRequestError`].
pub const MAX_PUT_REQUEST_ITEM_COUNT: usize = 1000;

const POOL_SIZE: usize = 10;

struct SmallPool<T> {
connections: [Mutex<Client>; POOL_SIZE],
endpoint: String,
db_name: String,
tls: T,
}

impl<T> SmallPool<T>
where
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
T::Stream: Send + Sync,
T::TlsConnect: Send,
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
async fn new(postgres_endpoint: &str, vss_db: &str, tls: T) -> Result<Self, Error> {
let connections = [
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
];

let pool = SmallPool {
connections,
endpoint: String::from(postgres_endpoint),
db_name: String::from(vss_db),
tls,
};
Ok(pool)
}

async fn get(&self) -> Result<tokio::sync::MutexGuard<'_, Client>, Error> {
let mut conn = tokio::select! {
conn_0 = self.connections[0].lock() => conn_0,
conn_1 = self.connections[1].lock() => conn_1,
conn_2 = self.connections[2].lock() => conn_2,
conn_3 = self.connections[3].lock() => conn_3,
conn_4 = self.connections[4].lock() => conn_4,
conn_5 = self.connections[5].lock() => conn_5,
conn_6 = self.connections[6].lock() => conn_6,
conn_7 = self.connections[7].lock() => conn_7,
conn_8 = self.connections[8].lock() => conn_8,
conn_9 = self.connections[9].lock() => conn_9,
};
self.ensure_connected(&mut conn).await?;
Ok(conn)
}

async fn ensure_connected(&self, client: &mut Client) -> Result<(), Error> {
if client.is_closed() || client.check_connection().await.is_err() {
let new_client =
make_db_connection(&self.endpoint, &self.db_name, self.tls.clone()).await?;
*client = new_client;
}
Ok(())
}
}

/// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
pub struct PostgresBackend<T>
where
Expand All @@ -55,7 +120,7 @@ where
<T as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
pool: Pool<PostgresConnectionManager<T>>,
pool: SmallPool<T>,
}

/// A postgres backend with plaintext connections to the database
Expand Down Expand Up @@ -183,22 +248,8 @@ where
postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T,
) -> Result<Self, Error> {
create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?;
let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db);
let manager =
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to create PostgresConnectionManager: {}", e),
)
})?;
// By default, Pool maintains 0 long-running connections, so returning a pool
// here is no guarantee that Pool established a connection to the database.
//
// See Builder::min_idle to increase the long-running connection count.
let pool = Pool::builder()
.build(manager)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;

let pool = SmallPool::new(postgres_endpoint, vss_db, tls).await?;
let postgres_backend = PostgresBackend { pool };

#[cfg(not(test))]
Expand All @@ -208,10 +259,7 @@ where
}

async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
let mut conn = self.pool.get().await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to fetch a connection from Pool: {}", e))
})?;

let mut conn = self.pool.get().await?;
// Get the next migration to be applied.
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
Ok(row) => {
Expand Down Expand Up @@ -464,11 +512,7 @@ where
async fn get(
&self, user_token: String, request: GetObjectRequest,
) -> Result<GetObjectResponse, VssError> {
let conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let conn = self.pool.get().await?;
let stmt = "SELECT key, value, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key = $3";
let row = conn
.query_opt(stmt, &[&user_token, &request.store_id, &request.key])
Expand Down Expand Up @@ -525,11 +569,7 @@ where
vss_put_records.push(global_version_record);
}

let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let mut conn = self.pool.get().await?;
let transaction = conn
.transaction()
.await
Expand Down Expand Up @@ -573,11 +613,7 @@ where
})?;
let vss_record = self.build_vss_record(user_token, store_id, key_value);

let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let mut conn = self.pool.get().await?;
let transaction = conn
.transaction()
.await
Expand Down Expand Up @@ -622,11 +658,7 @@ where

let limit = min(page_size, LIST_KEY_VERSIONS_MAX_PAGE_SIZE) as i64;

let conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let conn = self.pool.get().await?;

let stmt = "SELECT key, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key > $3 AND key LIKE $4 ORDER BY key LIMIT $5";

Expand Down