Skip to content

Commit a35b9f0

Browse files
committed
♻️ refactor: replace Proxy with Context in handler and use async_trait
- Replace handler function signature to accept Context<ZeroKms> instead of Proxy + client_id - Add config access methods to Context for database, TLS, and other settings - Convert EncryptionService trait from impl Future to async_trait for cleaner syntax - Update ZeroKms implementation to use async_trait - Modify startup functions to work with Context instead of Proxy - Remove services module as functionality is now integrated into Context - Fix type annotation errors by using concrete Context<ZeroKms> type This refactoring simplifies the architecture by making Context the primary interface for handler operations, avoiding dyn trait objects while maintaining strong typing.
1 parent 2f3300b commit a35b9f0

File tree

16 files changed

+296
-442
lines changed

16 files changed

+296
-442
lines changed

packages/cipherstash-proxy-integration/src/schema_change.rs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,4 @@ mod tests {
2222

2323
assert!(rows.is_empty());
2424
}
25-
26-
#[tokio::test]
27-
async fn disable_mapping_disables_schema_reload() {
28-
let client = connect_with_tls(PROXY).await;
29-
30-
let sql = "SET CIPHERSTASH.UNSAFE_DISABLE_MAPPING = true";
31-
client.query(sql, &[]).await.unwrap();
32-
33-
let id = random_id();
34-
35-
let sql = format!(
36-
"CREATE TABLE table_{id} (
37-
id bigint,
38-
PRIMARY KEY(id)
39-
);"
40-
);
41-
42-
let _ = client.execute(&sql, &[]).await.unwrap();
43-
44-
let sql = "SET CIPHERSTASH.UNSAFE_DISABLE_MAPPING = false";
45-
client.query(sql, &[]).await.unwrap();
46-
47-
let sql = format!("SELECT id FROM table_{id}");
48-
let result = client.query(&sql, &[]).await;
49-
assert!(result.is_err());
50-
}
5125
}

packages/cipherstash-proxy/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ pub mod log;
99
pub mod postgresql;
1010
pub mod prometheus;
1111
pub mod proxy;
12-
pub mod services;
1312
pub mod tls;
1413

1514
pub use crate::cli::Args;

