Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ldk-node-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ jobs:
run: |
cd vss-server/rust
cargo build
cargo run --no-default-features server/vss-server-config.toml&
cargo run --no-default-features --features testing server/vss-server-config.toml&
- name: Run LDK Node Integration tests
run: |
cd ldk-node
export TEST_VSS_BASE_URL="http://localhost:8080/vss"
RUSTFLAGS="--cfg vss_test" cargo test io::vss_store
RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss
RUSTFLAGS="--cfg vss_test" cargo test io::vss_store -- --test-threads=1
RUSTFLAGS="--cfg vss_test" cargo test --test integration_tests_vss -- --test-threads=1
4 changes: 3 additions & 1 deletion rust/api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::error::VssError;
use async_trait::async_trait;
use std::collections::HashMap;
use std::string::ToString;

/// Response returned for [`Authorizer`] request if user is authenticated and authorized.
#[derive(Debug, Clone)]
Expand All @@ -21,10 +20,13 @@ pub trait Authorizer: Send + Sync {
}

/// A no-operation authorizer, which lets any user-request go through.
#[cfg(feature = "_test_utils")]
pub struct NoopAuthorizer {}

#[cfg(feature = "_test_utils")]
const UNAUTHENTICATED_USER: &str = "unauth-user";

#[cfg(feature = "_test_utils")]
#[async_trait]
impl Authorizer for NoopAuthorizer {
async fn verify(
Expand Down
4 changes: 2 additions & 2 deletions rust/impls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ async-trait = "0.1.77"
api = { path = "../api" }
chrono = "0.4.38"
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
bb8-postgres = "0.7"
bytes = "1.4.0"
tokio = { version = "1.38.0", default-features = false }
tokio = { version = "1.38.0", default-features = false, features = ["rt", "macros"] }
native-tls = { version = "0.2.14", default-features = false }
postgres-native-tls = { version = "0.5.2", default-features = false, features = ["runtime"] }
log = { version = "0.4.29", default-features = false }

[dev-dependencies]
tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] }
Expand Down
127 changes: 81 additions & 46 deletions rust/impls/src/postgres_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ use api::types::{
ListKeyVersionsRequest, ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
};
use async_trait::async_trait;
use bb8_postgres::bb8::Pool;
use bb8_postgres::PostgresConnectionManager;
use bytes::Bytes;
use chrono::Utc;
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use std::cmp::min;
use std::io::{self, Error, ErrorKind};
use tokio::sync::Mutex;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{error, Client, NoTls, Socket, Transaction};

use log::{debug, error, info, warn};

pub use native_tls::Certificate;

pub(crate) struct VssDbRecord {
Expand Down Expand Up @@ -47,6 +48,73 @@ pub const LIST_KEY_VERSIONS_MAX_PAGE_SIZE: i32 = 100;
/// Exceeding this value will result in request rejection through [`VssError::InvalidRequestError`].
pub const MAX_PUT_REQUEST_ITEM_COUNT: usize = 1000;

const POOL_SIZE: usize = 10;

struct SmallPool<T> {
connections: [Mutex<Client>; POOL_SIZE],
endpoint: String,
db_name: String,
tls: T,
}

impl<T> SmallPool<T>
where
T: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
T::Stream: Send + Sync,
T::TlsConnect: Send,
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
async fn new(postgres_endpoint: &str, vss_db: &str, tls: T) -> Result<Self, Error> {
let connections = [
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
Mutex::new(make_db_connection(postgres_endpoint, vss_db, tls.clone()).await?),
];

let pool = SmallPool {
connections,
endpoint: String::from(postgres_endpoint),
db_name: String::from(vss_db),
tls,
};
Ok(pool)
}

async fn get(&self) -> Result<tokio::sync::MutexGuard<'_, Client>, Error> {
let mut conn = tokio::select! {
conn_0 = self.connections[0].lock() => conn_0,
conn_1 = self.connections[1].lock() => conn_1,
conn_2 = self.connections[2].lock() => conn_2,
conn_3 = self.connections[3].lock() => conn_3,
conn_4 = self.connections[4].lock() => conn_4,
conn_5 = self.connections[5].lock() => conn_5,
conn_6 = self.connections[6].lock() => conn_6,
conn_7 = self.connections[7].lock() => conn_7,
conn_8 = self.connections[8].lock() => conn_8,
conn_9 = self.connections[9].lock() => conn_9,
};
self.ensure_connected(&mut conn).await?;
Ok(conn)
}