packages/cipherstash-proxy/src/main.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
8282
},
8383
_ = sighup() => {
8484
info!(msg = "Received SIGHUP. Reloading configuration");
85-
(listener, proxy) = reload_config(listener, &args, proxy).await;
85+
(listener, proxy) = match reload_config(listener, &args).await {
86+
Ok((listener, proxy)) => (listener, proxy),
87+
Err(_) => todo!(),
88+
};
89+
8690
info!(msg = "Reloaded configuration");
8791
},
8892
_ = sigterm() => {
@@ -91,16 +95,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
9195
},
9296
Ok(client_stream) = AsyncStream::accept(&listener) => {
9397

94-
let proxy = proxy.clone();
95-
9698
client_id += 1;
9799

100+
let context = proxy.context(client_id);
101+
98102
tracker.spawn(async move {
99-
let proxy = proxy.clone();
100103

101104
gauge!(CLIENTS_ACTIVE_CONNECTIONS).increment(1);
102105

103-
match pg::handler(client_stream, proxy, client_id).await {
106+
match pg::handler(client_stream,context).await {
104107
Ok(_) => (),
105108
Err(err) => {
106109

@@ -261,15 +264,15 @@ async fn sighup() -> std::io::Result<()> {
261264
Ok(())
262265
}
263266

264-
async fn reload_config(listener: TcpListener, args: &Args, proxy: Proxy) -> (TcpListener, Proxy) {
267+
async fn reload_config(listener: TcpListener, args: &Args) -> Result<(TcpListener, Proxy), Error> {
265268
let new_config = match TandemConfig::load(args) {
266269
Ok(config) => config,
267270
Err(err) => {
268271
warn!(
269272
msg = "Configuration could not be reloaded: {}",
270273
error = err.to_string()
271274
);
272-
return (listener, proxy);
275+
return Err(err);
273276
}
274277
};
275278

@@ -278,8 +281,8 @@ async fn reload_config(listener: TcpListener, args: &Args, proxy: Proxy) -> (Tcp
278281
// Explicit drop needed here to free the network resources before binding if using the same address & port
279282
std::mem::drop(listener);
280283

281-
(
284+
Ok((
282285
connect::bind_with_retry(&new_proxy.config.server).await,
283286
new_proxy,
284-
)
287+
))
285288
}

packages/cipherstash-proxy/src/postgresql/backend.rs

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::prometheus::{
1919
DECRYPTION_ERROR_TOTAL, DECRYPTION_REQUESTS_TOTAL, ROWS_ENCRYPTED_TOTAL,
2020
ROWS_PASSTHROUGH_TOTAL, ROWS_TOTAL, SERVER_BYTES_RECEIVED_TOTAL,
2121
};
22-
use crate::proxy::Proxy;
22+
use crate::proxy::EncryptionService;
2323
use bytes::BytesMut;
2424
use metrics::{counter, histogram};
2525
use std::time::Instant;
@@ -70,25 +70,25 @@ use tracing::{debug, error, info, warn};
7070
/// - `RowDescription`: Result column metadata (modified for encrypted columns)
7171
/// - `ParameterDescription`: Parameter metadata (modified for encrypted parameters)
7272
/// - `ReadyForQuery`: Session ready state (triggers schema reload if needed)
73-
pub struct Backend<R>
73+
pub struct Backend<R, S>
7474
where
7575
R: AsyncRead + Unpin,
76+
S: EncryptionService,
7677
{
7778
/// Sender for outgoing messages to client
7879
client_sender: Sender,
7980
/// Reader for incoming messages from server
8081
server_reader: R,
81-
/// Encryption service for column decryption
82-
proxy: Proxy,
8382
/// Session context with portal and statement metadata
84-
context: Context,
83+
context: Context<S>,
8584
/// Buffer for batching DataRow messages before decryption
8685
buffer: MessageBuffer,
8786
}
8887

89-
impl<R> Backend<R>
88+
impl<R, S> Backend<R, S>
9089
where
9190
R: AsyncRead + Unpin,
91+
S: EncryptionService,
9292
{
9393
/// Creates a new Backend instance.
9494
///
@@ -98,12 +98,11 @@ where
9898
/// * `server_reader` - Stream for reading messages from the PostgreSQL server
9999
/// * `encrypt` - Encryption service for handling column decryption
100100
/// * `context` - Session context shared with the frontend
101-
pub fn new(client_sender: Sender, server_reader: R, proxy: Proxy, context: Context) -> Self {
101+
pub fn new(client_sender: Sender, server_reader: R, context: Context<S>) -> Self {
102102
let buffer = MessageBuffer::new();
103103
Backend {
104104
client_sender,
105105
server_reader,
106-
proxy,
107106
context,
108107
buffer,
109108
}
@@ -150,19 +149,17 @@ where
150149
/// Returns `Ok(())` on successful message processing, or an `Error` if a fatal
151150
/// error occurs that should terminate the connection.
152151
pub async fn rewrite(&mut self) -> Result<(), Error> {
153-
let connection_timeout = self.proxy.config.database.connection_timeout();
154-
155152
let (code, mut bytes) = protocol::read_message(
156153
&mut self.server_reader,
157154
self.context.client_id,
158-
connection_timeout,
155+
self.context.connection_timeout(),
159156
)
160157
.await?;
161158

162159
let sent: u64 = bytes.len() as u64;
163160
counter!(SERVER_BYTES_RECEIVED_TOTAL).increment(sent);
164161

165-
if self.proxy.is_passthrough() {
162+
if self.context.is_passthrough() {
166163
debug!(target: DEVELOPMENT,
167164
client_id = self.context.client_id,
168165
msg = "Passthrough enabled"
@@ -250,7 +247,7 @@ where
250247
msg = "ReadyForQuery"
251248
);
252249
if self.context.schema_changed() {
253-
self.proxy.reload_schema().await;
250+
self.context.reload_schema().await;
254251
}
255252
}
256253

@@ -450,16 +447,12 @@ where
450447
);
451448

452449
// Decrypt CipherText -> Plaintext
453-
let plaintexts = self
454-
.proxy
455-
.decrypt(keyset_id, ciphertexts)
456-
.await
457-
.inspect_err(|_| {
458-
counter!(DECRYPTION_ERROR_TOTAL).increment(1);
459-
})?;
450+
let plaintexts = self.context.decrypt(ciphertexts).await.inspect_err(|_| {
451+
counter!(DECRYPTION_ERROR_TOTAL).increment(1);
452+
})?;
460453

461454
// Avoid the iter calculation if we can
462-
if self.proxy.config.prometheus_enabled() {
455+
if self.context.prometheus_enabled() {
463456
let decrypted_count =
464457
plaintexts
465458
.iter()
@@ -655,9 +648,10 @@ where
655648
}
656649

657650
/// Implementation of PostgreSQL error handling for the Backend component.
658-
impl<R> PostgreSqlErrorHandler for Backend<R>
651+
impl<R, S> PostgreSqlErrorHandler for Backend<R, S>
659652
where
660653
R: AsyncRead + Unpin,
654+
S: EncryptionService,
661655
{
662656
fn client_sender(&mut self) -> &mut Sender {
663657
&mut self.client_sender

packages/cipherstash-proxy/src/postgresql/column_mapper.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{
33
error::{EncryptError, Error},
44
log::MAPPER,
55
postgresql::Column,
6-
services::SchemaService,
6+
proxy::EncryptConfig,
77
};
88
use eql_mapper::{EqlTerm, TableColumn, TypeCheckedStatement};
99
use postgres_types::Type;
@@ -14,13 +14,13 @@ use tracing::{debug, warn};
1414
/// and mapping them to encryption configurations.
1515
#[derive(Clone)]
1616
pub struct ColumnMapper {
17-
schema_service: Arc<dyn SchemaService>,
17+
encrypt_config: Arc<EncryptConfig>,
1818
}
1919

2020
impl ColumnMapper {
2121
/// Create a new ColumnProcessor with the given schema service and client ID
22-
pub fn new(schema_service: Arc<dyn SchemaService>) -> Self {
23-
Self { schema_service }
22+
pub fn new(encrypt_config: Arc<EncryptConfig>) -> Self {
23+
Self { encrypt_config }
2424
}
2525

2626
/// Maps typed statement projection columns to an Encrypt column configuration
@@ -127,7 +127,7 @@ impl ColumnMapper {
127127
identifier: Identifier,
128128
eql_term: &EqlTerm,
129129
) -> Result<Option<Column>, Error> {
130-
match self.schema_service.get_column_config(&identifier) {
130+
match self.encrypt_config.get_column_config(&identifier) {
131131
Some(config) => {
132132
debug!(
133133
target: MAPPER,

0 commit comments

Comments
 (0)