async fn ensure_connected(&self, client: &mut Client) -> Result<(), Error> {
if client.is_closed() || client.check_connection().await.is_err() {
debug!("Rotating connection to the postgres database");
let new_client =
make_db_connection(&self.endpoint, &self.db_name, self.tls.clone()).await?;
*client = new_client;
}
Ok(())
}
}

/// A [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
pub struct PostgresBackend<T>
where
Expand All @@ -55,7 +123,7 @@ where
<T as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<T as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
pool: Pool<PostgresConnectionManager<T>>,
pool: SmallPool<T>,
}

/// A postgres backend with plaintext connections to the database
Expand All @@ -80,7 +148,7 @@ where
// Connection must be driven on a separate task, and will resolve when the client is dropped
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Connection error: {}", e);
warn!("Connection error: {}", e);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this recoverable? error then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can recover using fn ensure_connected if this happens on one of our 10 long-lived connections to the database.

Copy link
Contributor Author

@tankyleo tankyleo Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you would rather use error here it is true that if this happens during startup we do not recover.

In case we fail during startup we do have the error!("Failed to start postgres backend"); messages at the error level.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think error would be preferable in this case here, too.

}
});
Ok(client)
Expand Down Expand Up @@ -108,7 +176,7 @@ where
client.execute(&stmt, &[]).await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e))
})?;
println!("Created database {}", db_name);
info!("Created database {}", db_name);
}

Ok(())
Expand Down Expand Up @@ -183,22 +251,8 @@ where
postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T,
) -> Result<Self, Error> {
create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?;
let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db);
let manager =
PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to create PostgresConnectionManager: {}", e),
)
})?;
// By default, Pool maintains 0 long-running connections, so returning a pool
// here is no guarantee that Pool established a connection to the database.
//
// See Builder::min_idle to increase the long-running connection count.
let pool = Pool::builder()
.build(manager)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;

let pool = SmallPool::new(postgres_endpoint, vss_db, tls).await?;
let postgres_backend = PostgresBackend { pool };

#[cfg(not(test))]
Expand All @@ -208,10 +262,7 @@ where
}

async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
let mut conn = self.pool.get().await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to fetch a connection from Pool: {}", e))
})?;

let mut conn = self.pool.get().await?;
// Get the next migration to be applied.
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
Ok(row) => {
Expand Down Expand Up @@ -243,7 +294,7 @@ where
panic!("We do not allow downgrades");
}

println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);
info!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);

for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() {
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
Expand Down Expand Up @@ -464,11 +515,7 @@ where
async fn get(
&self, user_token: String, request: GetObjectRequest,
) -> Result<GetObjectResponse, VssError> {
let conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let conn = self.pool.get().await?;
let stmt = "SELECT key, value, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key = $3";
let row = conn
.query_opt(stmt, &[&user_token, &request.store_id, &request.key])
Expand Down Expand Up @@ -525,11 +572,7 @@ where
vss_put_records.push(global_version_record);
}

let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let mut conn = self.pool.get().await?;
let transaction = conn
.transaction()
.await
Expand Down Expand Up @@ -573,11 +616,7 @@ where
})?;
let vss_record = self.build_vss_record(user_token, store_id, key_value);

let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let mut conn = self.pool.get().await?;
let transaction = conn
.transaction()
.await
Expand Down Expand Up @@ -622,11 +661,7 @@ where

let limit = min(page_size, LIST_KEY_VERSIONS_MAX_PAGE_SIZE) as i64;

let conn = self
.pool
.get()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
let conn = self.pool.get().await?;

let stmt = "SELECT key, version FROM vss_db WHERE user_token = $1 AND store_id = $2 AND key > $3 AND key LIKE $4 ORDER BY key LIMIT $5";

Expand Down
4 changes: 4 additions & 0 deletions rust/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
jwt = ["auth-impls/jwt"]
sigs = ["auth-impls/sigs"]
default = [ "jwt", "sigs" ]
# This feature is NOT to be used in prod
testing = ["api/_test_utils"]

[dependencies]
api = { path = "../api" }
Expand All @@ -21,3 +23,5 @@ prost = { version = "0.11.6", default-features = false, features = ["std"] }
bytes = "1.4.0"
serde = { version = "1.0.203", default-features = false, features = ["derive"] }
toml = { version = "0.8.9", default-features = false, features = ["parse"] }
log = { version = "0.4.29", default-features = false, features = ["std"] }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
Loading