diff --git a/Cargo.lock b/Cargo.lock index 6f1855f5b..4aba28989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3270,6 +3270,7 @@ dependencies = [ "serde", "serde_json", "sha1", + "smallvec", "socket2 0.5.10", "stats_alloc", "tempfile", diff --git a/docs/CLIENT_CONNECTION.md b/docs/CLIENT_CONNECTION.md new file mode 100644 index 000000000..1d533baf0 --- /dev/null +++ b/docs/CLIENT_CONNECTION.md @@ -0,0 +1,370 @@ +# Client connection lifecycle + +Traces a PostgreSQL client message from TCP accept to response delivery. + +pgdog is a connection pooler and query router. Clients connect to it as if it were Postgres; it speaks the wire protocol on both sides and maintains a pool of real server connections behind it. + +--- + +## Key concepts + +Six types recur throughout the codebase. Understand these before following the flow. + +**`BackendPid` — identity, not process ID.** Every client and every server connection gets a `BackendPid` (see [`net/messages/backend_pid.rs`](../pgdog/src/net/messages/backend_pid.rs)). For clients, pgdog mints a sequential synthetic pid; for server connections it reads the pid from the Postgres handshake. Cancel routing, server tracking, and message tagging all use `BackendPid`. + +**`Source` — message origin.** Every `Message` (in [`net/messages/mod.rs`](../pgdog/src/net/messages/mod.rs)) carries a `Source` tag: `Backend(BackendPid)` for messages from a Postgres socket, `Frontend` for messages from the client, and `Internal` (the default) for messages synthesised by pgdog. The `Debug` formatter uses this to disambiguate byte codes shared between directions — the same byte `D` is `Describe` from a client and `DataRow` from a server. + +**`Guard` — RAII pool checkout.** `Guard` in [`backend/pool/guard.rs`](../pgdog/src/backend/pool/guard.rs) wraps a `Box` with `Deref`/`DerefMut`; `Drop` triggers cleanup and check-in. The connection always returns to the pool (or is closed) when `Guard` falls out of scope. + +**`Route` — routing decision.** The query engine produces a `Route` (in [`frontend/router/parser/route.rs`](../pgdog/src/frontend/router/parser/route.rs)) for each query: shard(s), role (primary or replica), and aggregation metadata for multi-shard merging. The rest of the system acts only on `Route` and never re-inspects the SQL. + +**`ClientRequest` — extended-protocol buffer.** The extended protocol is a multi-message sequence (Parse → Bind → Describe → Execute → Sync). pgdog buffers messages into a `ClientRequest` (in [`frontend/client_request.rs`](../pgdog/src/frontend/client_request.rs)) and dispatches only when the sequence is complete (Sync or Flush). Simple Query (`Q`) is dispatched immediately. + +**`Sticky` — per-client routing pins.** `Sticky` in [`frontend/client/sticky.rs`](../pgdog/src/frontend/client/sticky.rs) has two fields set at login: `omni_index` (a random value that pins all of one client's omni-shard queries to the same shard — per client, not per statement) and `role` (from the `pgdog.role` startup parameter). Both are set once and never change. + +--- + +## High-level flow + +```mermaid +sequenceDiagram + participant C as Client app + participant L as Listener + participant CL as Client + participant QE as QueryEngine + participant P as Pool + participant S as Server + + C->>L: TCP connect + L->>CL: Client::spawn() + C->>CL: Startup message + CL->>P: conn.parameters() + P-->>CL: ParameterStatus + CL-->>C: AuthOk + BackendKeyData + RFQ + loop every request + C->>CL: Q / Parse+Exec+Sync + CL->>QE: client_messages() + QE->>P: Pool::get() + P-->>QE: Guard(Server) + QE->>S: Server::send() + S-->>QE: Server::read() + QE->>P: Guard::drop() checkin + QE-->>CL: response + CL-->>C: DataRow* + RFQ + end +``` + +--- + +## 1. Connection acceptance + +Entry point: `Listener::listen()` in [`frontend/listener.rs`](../pgdog/src/frontend/listener.rs). + +Each accepted TCP socket becomes a Tokio task in `handle_client()`. Before constructing a `Client`, the listener resolves the startup type: + +```mermaid +flowchart TD + A[TCP accept] --> B{startup message type} + B -->|Ssl| C[TLS handshake] + C --> F + B -->|GssEnc| D[send N — client retries] + B -->|Cancel| E{verify_cancel} + E -->|secret ok| E2[databases::cancel
close TCP] + E -->|secret mismatch| E3[close TCP
query unaffected] + B -->|Startup| F[Client::spawn] +``` + +- **SSL** (`Startup::Ssl`): wraps the socket in TLS; records the peer certificate for mTLS auth. +- **GSS** (`Startup::GssEnc`): rejected — pgdog sends `N` and the client retries. +- **Cancel** (`Startup::Cancel`): the listener calls `comms().verify_cancel(&id)` before anything else. On mismatch the TCP connection is closed and the running query is unaffected — no error is sent to the caller. On success, `databases().cancel(id.pid)` is called and then the TCP connection is closed. Handled here, before any `Client` exists — this is why the cancel path is entirely separate from the query path. +- **Startup** (`Startup::Startup`): negotiation complete; falls through to `Client::spawn()`. + +The socket is wrapped in `Stream` (in [`net/stream.rs`](../pgdog/src/net/stream.rs)) for a uniform `send` / `read` / `flush` interface over both plain TCP and TLS. + +--- + +## 2. Login + +`Client::login()` in [`frontend/client/mod.rs`](../pgdog/src/frontend/client/mod.rs). Runs once per connection; returns a `Client` or sends an error. + +Steps in order: +1. Reject plaintext connections when `tls_client_required` is set. +2. Identify the target database and user from startup parameters; detect admin connections. +3. Mint `BackendKeyData` (in [`net/messages/backend_key.rs`](../pgdog/src/net/messages/backend_key.rs)) — synthetic pid + random secret — and create a `ClientComms` (in [`frontend/comms.rs`](../pgdog/src/frontend/comms.rs)). +4. Authenticate using the method configured for this user: Trust, MD5, SCRAM, Plaintext, or mTLS (via `stream.tls_identity()`). Passthrough auth forwards credentials directly to Postgres. +5. Send `AuthenticationOk`. +6. Reject the connection if the pooler is shutting down (`comms.offline()`) and this is not an admin connection. +7. Fetch server parameters from a pooled backend via `conn.parameters(&Request::unrouted(id.pid()))` and forward them as `ParameterStatus` messages. +8. Send `BackendKeyData` to the client (stored for future cancel requests). +9. Send `ReadyForQuery(Idle)`. +10. Call `comms.connect(id, addr, ¶ms)` (in [`frontend/comms.rs`](../pgdog/src/frontend/comms.rs)) — registers the client in the process-wide map for cancel routing and shutdown. + +`Sticky` (in [`frontend/client/sticky.rs`](../pgdog/src/frontend/client/sticky.rs)) is initialised here via `Sticky::from_params(¶ms)`. + +--- + +## 3. Main client loop + +`Client::run()` in [`frontend/client/mod.rs`](../pgdog/src/frontend/client/mod.rs). + +```rust +loop { + select! { + _ = shutdown.notified() => { /* check offline + can_disconnect */ } + message = engine.read_backend() => server_message(message) + buffer = self.buffer(state) => client_messages(buffer) + } +} +``` + +```mermaid +flowchart TD + A[Client::run loop] --> B{select!} + B -->|shutdown.notified| C{offline and
can_disconnect?} + C -->|yes| D[send ErrorResponse::shutting_down
break] + C -->|no| B + B -->|engine.read_backend| E[server_message
→ process_server_message] + B -->|self.buffer| F[client_messages] + F --> G{maintenance waiter?} + G -->|yes, not in tx| H[park until stop] + H --> I[QueryEngine::handle] + G -->|no| I + I --> B +``` + +### Shutdown arm + +`shutdown` is an `Arc` from `comms.shutting_down()` (in [`frontend/comms.rs`](../pgdog/src/frontend/comms.rs)). When it fires, the loop checks `comms.offline() && query_engine.can_disconnect()`. If true, it sends `ErrorResponse::shutting_down()` and exits. Otherwise it keeps running until the current transaction completes. + +### Backend push arm + +`engine.read_backend()` reads from a checked-out server connection. Not just `NOTIFY` — any server-pushed message goes through `server_message()` → `query_engine.process_server_message()` (in [`frontend/client/query_engine/mod.rs`](../pgdog/src/frontend/client/query_engine/mod.rs)), which handles streaming flags, explain traces, `ReadyForQuery` transitions, 2PC finalisation, and stats. + +### Client buffer arm + +`self.buffer(client_state)` reads bytes from the client socket into a `ClientRequest` (in [`frontend/client_request.rs`](../pgdog/src/frontend/client_request.rs)). A request is complete (`ClientRequest::is_complete()`) when the last message code is one of `{H, S, Q, c, f, F}` or a `CopyData` chunk reaches 4 KB. `'X'` (Terminate) causes a graceful disconnect. + +### Maintenance mode + +Before dispatching, `client_messages()` checks `maintenance_mode::waiter(&database)` (in [`backend/maintenance_mode.rs`](../pgdog/src/backend/maintenance_mode.rs)). If a waiter is active and the client is not in a transaction, the client parks until `maintenance_mode::stop()` fires. + +### Pipeline splicing + +When a client sends multiple pipelined extended-protocol requests in one buffer, `ClientRequest::spliced()` (in [`frontend/client_request.rs`](../pgdog/src/frontend/client_request.rs)) splits them at `Execute` boundaries. Each sub-request runs through `QueryEngine::handle()` independently. A server error mid-pipeline skips forward to the next `Sync`. + +--- + +## 4. Query engine + +`QueryEngine::handle()` in [`frontend/client/query_engine/mod.rs`](../pgdog/src/frontend/client/query_engine/mod.rs). + +```mermaid +flowchart TD + A[ClientRequest] --> B[rewrite_extended] + B --> C[cluster_check] + C --> D[parse_and_rewrite → Route] + D --> E{intercept_incomplete?} + E -->|BEGIN / COMMIT / ROLLBACK| F[synthesise response
Source::Internal] + E -->|no| G[route_query] + G --> H[hooks.before_execution] + H --> I[backend.mirror] + I --> J[dispatch handler] + J -->|needs Postgres| K[connect → Pool::get] + K --> L[Server::send / read] + L --> M[hooks.after_execution] + F --> N[ReadyForQuery] + M --> N +``` + +The pre-dispatch pipeline, all in `QueryEngine::handle()`: + +| Step | Method | What it does | +|---|---|---| +| 1 | `rewrite_extended()` | Rewrite Parse/Bind for sharding (e.g. inject shard key into parameter list) | +| 2 | `cluster_check()` | Verify the cluster is online and not in maintenance | +| 3 | `parse_and_rewrite()` | Parse SQL, extract shard key, build `Route`, rewrite query if needed | +| 4 | `intercept_incomplete()` | Synthesise responses for `BEGIN`/`COMMIT`/`ROLLBACK` without contacting Postgres | +| 5 | `route_query()` | Finalise shard selection and primary/replica choice | +| 6 | `hooks.before_execution()` | `QueryEngineHooks` extension point | +| 7 | `backend.mirror()` | Queue shadow traffic to mirror pools | +| 8 | dispatch | Call the appropriate command handler | + +### Lazy backend connection + +`connect()` in [`frontend/client/query_engine/connect.rs`](../pgdog/src/frontend/client/query_engine/connect.rs) is called from within command handlers (`execute()`, `connect_transaction()`), not from the top of `handle()`. Queries handled entirely by the engine — `BEGIN`, `COMMIT`, `ROLLBACK`, `DISCARD`, SET statements — never touch the pool. + +`connect()` returns `bool`: `false` = recoverable (no server available, engine synthesises an error); `true` = connected; `Err` = fatal. + +### Synthesised responses + +Messages produced by pgdog carry `Source::Internal` (in [`net/messages/mod.rs`](../pgdog/src/net/messages/mod.rs)). This distinguishes synthesised messages from real Postgres responses throughout the codebase. + +### Hooks + +`QueryEngineHooks` in [`frontend/client/query_engine/hooks/mod.rs`](../pgdog/src/frontend/client/query_engine/hooks/mod.rs) has five callbacks: `before_execution`, `after_connected`, `after_execution`, `on_server_message`, and `on_engine_error`. The current built-in use: schema-change detection in [`frontend/client/query_engine/hooks/schema.rs`](../pgdog/src/frontend/client/query_engine/hooks/schema.rs), which marks `schema_changed` on the server so the cleanup step issues `DEALLOCATE ALL`. + +### Two-phase commit + +`TwoPc` in [`frontend/client/query_engine/two_pc/mod.rs`](../pgdog/src/frontend/client/query_engine/two_pc/mod.rs) coordinates distributed transactions across shards. When a write transaction ends with `two_pc_enabled && !rollback`, `phase_one()` issues fsync-safe `PREPARE TRANSACTION` on all shards, then `phase_two()` issues fsync-safe `COMMIT PREPARED`. The WAL in [`frontend/client/query_engine/two_pc/wal/`](../pgdog/src/frontend/client/query_engine/two_pc/wal/) records `Begin` before the prepare and `Committing` before the commit; `End` on clean completion. Format: `u32 bodylen LE | u32 crc32c LE | u8 tag | rmp-serde body`. Tags never change — format evolution uses `#[serde(default)]`. + +--- + +## 5. Backend connection checkout + +`connect()` in [`frontend/client/query_engine/connect.rs`](../pgdog/src/frontend/client/query_engine/connect.rs) calls down through `Connection::connect()` → `cluster.primary()` or `cluster.replica()` → `Pool::get()` in [`backend/pool/pool_impl.rs`](../pgdog/src/backend/pool/pool_impl.rs). + +```mermaid +flowchart TD + A[Pool::get] --> B[Inner::take] + B --> C{idle connection?} + C -->|yes| D[Taken::take
register client→server] + D --> E[return Guard] + C -->|no| F[push Waiting
block on receiver] + F --> G{woken by?} + G -->|Monitor new conn| H[Inner::put → send to waiter] + G -->|other client checkin| H + H --> D +``` + +**Fast path** (`Inner::take()` in [`backend/pool/inner.rs`](../pgdog/src/backend/pool/inner.rs)): pops a server from `idle_connections`, registers it in `Taken`, and returns it as a `Guard`. + +**Slow path** (`Waiting` in [`backend/pool/inner.rs`](../pgdog/src/backend/pool/inner.rs)): no idle connection means a `Waiting` struct with a oneshot channel is pushed onto `Inner::waiting`. The caller blocks on the receiver until `Monitor` creates a connection or another client checks one back in. + +`Guard` in [`backend/pool/guard.rs`](../pgdog/src/backend/pool/guard.rs) wraps `Box` with `Deref`/`DerefMut`. Its `Drop` spawns a cleanup task bounded by `rollback_timeout`; timeout marks the server `ForceClose`. The pool is always notified on return. + +**Multi-shard**: for queries targeting multiple shards, `Binding::MultiShard` in [`backend/pool/connection/binding.rs`](../pgdog/src/backend/pool/connection/binding.rs) holds one `Guard` per shard alongside a `MultiShard` state machine. + +### The `Taken` maps + +`Taken` in [`backend/pool/taken.rs`](../pgdog/src/backend/pool/taken.rs) answers two questions: which server is this client using, and what is its Postgres-issued cancel key? Cancel routing reads `frontend_to_cancel` directly; the reverse map exists only so check-in (which knows the backend pid, not the frontend) can find the entry to drop. + +| Map | Key → Value | Purpose | +|---|---|---| +| `frontend_to_cancel` | frontend pid → server `BackendKeyData` | Cancel routing: client → server key (pid + secret) | +| `backend_to_frontend` | backend pid → frontend pid | Reverse lookup so `check_in(backend_pid)` can drop the right `frontend_to_cancel` entry | + +### Monitor + +`Monitor` in [`backend/pool/monitor.rs`](../pgdog/src/backend/pool/monitor.rs) runs four loops: maintenance every 333 ms (close idle/old, create when undersized), health checks (`SELECT 1` on idle connections), connection creation on demand, and token refresh for external auth (RDS IAM, Azure AD). + +--- + +## 6. Sending to and receiving from Postgres + +Both directions go through `Connection` → `Binding` (in [`backend/pool/connection/binding.rs`](../pgdog/src/backend/pool/connection/binding.rs)) → `Guard` (in [`backend/pool/guard.rs`](../pgdog/src/backend/pool/guard.rs)) → `Server` (in [`backend/server.rs`](../pgdog/src/backend/server.rs)). + +### Sending + +`Server::send()` in [`backend/server.rs`](../pgdog/src/backend/server.rs) marks state `Active`, calls `send_one()` per message, then `flush()`. Each message passes through `PreparedStatements::handle()` (in [`backend/prepared_statements.rs`](../pgdog/src/backend/prepared_statements.rs)) first: + +- A `Parse` already in the cache is dropped; a synthetic `ParseComplete` is queued. Postgres never sees it, and neither does the client. +- A `Bind` for a cached statement may get a `Parse` prepended if the statement needs re-establishing on this connection. +- Other messages update the in-flight state machine. + +State transitions to `ReceivingData` after flush. + +### Receiving + +`Server::read()` in [`backend/server.rs`](../pgdog/src/backend/server.rs) reads from the Postgres socket, tags each message `.backend(self.key.pid())` (`Source::Backend(BackendPid)`), then passes it through `PreparedStatements::forward()` (in [`backend/prepared_statements.rs`](../pgdog/src/backend/prepared_statements.rs)). `forward()` runs a `ProtocolState` state machine returning `Ignore` or `Forward` per code: + +- `ParseComplete ('1')` — marks the statement prepared in cache; forwarded. +- `RowDescription ('T')` — caches the row description; forwarded. +- `ErrorResponse ('E')` — clears in-flight parses and describes; forwarded. +- `Ignore` messages are consumed silently. +This is the only place `Source::Backend` is set. + +### Multi-shard receive + +`MultiShard::forward()` in [`backend/pool/connection/multi_shard/mod.rs`](../pgdog/src/backend/pool/connection/multi_shard/mod.rs) aggregates messages from all shard connections: + +- `RowDescription ('T')`: validated for consistency across shards, de-duplicated. +- `DataRow ('D')`: buffered per shard; for `Route::Omni`, only one shard's rows are kept. +- `CommandComplete ('C')`: counts accumulated per shard; a synthetic `CommandComplete` with `Source::Internal` is emitted once all shards report. +- `ReadyForQuery ('Z')`: waits for all shards; synthesises error state if any shard errored. +--- + +## 7. Connection check-in + +When `Guard` (in [`backend/pool/guard.rs`](../pgdog/src/backend/pool/guard.rs)) drops, `Guard::cleanup()` runs in a spawned task bounded by `rollback_timeout`. + +```mermaid +flowchart TD + A[Guard::drop] --> B[Cleanup::new] + B --> C{dirty state?} + C -->|guard.reset| D[DISCARD ALL] + C -->|server.dirty| E[RESET ALL +
pg_advisory_unlock_all] + C -->|schema_changed| F[DEALLOCATE ALL] + C -->|clean| G[no queries] + D & E & F & G --> H[drain if out of sync] + H --> I[rollback if in tx] + I --> J[execute_batch] + J --> K[Pool::checkin] + K --> L{still healthy?} + L -->|yes + waiter| M[send to waiter] + L -->|yes no waiter| N[idle_connections.push] + L -->|no| O[drop + notify Monitor] +``` + +1. **`Cleanup::new()`** in [`backend/pool/cleanup.rs`](../pgdog/src/backend/pool/cleanup.rs) — decides what to run: + - `guard.reset` → `DISCARD ALL` + - `server.dirty()` → `RESET ALL` + `SELECT pg_advisory_unlock_all()` + - `server.schema_changed()` → `DEALLOCATE ALL` + - otherwise → nothing + - always: `server.ensure_prepared_capacity()` identifies statements to CLOSE within the limit. +2. **`Server::drain()`** in [`backend/server.rs`](../pgdog/src/backend/server.rs) — discards buffered Postgres data if the connection is out of sync. +3. **`Server::rollback()`** in [`backend/server.rs`](../pgdog/src/backend/server.rs) — sends `ROLLBACK` if in a transaction. +4. **`Server::execute_batch()`** in [`backend/server.rs`](../pgdog/src/backend/server.rs) — runs the cleanup queries. +5. **`Server::sync_prepared_statements()`** in [`backend/server.rs`](../pgdog/src/backend/server.rs) — reconciles the local cache against `pg_prepared_statements`. +6. **`Pool::checkin(server)`** → `Inner::maybe_check_in()` in [`backend/pool/inner.rs`](../pgdog/src/backend/pool/inner.rs): + - Removes from `Taken` in [`backend/pool/taken.rs`](../pgdog/src/backend/pool/taken.rs). + - Checks: error state, offline/paused, age ≥ `effective_max_age`, `force_close`, replication mode. + - Healthy: `Inner::put()` hands to a waiter or pushes to `idle_connections`. + - Unhealthy: dropped; `Monitor` in [`backend/pool/monitor.rs`](../pgdog/src/backend/pool/monitor.rs) is notified. + +The timeout ensures a stuck `ROLLBACK` or `DISCARD` doesn't starve the next waiting client. + +--- + +## 8. Cancel flow + +Cancel requests come in on a fresh TCP connection — no auth, no `Client` struct. The path is in [`frontend/listener.rs`](../pgdog/src/frontend/listener.rs) and [`backend/pool/`](../pgdog/src/backend/pool/): + +```mermaid +sequenceDiagram + participant C as Client app + participant L as Listener + participant CM as Comms + participant P as Pool / Taken + participant PG as Postgres + + C->>L: TCP connect (CancelRequest) + L->>CM: verify_cancel(id) + alt secret matches registered client + CM-->>L: ok + L->>P: databases().cancel(client_pid) + P->>P: cancel_key(client_pid) → server BackendKeyData + P->>PG: new TCP + CancelRequest
server pid + server secret + L->>C: close TCP + else wrong pid or wrong secret + CM-->>L: reject + L->>C: close TCP (query unaffected) + end +``` + +1. `Startup::Cancel { id }` parsed in [`frontend/listener.rs`](../pgdog/src/frontend/listener.rs). +2. `comms().verify_cancel(&id)` in [`frontend/comms.rs`](../pgdog/src/frontend/comms.rs) — looks up `ConnectedClient.key` by `id.pid`, compares secrets. On mismatch the TCP connection is closed and **the running query is unaffected** — no signal reaches the backend. This gate exists because `Taken` is keyed on `BackendPid` (pid only); without it, any peer that knows a client pid could cancel its query. +3. `databases().cancel(id.pid)` routes through cluster → shard → `LoadBalancer::cancel` (in [`backend/pool/lb/mod.rs`](../pgdog/src/backend/pool/lb/mod.rs), fans out to every target) → `Pool::cancel(client_pid)` in [`backend/pool/pool_impl.rs`](../pgdog/src/backend/pool/pool_impl.rs). +4. `Inner::cancel_key(client_pid)` → `Taken::cancel_key(client_pid)` → `frontend_to_cancel.get(&client_pid)` in [`backend/pool/taken.rs`](../pgdog/src/backend/pool/taken.rs) returns the server's `BackendKeyData` (pid + secret) directly; missing entries are silently skipped, which is why fan-out across shards/replicas is safe. +5. `Server::cancel(addr, key)` in [`backend/server.rs`](../pgdog/src/backend/server.rs) opens a new TCP connection to Postgres and sends the `CancelRequest` with the server's pid and secret. +The client's secret (checked in step 2) verifies the cancel is legitimate. The server's secret (used in step 5) is what Postgres acts on. +--- + +## Source tagging summary + +| Variant | Set where | Meaning | +|---|---|---| +| `Backend(BackendPid)` | [`backend/server.rs`](../pgdog/src/backend/server.rs) — one place | Arrived verbatim from a Postgres socket | +| `Frontend` | [`frontend/client/mod.rs`](../pgdog/src/frontend/client/mod.rs) — one place | Arrived from the client TCP socket | +| `Internal` | default on `Message::new()` in [`net/messages/mod.rs`](../pgdog/src/net/messages/mod.rs) | Synthesised or transformed by pgdog | + +`Source` has three uses: disambiguating shared byte codes in the `Debug` formatter ([`net/messages/mod.rs`](../pgdog/src/net/messages/mod.rs)); identifying which shard's `DataRow` to keep in `MultiShard` ([`backend/pool/connection/multi_shard/mod.rs`](../pgdog/src/backend/pool/connection/multi_shard/mod.rs)); and nothing else — no other production code reads it. diff --git a/integration/rust/tests/integration/cancel.rs b/integration/rust/tests/integration/cancel.rs new file mode 100644 index 000000000..564a89bdf --- /dev/null +++ b/integration/rust/tests/integration/cancel.rs @@ -0,0 +1,221 @@ +use std::time::Duration; + +use bytes::{BufMut, BytesMut}; +use rust::setup::{admin_tokio, connection_sqlx_direct}; +use sqlx::PgPool; +use tokio::{io::AsyncWriteExt, net::TcpStream, task::JoinHandle, time::timeout}; +use tokio_postgres::{CancelToken, Error as PgError, NoTls, SimpleQueryMessage}; + +/// Returns whether `pid` has an active `pg_sleep` query visible in `pg_stat_activity`. +/// Uses a direct PostgreSQL connection so the result bypasses pgdog completely. +async fn is_sleeping(direct: &PgPool, pid: i32) -> bool { + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) \ + FROM pg_stat_activity \ + WHERE pid = $1 \ + AND state = 'active' \ + AND query LIKE '%pg_sleep%'", + ) + .bind(pid) + .fetch_one(direct) + .await + .unwrap(); + count == 1 +} + +/// Connect to pgdog, pin to a specific PG backend via BEGIN, capture the backend pid +/// via `pg_backend_pid()`, and launch `SELECT pg_sleep(60)` in a background task. +/// +/// `application_name` is embedded in the connection string so the caller can identify +/// this connection in `SHOW CLIENTS` if needed. +/// +/// Returns `(backend_pid, cancel_token, query_handle)`. The caller owns `cancel_token` +/// and `query_handle`; both must be driven to completion to keep the test clean. +async fn start_sleeping_connection( + application_name: &str, +) -> ( + i32, + CancelToken, + JoinHandle, PgError>>, +) { + let (client, connection) = tokio_postgres::connect( + &format!( + "host=127.0.0.1 user=pgdog dbname=pgdog password=pgdog port=6432 application_name={application_name}" + ), + NoTls, + ) + .await + .unwrap(); + + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("pgdog connection error: {}", e); + } + }); + + let cancel_token = client.cancel_token(); + + // BEGIN pins the client to one backend for the duration of the transaction. + // Without this, transaction-mode pooling may assign a different backend to + // pg_sleep than the one whose pid we captured. + client.simple_query("BEGIN").await.unwrap(); + + let row = client + .query_one("SELECT pg_backend_pid()", &[]) + .await + .unwrap(); + let backend_pid: i32 = row.get(0); + + let handle = tokio::spawn(async move { client.simple_query("SELECT pg_sleep(60)").await }); + + (backend_pid, cancel_token, handle) +} + +/// Assert that a query handle returned by `start_sleeping_connection` was cancelled: +/// it must resolve to SQLSTATE 57014 (canceling statement due to user request). +async fn assert_cancelled( + handle: JoinHandle, PgError>>, + label: &str, +) { + let result = timeout(Duration::from_secs(5), handle) + .await + .expect(&format!( + "{label}: cancelled query did not unblock within 5 seconds" + )) + .expect(&format!("{label}: task panicked")); + + let err = result.expect_err(&format!( + "{label}: query should have been cancelled, but it succeeded" + )); + let db_err = err.as_db_error().expect(&format!( + "{label}: expected a PostgreSQL error, not a network error" + )); + + assert_eq!( + db_err.code().code(), + "57014", + "{label}: expected SQLSTATE 57014, got {}", + db_err.code().code() + ); +} + +/// Verify that cancellation is precise: two independent connections both run a long +/// query and each cancel request stops exactly one of them. +/// +/// Steps: +/// 1. Two clients connect through pgdog; each starts `SELECT pg_sleep(60)`. +/// 2. Both queries are confirmed active on specific PG backends via `pg_stat_activity`. +/// 3. Cancel connection 1 → only backend 1 stops; backend 2 remains active. +/// 4. Cancel connection 2 → backend 2 stops. +#[tokio::test] +async fn test_cancel_query() { + let direct = connection_sqlx_direct().await; + + let (pid1, token1, handle1) = start_sleeping_connection("cancel_test").await; + let (pid2, token2, handle2) = start_sleeping_connection("cancel_test").await; + + // Give both queries time to reach their respective backends. + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!( + is_sleeping(&direct, pid1).await, + "connection 1 (backend {pid1}) should be active before any cancel" + ); + assert!( + is_sleeping(&direct, pid2).await, + "connection 2 (backend {pid2}) should be active before any cancel" + ); + + // ── Cancel connection 1 ──────────────────────────────────────────────── + token1.cancel_query(NoTls).await.unwrap(); + + // Wait for the client to receive the cancellation error. + // By the time the handle resolves, the backend has already stopped. + assert_cancelled(handle1, "connection 1").await; + + // Connection 1's backend is gone; connection 2 must still be running. + tokio::time::sleep(Duration::from_millis(100)).await; + assert!( + !is_sleeping(&direct, pid1).await, + "backend {pid1} should be idle after cancelling connection 1" + ); + assert!( + is_sleeping(&direct, pid2).await, + "backend {pid2} should still be active after cancelling connection 1 only" + ); + + // ── Cancel connection 2 ──────────────────────────────────────────────── + token2.cancel_query(NoTls).await.unwrap(); + + assert_cancelled(handle2, "connection 2").await; + + tokio::time::sleep(Duration::from_millis(100)).await; + assert!( + !is_sleeping(&direct, pid2).await, + "backend {pid2} should be idle after cancelling connection 2" + ); +} + +/// Verify that a cancel request carrying a wrong pid and secret is silently rejected: +/// the running query is unaffected and the client does not receive a cancellation error. +/// +/// pgdog's `verify_cancel` gate must reject the request before it reaches the pool, +/// so the backend continues executing as if nothing happened. +#[tokio::test] +async fn test_cancel_query_wrong_secret() { + let direct = connection_sqlx_direct().await; + let app_name = "cancel_test_wrong_secret"; + let (backend_pid, real_cancel_token, query_handle) = start_sleeping_connection(app_name).await; + + // Give the query time to reach the backend. + tokio::time::sleep(Duration::from_millis(300)).await; + + assert!( + is_sleeping(&direct, backend_pid).await, + "query should be running before wrong-secret cancel" + ); + + // Look up the pgdog client pid from the admin interface. + // SHOW CLIENTS exposes the pid (the 'id' column) that pgdog assigned during login — + // the same value that was sent in the K message and that verify_cancel checks against. + let admin = admin_tokio().await; + let messages = admin.simple_query("SHOW CLIENTS").await.unwrap(); + let pgdog_pid: i32 = messages + .iter() + .filter_map(|m| match m { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) + .find(|row| row.get("application_name") == Some(app_name)) + .expect("connection should appear in SHOW CLIENTS") + .get("id") + .expect("id column should be present") + .parse() + .expect("id should be a valid i32"); + + // Send a CancelRequest with the real pgdog client pid but a wrong secret. + // pgdog will find the client in comms by pid, then reject it because + // the secret doesn't match — verify_cancel returns false. + let mut raw = TcpStream::connect("127.0.0.1:6432").await.unwrap(); + let mut buf = BytesMut::new(); + buf.put_i32(16); // total message length (including the length field) + buf.put_i32(80877102); // CancelRequest magic code + buf.put_i32(pgdog_pid); // correct pid + buf.put_i32(0); // wrong secret + raw.write_all(&buf).await.unwrap(); + // pgdog closes the connection silently after processing; no response is sent. + drop(raw); + + // Give pgdog enough time to receive and process the bogus cancel. + tokio::time::sleep(Duration::from_millis(300)).await; + + // The query must still be running — the secret mismatch was caught by verify_cancel. + assert!( + is_sleeping(&direct, backend_pid).await, + "query should still be running after wrong-secret cancel — verify_cancel must have rejected it" + ); + + // Clean up: cancel for real. + real_cancel_token.cancel_query(NoTls).await.unwrap(); + assert_cancelled(query_handle, "wrong-secret test cleanup").await; +} diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index 51bc0cb72..985c9ddb1 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -3,6 +3,7 @@ pub mod auth; pub mod auto_id; pub mod avg; pub mod ban; +pub mod cancel; pub mod client_ids; pub mod connection_recovery; pub mod cross_shard_disabled; diff --git a/pgdog-postgres-types/src/interface.rs b/pgdog-postgres-types/src/interface.rs index 16bab0c1e..9917f31fd 100644 --- a/pgdog-postgres-types/src/interface.rs +++ b/pgdog-postgres-types/src/interface.rs @@ -48,19 +48,10 @@ impl ToDataRowColumn for i64 { } } -impl ToDataRowColumn for Option { +impl ToDataRowColumn for Option { fn to_data_row_column(&self) -> Data { match self { - Some(value) => ToDataRowColumn::to_data_row_column(value), - None => Data::null(), - } - } -} - -impl ToDataRowColumn for Option { - fn to_data_row_column(&self) -> Data { - match self { - Some(value) => ToDataRowColumn::to_data_row_column(value), + Some(value) => value.to_data_row_column(), None => Data::null(), } } diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index f2c4c875c..78d571e41 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -24,7 +24,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "std"] tracing-throttle = "0.4" parking_lot = "0.12" thiserror = "2" -derive_more = { version = "2", features = ["display", "error"] } +derive_more = { version = "2", features = ["display", "error", "from", "into"] } bytes = "1" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } @@ -77,6 +77,7 @@ azure_identity = "0.34.0" azure_core = "0.34.0" crc32c = "0.6.8" bit-vec = "0.8" +smallvec = "1" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-webpki-roots-no-provider"] } hex = "0.4" x509-parser = "0.18" diff --git a/pgdog/src/admin/server.rs b/pgdog/src/admin/server.rs index 846e4cb88..a5381429d 100644 --- a/pgdog/src/admin/server.rs +++ b/pgdog/src/admin/server.rs @@ -9,8 +9,8 @@ use tracing::debug; use crate::frontend::ClientRequest; use crate::net::messages::command_complete::CommandComplete; use crate::net::messages::{ErrorResponse, FromBytes, Protocol, Query, ReadyForQuery}; +use crate::net::ProtocolMessage; use crate::net::ToBytes; -use crate::net::{BackendKeyData, ProtocolMessage}; use super::parser::Parser; use super::prelude::Message; @@ -63,10 +63,7 @@ impl AdminServer { self.messages.extend(messages); self.messages.push_back(ReadyForQuery::idle().message()?); - self.messages = std::mem::take(&mut self.messages) - .into_iter() - .map(|m| m.backend(BackendKeyData::default())) - .collect(); + self.messages = std::mem::take(&mut self.messages).into_iter().collect(); Ok(()) } diff --git a/pgdog/src/admin/show_client_memory.rs b/pgdog/src/admin/show_client_memory.rs index 30a14b799..785b6ff66 100644 --- a/pgdog/src/admin/show_client_memory.rs +++ b/pgdog/src/admin/show_client_memory.rs @@ -42,7 +42,7 @@ impl Command for ShowClientMemory { let user = client.paramters.get_default("user", "postgres"); let database = client.paramters.get_default("database", user); - row.add(client.id.pid as i64) + row.add(client.key.pid()) .add(database) .add(user) .add(client.addr.ip().to_string().as_str()) diff --git a/pgdog/src/admin/show_clients.rs b/pgdog/src/admin/show_clients.rs index 022724143..414e630e7 100644 --- a/pgdog/src/admin/show_clients.rs +++ b/pgdog/src/admin/show_clients.rs @@ -77,7 +77,7 @@ impl Command for ShowClients { let row = self .filter .clone() - .add("id", client.id.pid as i64) + .add("id", client.key.pid()) .add("user", user) .add("database", client.paramters.get_default("database", user)) .add("addr", client.addr.ip().to_string()) diff --git a/pgdog/src/admin/show_server_memory.rs b/pgdog/src/admin/show_server_memory.rs index 1f58a0f6c..88f81ae50 100644 --- a/pgdog/src/admin/show_server_memory.rs +++ b/pgdog/src/admin/show_server_memory.rs @@ -36,7 +36,7 @@ impl Command for ShowServerMemory { let mut messages = vec![rd.message()?]; let stats = stats(); - for (_, server) in stats { + for server in stats { let mut row = DataRow::new(); let memory = &server.stats.memory; @@ -45,7 +45,7 @@ impl Command for ShowServerMemory { .add(server.addr.user.as_str()) .add(server.addr.host.as_str()) .add(server.addr.port as i64) - .add(server.stats.id.pid as i64) + .add(server.stats.id) .add(memory.buffer.reallocs as i64) .add(memory.buffer.reclaims as i64) .add(memory.buffer.bytes_used as i64) diff --git a/pgdog/src/admin/show_servers.rs b/pgdog/src/admin/show_servers.rs index d2c77fab8..287db00c6 100644 --- a/pgdog/src/admin/show_servers.rs +++ b/pgdog/src/admin/show_servers.rs @@ -79,7 +79,7 @@ impl Command for ShowServers { let now = Instant::now(); let now_time = SystemTime::now(); - for (_, server) in stats { + for server in stats { let age = now.duration_since(server.stats.created_at); let request_age = now.duration_since(server.stats.last_used); let request_time = now_time - request_age; @@ -98,11 +98,8 @@ impl Command for ShowServers { format_time(server.stats.created_at_time.into()), ) .add("request_time", format_time(request_time.into())) - .add("remote_pid", server.stats.id.pid as i64) - .add( - "client_id", - server.stats.client_id.map(|client| client.pid as i64), - ) + .add("remote_pid", server.stats.id) + .add("client_id", server.stats.client_id) .add("transactions", server.stats.total.transactions) .add("queries", server.stats.total.queries) .add("rollbacks", server.stats.total.rollbacks) diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index 4b52342dd..38be8a38c 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -23,7 +23,7 @@ use crate::frontend::PreparedStatements; use crate::{ backend::pool::PoolConfig, config::{config, load, set, ConfigAndUsers, ManualQuery, Role, User as ConfigUser}, - net::{messages::BackendKeyData, tls}, + net::{messages::BackendPid, tls}, }; use super::{ @@ -399,7 +399,7 @@ impl Databases { } /// Cancel a query running on one of the databases proxied by the pooler. - pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), Error> { + pub async fn cancel(&self, id: BackendPid) -> Result<(), Error> { for cluster in self.databases.values() { cluster.cancel(id).await?; } diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index d8421212a..eaa80ff08 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -28,7 +28,7 @@ use crate::{ ShardedTable, User, }, frontend::{ClientRequest, RegexParser}, - net::{messages::BackendKeyData, Query}, + net::{messages::BackendPid, Query}, }; use super::{Address, Config, Error, Guard, MirrorStats, Request, Shard, ShardConfig}; @@ -388,7 +388,7 @@ impl Cluster { } /// Cancel a query executed by one of the shards. - pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { + pub async fn cancel(&self, id: BackendPid) -> Result<(), super::super::Error> { for shard in &self.shards { shard.cancel(id).await?; } diff --git a/pgdog/src/backend/pool/connection/binding.rs b/pgdog/src/backend/pool/connection/binding.rs index 071eaaf4c..889a2db34 100644 --- a/pgdog/src/backend/pool/connection/binding.rs +++ b/pgdog/src/backend/pool/connection/binding.rs @@ -2,7 +2,7 @@ use crate::{ frontend::{client::query_engine::TwoPcPhase, ClientRequest}, - net::{parameter::Parameters, BackendKeyData, ProtocolMessage, Query}, + net::{parameter::Parameters, BackendPid, ProtocolMessage, Query}, state::State, }; @@ -420,7 +420,7 @@ impl Binding { /// Link client to server. pub async fn link_client( &mut self, - id: &BackendKeyData, + id: BackendPid, params: &Parameters, transaction_start_stmt: Option<&str>, ) -> Result { diff --git a/pgdog/src/backend/pool/connection/mirror/mod.rs b/pgdog/src/backend/pool/connection/mirror/mod.rs index 73d02d221..a831663f2 100644 --- a/pgdog/src/backend/pool/connection/mirror/mod.rs +++ b/pgdog/src/backend/pool/connection/mirror/mod.rs @@ -15,7 +15,7 @@ use crate::frontend::client::query_engine::{QueryEngine, QueryEngineContext}; use crate::frontend::client::timeouts::Timeouts; use crate::frontend::client::TransactionType; use crate::frontend::{ClientComms, PreparedStatements}; -use crate::net::{BackendKeyData, Parameter, Parameters, Stream}; +use crate::net::{BackendPid, Parameter, Parameters, Stream}; use super::Error; @@ -32,7 +32,7 @@ pub use request::*; #[derive(Debug)] pub struct Mirror { /// Random identifier for this mirror connection. - pub id: BackendKeyData, + pub id: BackendPid, /// Mirror's prepared statements. Should be similar /// to client's statements, if exposure is high. pub prepared_statements: PreparedStatements, @@ -54,7 +54,7 @@ impl Mirror { prepared_statements.set_level(config.prepared_statements()); Self { - id: BackendKeyData::new(), + id: BackendPid::random(), prepared_statements, params: params.clone(), timeouts: Timeouts::from_config(&config.config.general), @@ -95,7 +95,7 @@ impl Mirror { // Same query engine as the client, except with a potentially different database config. let mut query_engine = - QueryEngine::new(¶ms, &ClientComms::new(&BackendKeyData::new()), false)?; + QueryEngine::new(¶ms, &ClientComms::new(BackendPid::random()), false)?; // Mirror traffic handler. let mut mirror = Self::new(¶ms, &config); diff --git a/pgdog/src/backend/pool/connection/multi_shard/mod.rs b/pgdog/src/backend/pool/connection/multi_shard/mod.rs index b97e847fc..0344d8017 100644 --- a/pgdog/src/backend/pool/connection/multi_shard/mod.rs +++ b/pgdog/src/backend/pool/connection/multi_shard/mod.rs @@ -9,7 +9,7 @@ use crate::{ command_complete::CommandComplete, DataRow, FromBytes, Message, Protocol, RowDescription, ToBytes, }, - BackendKeyData, Decoder, ReadyForQuery, + BackendPid, Decoder, ReadyForQuery, }, }; @@ -42,7 +42,7 @@ struct Counters { copy_done: usize, copy_out: usize, copy_data: usize, - first_backend_data: Option, + first_backend_data: Option, } /// Multi-shard state. diff --git a/pgdog/src/backend/pool/connection/multi_shard/test.rs b/pgdog/src/backend/pool/connection/multi_shard/test.rs index b65792a2d..4338cd27e 100644 --- a/pgdog/src/backend/pool/connection/multi_shard/test.rs +++ b/pgdog/src/backend/pool/connection/multi_shard/test.rs @@ -71,11 +71,11 @@ fn test_rd_before_dr() { dr.add(1i64); for _ in 0..2 { let result = multi_shard - .forward(rd.message().unwrap().backend(BackendKeyData::default())) + .forward(rd.message().unwrap().backend(BackendPid::from(1))) .unwrap(); assert!(result.is_none()); // dropped let result = multi_shard - .forward(dr.message().unwrap().backend(BackendKeyData::default())) + .forward(dr.message().unwrap().backend(BackendPid::from(1))) .unwrap(); assert!(result.is_none()); // buffered. } @@ -92,7 +92,7 @@ fn test_rd_before_dr() { CommandComplete::from_str("SELECT 1") .message() .unwrap() - .backend(BackendKeyData::default()), + .backend(BackendPid::from(1)), ) .unwrap(); assert!(result.is_none()); @@ -100,7 +100,7 @@ fn test_rd_before_dr() { for _ in 0..2 { let result = multi_shard.message(); - let id = BackendKeyData::default(); + let id = BackendPid::from(1); assert_eq!( result.map(|m| m.backend(id)), Some(dr.message().unwrap().backend(id)) @@ -109,14 +109,14 @@ fn test_rd_before_dr() { let result = multi_shard .message() - .map(|m| m.backend(BackendKeyData::default())); + .map(|m| m.backend(BackendPid::from(1))); assert_eq!( result, Some( CommandComplete::from_str("SELECT 3") .message() .unwrap() - .backend(BackendKeyData::default()) + .backend(BackendPid::from(1)) ) ); @@ -153,9 +153,9 @@ fn test_omni_command_complete_not_summed() { let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); let mut multi_shard = MultiShard::new(vec![0, 1, 2], &route); - let backend1 = BackendKeyData::legacy(1, 1); - let backend2 = BackendKeyData::legacy(2, 2); - let backend3 = BackendKeyData::legacy(3, 3); + let backend1 = BackendPid::from(1); + let backend2 = BackendPid::from(2); + let backend3 = BackendPid::from(3); // All shards report UPDATE 5 multi_shard @@ -195,8 +195,8 @@ fn test_omni_command_complete_uses_first_shard_row_count() { let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); let mut multi_shard = MultiShard::new(vec![0, 1], &route); - let backend1 = BackendKeyData::legacy(1, 1); - let backend2 = BackendKeyData::legacy(2, 2); + let backend1 = BackendPid::from(1); + let backend2 = BackendPid::from(2); // First shard reports 7 rows multi_shard @@ -230,8 +230,8 @@ fn test_omni_data_rows_only_from_first_server() { let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); let mut multi_shard = MultiShard::new(vec![0, 1], &route); - let backend1 = BackendKeyData::legacy(1, 1); - let backend2 = BackendKeyData::legacy(2, 2); + let backend1 = BackendPid::from(1); + let backend2 = BackendPid::from(2); // Setup: send RowDescription from both shards let rd = RowDescription::new(&[Field::bigint("id")]); diff --git a/pgdog/src/backend/pool/error.rs b/pgdog/src/backend/pool/error.rs index 27e32ba8b..1f74de990 100644 --- a/pgdog/src/backend/pool/error.rs +++ b/pgdog/src/backend/pool/error.rs @@ -1,7 +1,7 @@ //! Connection pool errors. use thiserror::Error; -use crate::net::BackendKeyData; +use crate::net::BackendPid; #[derive(Debug, Error, PartialEq, Clone, Copy)] pub enum Error { @@ -69,10 +69,7 @@ pub enum Error { PoolUnhealthy, #[error("checked in untracked connection: {0}")] - UntrackedConnCheckin(BackendKeyData), - - #[error("mapping missing: {0}")] - MappingMissing(usize), + UntrackedConnCheckin(BackendPid), #[error("fast shutdown failed")] FastShutdown, @@ -95,7 +92,6 @@ impl Error { | Self::NoDatabases | Self::PubSubDisabled | Self::PoolNoHealthTarget(_) - | Self::MappingMissing(_) // Admin decisions — respect them. | Self::ManualBan // Programming errors. @@ -137,6 +133,5 @@ mod tests { assert!(!Error::PubSubDisabled.is_retryable()); assert!(!Error::FastShutdown.is_retryable()); assert!(!Error::NoShard(0).is_retryable()); - assert!(!Error::MappingMissing(0).is_retryable()); } } diff --git a/pgdog/src/backend/pool/inner.rs b/pgdog/src/backend/pool/inner.rs index 7e822f530..12ca5431b 100644 --- a/pgdog/src/backend/pool/inner.rs +++ b/pgdog/src/backend/pool/inner.rs @@ -6,13 +6,11 @@ use std::fmt::Display; use crate::backend::{stats::Counts as BackendCounts, Server}; use crate::backend::{ConnectReason, DisconnectReason}; -use crate::net::messages::BackendKeyData; +use crate::net::messages::{BackendKeyData, BackendPid}; use tokio::time::Instant; -use super::{ - lsn_monitor::ReplicaLag, Config, Error, Mapping, Oids, Pool, Request, Stats, Taken, Waiter, -}; +use super::{lsn_monitor::ReplicaLag, Config, Error, Oids, Pool, Request, Stats, Taken, Waiter}; /// Pool internals protected by a mutex. #[derive(Default)] @@ -116,15 +114,15 @@ impl Inner { self.taken.len() } - /// Get backend IDs for all currently checked out servers. - pub(super) fn checked_out_server_ids(&self) -> Vec { - self.taken.servers() + /// Cancel key for the server currently assigned to this client. + #[inline] + pub(super) fn cancel_key(&self, client: BackendPid) -> Option<&BackendKeyData> { + self.taken.cancel_key(client) } - /// Find the server currently linked to this client, if any. - #[inline] - pub(super) fn peer(&self, client_id: &BackendKeyData) -> Option { - self.taken.server(client_id) + /// All cancel keys for currently checked-out server connections. + pub(super) fn cancel_keys(&self) -> impl Iterator { + self.taken.cancel_keys() } /// How many connections can be removed from the pool @@ -238,10 +236,8 @@ impl Inner { #[inline(always)] pub(super) fn take(&mut self, request: &Request) -> Result>, Error> { if let Some(conn) = self.idle_connections.pop() { - self.taken.take(&Mapping { - client: request.id, - server: *(conn.id()), - })?; + let cancel_key = conn.key().clone(); + self.taken.take(request.id, cancel_key); Ok(Some(conn)) } else { @@ -254,15 +250,12 @@ impl Inner { #[inline] pub(super) fn put(&mut self, mut conn: Box, now: Instant) -> Result<(), Error> { // Try to give it to a client that's been waiting, if any. - let id = *conn.id(); + let cancel_key = conn.key().clone(); while let Some(waiter) = self.waiting.pop_front() { if let Err(conn_ret) = waiter.tx.send(Ok(conn)) { conn = conn_ret.unwrap(); // SAFETY: We sent Ok(conn), we'll get back Ok(conn) if channel is closed. } else { - self.taken.take(&Mapping { - server: id, - client: waiter.request.id, - })?; + self.taken.take(waiter.request.id, cancel_key); self.stats.counts.server_assignment_count += 1; self.stats.counts.wait_time += now.duration_since(waiter.request.created_at); return Ok(()); @@ -396,16 +389,16 @@ impl Inner { /// This happens if the waiter timed out, e.g. checkout timeout, /// or the caller got cancelled. #[inline] - pub(super) fn remove_waiter(&mut self, id: &BackendKeyData) { + pub(super) fn remove_waiter(&mut self, id: BackendPid) { if let Some(waiter) = self.waiting.pop_front() { - if waiter.request.id != *id { + if waiter.request.id != id { // Put me back. self.waiting.push_front(waiter); // Slow search, but we should be somewhere towards the front // if the runtime is doing scheduling correctly. for (i, waiter) in self.waiting.iter().enumerate() { - if waiter.request.id == *id { + if waiter.request.id == id { self.waiting.remove(i); break; } @@ -476,7 +469,7 @@ mod test { use tokio::sync::oneshot::channel; - use crate::net::messages::BackendKeyData; + use crate::net::messages::{BackendKeyData, BackendPid}; use super::*; @@ -496,14 +489,8 @@ mod test { let mut inner = Inner::default(); let server = Box::new(Server::default()); - let server_id = *server.id(); - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + + inner.taken.take(BackendPid::random(), server.key().clone()); let result = inner .maybe_check_in(server, Instant::now(), BackendCounts::default(), false) @@ -523,14 +510,8 @@ mod test { }; let server = Box::new(Server::default()); - let server_id = *server.id(); - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + + inner.taken.take(BackendPid::random(), server.key().clone()); inner .maybe_check_in(server, Instant::now(), BackendCounts::default(), false) @@ -548,14 +529,8 @@ mod test { }; let server = Box::new(Server::default()); - let server_id = *server.id(); - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + + inner.taken.take(BackendPid::random(), server.key().clone()); let result = inner .maybe_check_in(server, Instant::now(), BackendCounts::default(), false) @@ -574,16 +549,9 @@ mod test { }; let server = Box::new(Server::new_error()); - let server_id = *server.id(); // Simulate server being checked out - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + inner.taken.take(BackendPid::random(), server.key().clone()); assert_eq!(inner.checked_out(), 1); let result = inner @@ -591,7 +559,7 @@ mod test { .unwrap(); assert!(result.server_error); - assert!(inner.taken.is_empty()); // Error server removed from taken + assert_eq!(inner.checked_out(), 0); // Error server removed from taken assert_eq!(inner.idle(), 0); // Error server not added to idle } @@ -662,11 +630,7 @@ mod test { inner.idle_connections.push(Box::new(Server::default())); inner .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: BackendKeyData::new(), - }) - .unwrap(); + .take(BackendPid::random(), BackendKeyData::random_legacy()); assert_eq!(inner.idle(), 2); assert_eq!(inner.checked_out(), 1); @@ -714,14 +678,8 @@ mod test { // Add a connection let server = Box::new(Server::default()); - let server_id = *server.id(); - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + + inner.taken.take(BackendPid::random(), server.key().clone()); inner .maybe_check_in(server, Instant::now(), BackendCounts::default(), false) @@ -744,7 +702,9 @@ mod test { assert_eq!(inner.total(), 0); // Simulate taking a connection - inner.taken.take(&Mapping::default()).unwrap(); + inner + .taken + .take(BackendPid::random(), BackendKeyData::random_legacy()); assert_eq!(inner.total(), 1); assert_eq!(inner.checked_out(), 1); @@ -763,14 +723,8 @@ mod test { inner.config.max_age = Duration::from_millis(60_000); let server = Box::new(Server::default()); - let server_id = *server.id(); - inner - .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: server_id, - }) - .unwrap(); + + inner.taken.take(BackendPid::random(), server.key().clone()); inner .maybe_check_in( @@ -787,45 +741,38 @@ mod test { #[test] fn test_peer_lookup() { let mut inner = Inner::default(); - let client_id = BackendKeyData::new(); - let server_id = BackendKeyData::new(); + let client_id = BackendPid::random(); + let server_id = BackendPid::random(); + let cancel_key = BackendKeyData::legacy(server_id, 0); - assert_eq!(inner.peer(&client_id), None); + assert_eq!(inner.cancel_key(client_id), None); - inner - .taken - .take(&Mapping { - client: client_id, - server: server_id, - }) - .unwrap(); + inner.taken.take(client_id, cancel_key.clone()); - assert_eq!(inner.peer(&client_id), Some(server_id)); + assert_eq!(inner.cancel_key(client_id), Some(&cancel_key)); } #[test] fn test_taken_server_returns_server_when_mapped() { let mut taken = Taken::default(); - let client_id = BackendKeyData::new(); - let server_id = BackendKeyData::new(); + let client_id = BackendPid::random(); + let server_id = BackendPid::random(); + let cancel_key = BackendKeyData::legacy(server_id, 0); // No mapping yet - assert_eq!(taken.server(&client_id), None); + assert_eq!(taken.cancel_key(client_id), None); // Add mapping - taken - .take(&Mapping { - client: client_id, - server: server_id, - }) - .unwrap(); + taken.take(client_id, cancel_key.clone()); - // Server should be returned for mapped client - assert_eq!(taken.server(&client_id), Some(server_id)); + // Cancel key should be returned for mapped client, pid matches server_id + let stored = taken.cancel_key(client_id).unwrap(); + assert_eq!(stored.pid(), server_id); + assert_eq!(stored, &cancel_key); // Different client should return None - let other_client = BackendKeyData::new(); - assert_eq!(taken.server(&other_client), None); + let other_client = BackendPid::random(); + assert_eq!(taken.cancel_key(other_client), None); } #[test] @@ -931,7 +878,7 @@ mod test { }); assert_eq!(inner.waiting.len(), 3); - inner.remove_waiter(&target_id); + inner.remove_waiter(target_id); assert_eq!(inner.waiting.len(), 2); } @@ -971,18 +918,10 @@ mod test { // Add connections above minimum but all are checked out (no idle) inner .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: BackendKeyData::new(), - }) - .unwrap(); + .take(BackendPid::random(), BackendKeyData::random_legacy()); inner .taken - .take(&Mapping { - client: BackendKeyData::new(), - server: BackendKeyData::new(), - }) - .unwrap(); + .take(BackendPid::random(), BackendKeyData::random_legacy()); // Add a waiting client inner.waiting.push_back(Waiter { @@ -1022,15 +961,12 @@ mod test { #[test] fn test_set_taken() { let mut inner = Inner::default(); - let mapping = Mapping { - client: BackendKeyData::new(), - server: BackendKeyData::new(), - }; + let client = BackendPid::random(); assert_eq!(inner.checked_out(), 0); let mut taken = Taken::default(); - taken.take(&mapping).unwrap(); + taken.take(client, BackendKeyData::random_legacy()); inner.set_taken(taken); assert_eq!(inner.checked_out(), 1); @@ -1125,9 +1061,9 @@ mod test { // Add two idle connections to the pool let server1 = Box::new(Server::default()); - let server1_id = *server1.id(); + let server1_id = server1.id(); let server2 = Box::new(Server::default()); - let server2_id = *server2.id(); + let server2_id = server2.id(); inner.idle_connections.push(server1); inner.idle_connections.push(server2); @@ -1136,7 +1072,7 @@ mod test { assert_eq!(inner.total(), 2); // Same client ID for both requests - let client_id = BackendKeyData::new(); + let client_id = BackendPid::random(); let request = Request::unrouted(client_id); // Check out first connection @@ -1177,7 +1113,7 @@ mod test { assert_eq!(inner.total(), 2); // Verify the specific servers are back in the idle pool - let idle_ids: Vec<_> = inner.idle_conns().iter().map(|s| *s.id()).collect(); + let idle_ids: Vec<_> = inner.idle_conns().iter().map(|s| s.id()).collect(); assert!(idle_ids.contains(&server1_id)); assert!(idle_ids.contains(&server2_id)); } diff --git a/pgdog/src/backend/pool/lb/mod.rs b/pgdog/src/backend/pool/lb/mod.rs index 340f3258e..f13283ea8 100644 --- a/pgdog/src/backend/pool/lb/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -12,7 +12,7 @@ use rand::seq::SliceRandom; use tokio::{sync::Notify, time::timeout}; use tracing::warn; -use crate::net::messages::BackendKeyData; +use crate::net::messages::BackendPid; use crate::{ config::{LoadBalancingStrategy, ReadWriteSplit, Role}, net::Parameters, @@ -261,7 +261,7 @@ impl LoadBalancer { } /// Cancel a query if one is running. - pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { + pub async fn cancel(&self, id: BackendPid) -> Result<(), super::super::Error> { for target in &self.targets { target.pool.cancel(id).await?; } diff --git a/pgdog/src/backend/pool/mapping.rs b/pgdog/src/backend/pool/mapping.rs deleted file mode 100644 index 377c93325..000000000 --- a/pgdog/src/backend/pool/mapping.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::net::messages::BackendKeyData; - -/// Mapping between a client and a server. -#[derive(Debug, Copy, Clone, PartialEq, Default)] -pub(super) struct Mapping { - /// Client ID. - pub(super) client: BackendKeyData, - /// Server ID. - pub(super) server: BackendKeyData, -} diff --git a/pgdog/src/backend/pool/mod.rs b/pgdog/src/backend/pool/mod.rs index 10a3ad513..ce5635800 100644 --- a/pgdog/src/backend/pool/mod.rs +++ b/pgdog/src/backend/pool/mod.rs @@ -14,7 +14,6 @@ pub mod healthcheck; pub mod inner; pub mod lb; pub mod lsn_monitor; -pub mod mapping; pub mod mirror_stats; pub mod monitor; pub mod oids; @@ -49,7 +48,6 @@ pub use stats::Stats; use comms::Comms; use inner::Inner; -use mapping::Mapping; use shard::ShardConfig; use taken::Taken; use waiting::{Waiter, Waiting}; diff --git a/pgdog/src/backend/pool/pool_impl.rs b/pgdog/src/backend/pool/pool_impl.rs index 6c923f740..479f2adc1 100644 --- a/pgdog/src/backend/pool/pool_impl.rs +++ b/pgdog/src/backend/pool/pool_impl.rs @@ -14,7 +14,7 @@ use tracing::{debug, error}; use crate::backend::pool::LsnStats; use crate::backend::{ConnectReason, DisconnectReason, Server, ServerOptions}; use crate::config::PoolerMode; -use crate::net::messages::BackendKeyData; +use crate::net::messages::BackendPid; use crate::net::{Parameter, Parameters}; use super::inner::CheckInResult; @@ -267,17 +267,13 @@ impl Pool { Ok(()) } - /// Server connection used by the client. - pub fn peer(&self, id: &BackendKeyData) -> Option { - self.lock().peer(id) - } - /// Send a cancellation request if the client is connected to a server. - pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { - if let Some(server) = self.peer(id) { - Server::cancel(self.addr(), &server).await?; + pub async fn cancel(&self, id: BackendPid) -> Result<(), super::super::Error> { + // Must NOT hold the lock while doing async I/O. + let key = self.lock().cancel_key(id).cloned(); + if let Some(key) = key { + Server::cancel(self.addr(), key).await?; } - Ok(()) } @@ -334,13 +330,15 @@ impl Pool { /// Send a cancellation request for all running queries. pub async fn cancel_all(&self) -> Result<(), Error> { - let taken = self.lock().checked_out_server_ids(); let addr = self.addr().clone(); - - try_join_all(taken.iter().map(|id| Server::cancel(&addr, id))) + let futures: Vec<_> = self + .lock() + .cancel_keys() + .map(|key| Server::cancel(&addr, key.clone())) + .collect(); + try_join_all(futures) .await .map_err(|_| Error::FastShutdown)?; - Ok(()) } diff --git a/pgdog/src/backend/pool/request.rs b/pgdog/src/backend/pool/request.rs index 22267ddba..55e4b42ab 100644 --- a/pgdog/src/backend/pool/request.rs +++ b/pgdog/src/backend/pool/request.rs @@ -1,17 +1,17 @@ use tokio::time::Instant; -use crate::net::messages::BackendKeyData; +use crate::net::messages::BackendPid; /// Connection request. #[derive(Clone, Debug, Copy)] pub struct Request { - pub id: BackendKeyData, + pub id: BackendPid, pub created_at: Instant, pub read: bool, } impl Request { - pub fn new(id: BackendKeyData, read: bool) -> Self { + pub fn new(id: BackendPid, read: bool) -> Self { Self { id, created_at: Instant::now(), @@ -19,7 +19,7 @@ impl Request { } } - pub fn unrouted(id: BackendKeyData) -> Self { + pub fn unrouted(id: BackendPid) -> Self { Self { id, created_at: Instant::now(), @@ -30,6 +30,6 @@ impl Request { impl Default for Request { fn default() -> Self { - Self::unrouted(BackendKeyData::new()) + Self::unrouted(BackendPid::random()) } } diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index 728e1d2b7..143b582ce 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -13,7 +13,7 @@ use crate::backend::pool::lb::ban::Ban; use crate::backend::PubSubListener; use crate::backend::Schema; use crate::config::{LoadBalancingStrategy, ReadWriteSplit, Role}; -use crate::net::messages::BackendKeyData; +use crate::net::messages::BackendPid; use crate::net::{NotificationResponse, Parameters}; use super::{Error, Guard, LoadBalancer, Pool, PoolConfig, Request}; @@ -195,7 +195,7 @@ impl Shard { /// /// If these connection pools aren't running the query sent by this client, this is a no-op. /// - pub async fn cancel(&self, id: &BackendKeyData) -> Result<(), super::super::Error> { + pub async fn cancel(&self, id: BackendPid) -> Result<(), super::super::Error> { self.lb.cancel(id).await?; Ok(()) diff --git a/pgdog/src/backend/pool/taken.rs b/pgdog/src/backend/pool/taken.rs index c660c5ea9..87f1f203f 100644 --- a/pgdog/src/backend/pool/taken.rs +++ b/pgdog/src/backend/pool/taken.rs @@ -1,69 +1,231 @@ +use std::collections::hash_map::Entry; + use fnv::FnvHashMap as HashMap; -use crate::net::BackendKeyData; +use crate::net::{BackendKeyData, BackendPid}; -use super::{Error, Mapping}; +use super::Error; +/// Track the link between a frontend connection and the backend connection it +/// currently holds, so cancel requests can be routed. +/// +/// A Postgres CancelRequest carries only the frontend's identity; it has no +/// way to name the backend connection pgdog assigned to that frontend. This +/// struct stores that mapping for the pool's checked-out connections. #[derive(Default, Clone, Debug)] pub(super) struct Taken { - /// Guaranteed to be unique per client/server connection. - taken: HashMap, - /// Guaranteed to be unique because servers can only be mapped - /// to one client at a time. - server_client: HashMap, - /// Not unique, but will contain the server that's actively executing a query - /// for that client. - client_server: HashMap, - /// Counter that guarantees uniqueness. Wraparound happens after a gazillion billion transactions. - counter: usize, + /// Frontend pid -> cancel key of the backend connection currently + /// assigned to that frontend. Cancel routing reads this directly. + frontend_to_cancel: HashMap, + /// Reverse index from backend pid to the frontend pid that holds it. On + /// check-in the pool only knows the backend pid, so we use this to find + /// which `frontend_to_cancel` entry to drop. + backend_to_frontend: HashMap, } impl Taken { #[inline] - pub(super) fn take(&mut self, mapping: &Mapping) -> Result<(), Error> { - self.taken.insert(self.counter, *mapping); - self.server_client.insert(mapping.server, self.counter); - self.client_server.insert(mapping.client, mapping.server); - self.counter = self.counter.wrapping_add(1); - Ok(()) + pub(super) fn take(&mut self, frontend: BackendPid, cancel_key: BackendKeyData) { + self.backend_to_frontend.insert(cancel_key.pid(), frontend); + self.frontend_to_cancel.insert(frontend, cancel_key); } #[inline] - pub(super) fn check_in(&mut self, server: &BackendKeyData) -> Result<(), Error> { - let counter = self - .server_client - .remove(server) - .ok_or(Error::UntrackedConnCheckin(*server))?; - let mapping = self - .taken - .remove(&counter) - .ok_or(Error::MappingMissing(counter))?; - self.client_server.remove(&mapping.client); - + pub(super) fn check_in(&mut self, backend: BackendPid) -> Result<(), Error> { + let frontend = self + .backend_to_frontend + .remove(&backend) + .ok_or(Error::UntrackedConnCheckin(backend))?; + // Drop the frontend's cancel entry only when it still names this + // backend. The deferred check-in from a prior `Server::drop` may fire + // after the frontend has already taken a newer backend; in that case + // the entry belongs to the newer backend and must not be touched. + if let Entry::Occupied(entry) = self.frontend_to_cancel.entry(frontend) { + if entry.get().pid() == backend { + entry.remove(); + } + } Ok(()) } #[inline] pub(super) fn len(&self) -> usize { - self.taken.len() + self.backend_to_frontend.len() } - #[allow(dead_code)] + #[cfg(test)] + #[inline] pub(super) fn is_empty(&self) -> bool { - self.len() == 0 + self.backend_to_frontend.is_empty() } + /// Backend cancel key for this frontend's current checkout. #[inline] - pub(super) fn server(&self, client: &BackendKeyData) -> Option { - self.client_server.get(client).copied() + pub(super) fn cancel_key(&self, frontend: BackendPid) -> Option<&BackendKeyData> { + self.frontend_to_cancel.get(&frontend) } - pub(super) fn servers(&self) -> Vec { - self.client_server.values().copied().collect() + /// All cancel keys for currently checked-out backend connections. For + /// frontends with multiple concurrent checkouts, only the latest is + /// returned (matches prior behavior). + pub(super) fn cancel_keys(&self) -> impl Iterator { + self.frontend_to_cancel.values() } #[cfg(test)] pub(super) fn clear(&mut self) { - self.taken.clear(); + *self = Self::default(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn key(pid: BackendPid) -> BackendKeyData { + BackendKeyData::legacy(pid, 0) + } + + #[test] + fn empty_state_has_no_entries() { + let taken = Taken::default(); + assert_eq!(taken.len(), 0); + assert!(taken.is_empty()); + assert_eq!(taken.cancel_key(BackendPid::random()), None); + assert_eq!(taken.cancel_keys().count(), 0); + } + + #[test] + fn take_then_check_in_round_trip() { + let mut taken = Taken::default(); + let frontend = BackendPid::random(); + let backend = BackendPid::random(); + let cancel_key = key(backend); + + taken.take(frontend, cancel_key.clone()); + assert_eq!(taken.len(), 1); + assert_eq!(taken.cancel_key(frontend), Some(&cancel_key)); + assert_eq!(taken.cancel_keys().count(), 1); + + taken.check_in(backend).unwrap(); + assert!(taken.is_empty()); + assert_eq!(taken.cancel_key(frontend), None); + } + + #[test] + fn check_in_unknown_backend_errors() { + let mut taken = Taken::default(); + let unknown = BackendPid::random(); + assert_eq!( + taken.check_in(unknown).unwrap_err(), + Error::UntrackedConnCheckin(unknown), + ); + } + + #[test] + fn cancel_key_recovers_server_pid() { + // The map relies on cancel_key.pid() == backend pid as an invariant. + let mut taken = Taken::default(); + let frontend = BackendPid::random(); + let backend = BackendPid::random(); + + taken.take(frontend, key(backend)); + assert_eq!(taken.cancel_key(frontend).map(|k| k.pid()), Some(backend)); + } + + #[test] + fn distinct_frontends_are_independent() { + let mut taken = Taken::default(); + let (fa, ba) = (BackendPid::random(), BackendPid::random()); + let (fb, bb) = (BackendPid::random(), BackendPid::random()); + + taken.take(fa, key(ba)); + taken.take(fb, key(bb)); + assert_eq!(taken.len(), 2); + + taken.check_in(ba).unwrap(); + assert_eq!(taken.len(), 1); + assert_eq!(taken.cancel_key(fa), None); + assert_eq!(taken.cancel_key(fb).map(|k| k.pid()), Some(bb)); + + taken.check_in(bb).unwrap(); + assert!(taken.is_empty()); + } + + /// Regression: the `Server::drop` race documented on `check_in`. + /// + /// Sequence reproduced here: + /// 1. Frontend F takes backend A. + /// 2. F's guard drops; `Server::drop` defers the check-in to a tokio task. + /// 3. Before that task runs, F takes backend B (entry for F overwritten). + /// 4. The deferred check-in for A finally fires. + /// + /// After step 4, F is still actively using B, so cancel routing for F + /// must still resolve to B. Final check-in of B clears everything. + #[test] + fn deferred_check_in_after_same_frontend_retake() { + let mut taken = Taken::default(); + let frontend = BackendPid::random(); + let backend_a = BackendPid::random(); + let backend_b = BackendPid::random(); + let key_a = key(backend_a); + let key_b = key(backend_b); + + // Step 1: take A. + taken.take(frontend, key_a.clone()); + assert_eq!(taken.cancel_key(frontend), Some(&key_a)); + + // Step 3: F retakes with B before A's deferred check-in fires. + taken.take(frontend, key_b.clone()); + assert_eq!(taken.len(), 2, "both backends still tracked"); + assert_eq!(taken.cancel_key(frontend), Some(&key_b), "latest wins"); + + // Step 4: deferred check-in for A. Must NOT touch F's entry, + // since it now belongs to B. + taken.check_in(backend_a).unwrap(); + assert_eq!(taken.len(), 1); + assert_eq!( + taken.cancel_key(frontend), + Some(&key_b), + "cancel routing for F must still target the live backend B", + ); + + // Normal check-in of B clears the entry. + taken.check_in(backend_b).unwrap(); + assert!(taken.is_empty()); + assert_eq!(taken.cancel_key(frontend), None); + } + + /// Reverse order of the race: A's deferred check-in fires *before* F + /// retakes. Sanity check that the normal path still works. + #[test] + fn deferred_check_in_before_same_frontend_retake() { + let mut taken = Taken::default(); + let frontend = BackendPid::random(); + let backend_a = BackendPid::random(); + let backend_b = BackendPid::random(); + + taken.take(frontend, key(backend_a)); + taken.check_in(backend_a).unwrap(); + assert!(taken.is_empty()); + + taken.take(frontend, key(backend_b)); + assert_eq!(taken.cancel_key(frontend).map(|k| k.pid()), Some(backend_b)); + taken.check_in(backend_b).unwrap(); + assert!(taken.is_empty()); + } + + #[test] + fn double_check_in_second_errors() { + let mut taken = Taken::default(); + let frontend = BackendPid::random(); + let backend = BackendPid::random(); + + taken.take(frontend, key(backend)); + taken.check_in(backend).unwrap(); + assert_eq!( + taken.check_in(backend).unwrap_err(), + Error::UntrackedConnCheckin(backend), + ); } } diff --git a/pgdog/src/backend/pool/test/mod.rs b/pgdog/src/backend/pool/test/mod.rs index 8c2a76bb5..db6521f4e 100644 --- a/pgdog/src/backend/pool/test/mod.rs +++ b/pgdog/src/backend/pool/test/mod.rs @@ -74,7 +74,7 @@ async fn test_pool_checkout() { let pool = pool(); let conn = pool.get(&Request::default()).await.unwrap(); - let id = *(conn.id()); + let id = conn.id(); assert!(conn.done()); assert!(conn.done()); @@ -93,7 +93,7 @@ async fn test_pool_checkout() { drop(conn); // Return conn to the pool. let conn = pool.get(&Request::default()).await.unwrap(); - assert_eq!(conn.id(), &id); + assert_eq!(conn.id(), id); } // This test flakes in CI because of iffy hardware I think. @@ -279,7 +279,7 @@ async fn test_incomplete_request_recovery() { for query in ["SELECT 1", "BEGIN"] { let mut conn = pool.get(&Request::default()).await.unwrap(); - let conn_id = *(conn.id()); + let conn_id = conn.id(); conn.send(&vec![ProtocolMessage::from(Query::new(query))].into()) .await @@ -299,7 +299,7 @@ async fn test_incomplete_request_recovery() { // Verify the same connection is reused let conn = pool.get(&Request::default()).await.unwrap(); - assert_eq!(conn.id(), &conn_id); + assert_eq!(conn.id(), conn_id); } } diff --git a/pgdog/src/backend/pool/waiting.rs b/pgdog/src/backend/pool/waiting.rs index 741aec70a..2f26ae929 100644 --- a/pgdog/src/backend/pool/waiting.rs +++ b/pgdog/src/backend/pool/waiting.rs @@ -13,7 +13,7 @@ pub(super) struct Waiting { impl Drop for Waiting { fn drop(&mut self) { if self.waiting { - self.pool.lock().remove_waiter(&self.request.id); + self.pool.lock().remove_waiter(self.request.id); } } } @@ -91,7 +91,7 @@ pub(super) struct Waiter { mod tests { use super::*; use crate::backend::pool::Pool; - use crate::net::messages::BackendKeyData; + use crate::net::messages::BackendPid; use tokio::time::{sleep, timeout, Duration}; #[tokio::test] @@ -104,7 +104,7 @@ mod tests { for i in 0..num_tasks { let pool_clone = pool.clone(); - let request = Request::unrouted(BackendKeyData::new()); + let request = Request::unrouted(BackendPid::random()); let mut waiting = Waiting::new(pool_clone, &request).unwrap(); let wait_task = tokio::spawn(async move { waiting.wait().await }); @@ -167,7 +167,7 @@ mod tests { let _conn = pool.get(&Request::default()).await.unwrap(); - let request = Request::unrouted(BackendKeyData::new()); + let request = Request::unrouted(BackendPid::random()); let waiter_pool = pool.clone(); let get_conn = async move { let mut waiting = Waiting::new(waiter_pool.clone(), &request).unwrap(); diff --git a/pgdog/src/backend/pub_sub/listener.rs b/pgdog/src/backend/pub_sub/listener.rs index 1e2559a32..7837fb64a 100644 --- a/pgdog/src/backend/pub_sub/listener.rs +++ b/pgdog/src/backend/pub_sub/listener.rs @@ -18,7 +18,7 @@ use crate::{ backend::{self, pool::Error, ConnectReason, DisconnectReason, Pool}, config::config, net::{ - BackendKeyData, FromBytes, NotificationResponse, Parameter, Parameters, Protocol, + BackendPid, FromBytes, NotificationResponse, Parameter, Parameters, Protocol, ProtocolMessage, Query, ToBytes, }, }; @@ -164,7 +164,7 @@ impl PubSubListener { server .link_client( - &BackendKeyData::new(), + BackendPid::random(), &Parameters::from(vec![Parameter { name: "application_name".into(), value: "PgDog Pub/Sub Listener".into(), diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index fff6f219d..04e66fe0e 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -23,8 +23,9 @@ use crate::{ frontend::ClientRequest, net::{ messages::{ - hello::SslReply, Authentication, BackendKeyData, ErrorResponse, FromBytes, Message, - ParameterStatus, Password, Protocol, Query, ReadyForQuery, Startup, Terminate, ToBytes, + hello::SslReply, Authentication, BackendKeyData, BackendPid, ErrorResponse, FromBytes, + Message, ParameterStatus, Password, Protocol, Query, ReadyForQuery, Startup, Terminate, + ToBytes, }, Close, MessageBuffer, Parameter, ProtocolMessage, Sync, }, @@ -46,7 +47,7 @@ use crate::{net::tweak, state::State}; pub struct Server { addr: Address, stream: Option, - id: BackendKeyData, + key: BackendKeyData, params: Parameters, changed_params: Parameters, client_params: Parameters, @@ -313,7 +314,7 @@ impl Server { // so they don't send BackendKeyData. // Generating a random one is fine, it just won't work when we try to // cancel a query with this secret. - let id = key_data.unwrap_or(BackendKeyData::new()); + let key = key_data.unwrap_or_else(BackendKeyData::random_legacy); let params: Parameters = params.into(); info!( @@ -325,11 +326,12 @@ impl Server { if stream.is_tls() { "🔒" } else { "" }, ); + let pid = key.pid(); let mut server = Server { addr: addr.clone(), stream: Some(stream), - id, - stats: Stats::connect(id, addr, ¶ms, &options, &config.config.memory), + key, + stats: Stats::connect(pid, addr, ¶ms, &options, &config.config.memory), replication_mode: options.replication_mode(), params, changed_params: Parameters::default(), @@ -356,11 +358,9 @@ impl Server { } /// Request query cancellation for the given backend server identifier. - pub async fn cancel(addr: &Address, id: &BackendKeyData) -> Result<(), Error> { + pub async fn cancel(addr: &Address, id: BackendKeyData) -> Result<(), Error> { let mut stream = TcpStream::connect(addr.addr().await?).await?; - stream - .write_all(&Startup::Cancel { id: *id }.to_bytes()) - .await?; + stream.write_all(&Startup::Cancel { id }.to_bytes()).await?; stream.flush().await?; Ok(()) @@ -463,11 +463,11 @@ impl Server { pub async fn read(&mut self) -> Result { let message = loop { if let Some(message) = self.prepared_statements.state_mut().get_simulated() { - return Ok(message.backend(self.id)); + return Ok(message.backend(self.key.pid())); } match self.stream_buffer.read(self.stream.as_mut().unwrap()).await { Ok(message) => { - let message = message.stream(self.streaming).backend(self.id); + let message = message.stream(self.streaming).backend(self.key.pid()); match self.prepared_statements.forward(&message) { Ok(forward) => { if forward { @@ -579,7 +579,7 @@ impl Server { /// Synchronize parameters between client and server. pub async fn link_client( &mut self, - id: &BackendKeyData, + id: BackendPid, params: &Parameters, start_transaction: Option<&str>, ) -> Result { @@ -994,8 +994,14 @@ impl Server { /// Server connection unique identifier. #[inline] - pub fn id(&self) -> &BackendKeyData { - &self.id + pub fn id(&self) -> BackendPid { + self.key.pid() + } + + /// Backend key data for query cancellation. + #[inline] + pub fn key(&self) -> &BackendKeyData { + &self.key } /// Number of password attempts it took to authenticate this connection. @@ -1204,16 +1210,17 @@ pub mod test { impl Default for Server { fn default() -> Self { - let id = BackendKeyData::new(); + let cancel_key = BackendKeyData::random_legacy(); + let pid = cancel_key.pid(); let addr = Address::default(); Self { stream: None, - id, + key: cancel_key, params: Parameters::default(), changed_params: Parameters::default(), client_params: Parameters::default(), stats: Stats::connect( - id, + pid, &addr, &Parameters::default(), &ServerOptions::default(), @@ -1316,7 +1323,7 @@ pub mod test { .await .unwrap(); socket - .write_all(&BackendKeyData::new().to_bytes()) + .write_all(&BackendKeyData::random_legacy().to_bytes()) .await .unwrap(); socket @@ -1377,7 +1384,7 @@ pub mod test { .await .unwrap(); socket - .write_all(&BackendKeyData::new().to_bytes()) + .write_all(&BackendKeyData::random_legacy().to_bytes()) .await .unwrap(); socket @@ -2159,7 +2166,7 @@ pub mod test { params.insert("application_name", "test_sync_params"); params.insert("is_superuser", opposite); let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await .unwrap(); assert_eq!(changed, 1); @@ -2181,7 +2188,7 @@ pub mod test { ); let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await .unwrap(); assert_eq!(changed, 0); @@ -2270,12 +2277,12 @@ pub mod test { let mut server = test_server().await; let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await?; assert_eq!(changed, 1); let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await?; assert_eq!(changed, 0); @@ -2284,12 +2291,12 @@ pub mod test { params.insert("application_name", value); let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await?; assert_eq!(changed, 2); // RESET, SET. let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await?; assert_eq!(changed, 0); } @@ -2867,14 +2874,14 @@ pub mod test { // Sync params to server let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await .unwrap(); assert_eq!(changed, 1); // Same params should not need re-sync let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await .unwrap(); assert_eq!(changed, 0); @@ -2884,7 +2891,7 @@ pub mod test { // Now link_client should need to re-sync because client_params was cleared let changed = server - .link_client(&BackendKeyData::new(), ¶ms, None) + .link_client(BackendPid::random(), ¶ms, None) .await .unwrap(); assert!( diff --git a/pgdog/src/backend/stats.rs b/pgdog/src/backend/stats.rs index ceb7018fe..535e3635d 100644 --- a/pgdog/src/backend/stats.rs +++ b/pgdog/src/backend/stats.rs @@ -12,22 +12,26 @@ use tokio::time::Instant; use crate::{ backend::{pool::stats::MemoryStats, Pool, ServerOptions}, config::Memory, - net::{messages::BackendKeyData, Parameters}, + net::{messages::BackendPid, Parameters}, state::State, }; use super::pool::Address; -static STATS: Lazy>>>> = +/// Per-connection key for the global stats map. +/// +/// Keyed by `(Address, BackendPid)` rather than `BackendPid` alone: Postgres +/// pids are only unique within a single backend instance, so a pgdog proxying +/// multiple backends (different hosts/dbs/users) can otherwise see two live +/// connections with the same pid collide and silently evict each other. +type ServerKey = (Address, BackendPid); + +static STATS: Lazy>>>> = Lazy::new(|| RwLock::new(HashMap::default())); -/// Get a copy of latest stats. -pub fn stats() -> HashMap { - STATS - .read() - .iter() - .map(|(k, v)| (*k, v.lock().clone())) - .collect() +/// Get a snapshot of all connected-server stats. +pub fn stats() -> Vec { + STATS.read().values().map(|v| v.lock().clone()).collect() } /// Get idle-in-transaction server connections for connection pool. @@ -46,18 +50,18 @@ pub fn idle_in_transaction(pool: &Pool) -> usize { #[derive(Clone, Debug, Copy)] pub struct ServerStats { pub inner: pgdog_stats::server::Stats, - pub id: BackendKeyData, + pub id: BackendPid, pub last_used: Instant, pub last_healthcheck: Option, pub created_at: Instant, - pub client_id: Option, + pub client_id: Option, query_timer: Option, transaction_timer: Option, idle_in_transaction_timer: Option, } impl ServerStats { - fn new(id: BackendKeyData, options: &ServerOptions, config: &Memory) -> Self { + fn new(id: BackendPid, options: &ServerOptions, config: &Memory) -> Self { let now = Instant::now(); let inner = pgdog_stats::server::Stats { memory: *MemoryStats::new(config), @@ -99,7 +103,7 @@ pub struct ConnectedServer { pub stats: ServerStats, pub addr: Address, pub application_name: String, - pub client: Option, + pub client: Option, } /// Server statistics handle. @@ -111,12 +115,13 @@ pub struct ConnectedServer { pub struct Stats { local: ServerStats, shared: Arc>, + key: ServerKey, } impl Stats { /// Register new server with statistics. pub fn connect( - id: BackendKeyData, + id: BackendPid, addr: &Address, params: &Parameters, options: &ServerOptions, @@ -132,9 +137,10 @@ impl Stats { }; let shared = Arc::new(Mutex::new(server)); - STATS.write().insert(id, Arc::clone(&shared)); + let key: ServerKey = (addr.clone(), id); + STATS.write().insert(key.clone(), Arc::clone(&shared)); - Stats { local, shared } + Stats { local, shared, key } } /// Sync local stats to shared (called on I/O operations). @@ -156,8 +162,8 @@ impl Stats { self.sync_to_shared(); } - pub fn link_client(&mut self, client_name: &str, server_name: &str, id: &BackendKeyData) { - self.local.client_id = Some(*id); + pub fn link_client(&mut self, client_name: &str, server_name: &str, id: BackendPid) { + self.local.client_id = Some(id); if client_name != server_name { let mut guard = self.shared.lock(); guard.stats.client_id = self.local.client_id; @@ -307,7 +313,7 @@ impl Stats { /// Server is closing. pub(super) fn disconnect(&self) { - STATS.write().remove(&self.local.id); + STATS.write().remove(&self.key); } /// Reset last_checkout counts. diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index 6844784c3..59b169764 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -232,7 +232,7 @@ impl Client { let auth_type = &config.config.general.auth_type; let passthrough = config.config.general.passthrough_auth(); let id = BackendKeyData::new_client(protocol_version); - let comms = ClientComms::new(&id); + let comms = ClientComms::new(id.pid()); let log_connections = config.config.general.log_connections; // Check if we need to ask the client for its password in plaintext @@ -320,7 +320,7 @@ impl Client { // Get connection parameters. These will be most likely cached, // unless the pool was just created. - let server_params = match conn.parameters(&Request::unrouted(id)).await { + let server_params = match conn.parameters(&Request::unrouted(id.pid())).await { Ok(params) => params, Err(err) => { if err.no_server() { @@ -344,7 +344,7 @@ impl Client { stream.send(&id).await?; stream.send_flush(&ReadyForQuery::idle()).await?; - comms.connect(addr, ¶ms); + comms.connect(id.clone(), addr, ¶ms); if config.config.general.log_connections { info!( @@ -394,7 +394,8 @@ impl Client { connect_params.insert("database", "pgdog"); connect_params.merge(params); - let id = BackendKeyData::new(); + let id = BackendKeyData::random_legacy(); + let pid = id.pid(); let mut prepared_statements = PreparedStatements::new(); prepared_statements.level = config().config.general.prepared_statements; @@ -402,7 +403,7 @@ impl Client { stream, addr: SocketAddr::from(([127, 0, 0, 1], 1234)), id, - comms: ClientComms::new(&id), + comms: ClientComms::new(pid), streaming: false, prepared_statements, admin: false, @@ -417,11 +418,6 @@ impl Client { } } - /// Get client's identifier. - pub fn id(&self) -> BackendKeyData { - self.id - } - /// Run the client and log disconnect. async fn spawn_internal(&mut self) { match self.run().await { diff --git a/pgdog/src/frontend/client/query_engine/connect.rs b/pgdog/src/frontend/client/query_engine/connect.rs index 57bcc6026..f21fd6194 100644 --- a/pgdog/src/frontend/client/query_engine/connect.rs +++ b/pgdog/src/frontend/client/query_engine/connect.rs @@ -30,7 +30,7 @@ impl QueryEngine { let connect_route = connect_route.unwrap_or(context.client_request.route()); - let request = Request::new(*context.id, connect_route.is_read()); + let request = Request::new(context.id, connect_route.is_read()); self.stats.waiting(request.created_at); self.comms.update_stats(self.stats); diff --git a/pgdog/src/frontend/client/query_engine/context.rs b/pgdog/src/frontend/client/query_engine/context.rs index 8b42ef666..5a0ec4d1d 100644 --- a/pgdog/src/frontend/client/query_engine/context.rs +++ b/pgdog/src/frontend/client/query_engine/context.rs @@ -5,14 +5,14 @@ use crate::{ router::parser::rewrite::statement::plan::RewriteResult, Client, ClientRequest, PreparedStatements, }, - net::{BackendKeyData, Parameters, Stream}, + net::{BackendPid, Parameters, Stream}, }; #[allow(dead_code)] /// Context passed to the query engine to execute a query. pub struct QueryEngineContext<'a> { /// Client ID running the query. - pub(super) id: &'a BackendKeyData, + pub(super) id: BackendPid, /// Prepared statements cache. pub(super) prepared_statements: &'a mut PreparedStatements, /// Client session parameters. @@ -48,7 +48,7 @@ impl<'a> QueryEngineContext<'a> { let memory_stats = client.memory_stats(); Self { - id: &client.id, + id: client.id.pid(), prepared_statements: &mut client.prepared_statements, params: &mut client.params, client_request: &mut client.client_request, @@ -75,7 +75,7 @@ impl<'a> QueryEngineContext<'a> { /// Create context from mirror. pub fn new_mirror(mirror: &'a mut Mirror, buffer: &'a mut ClientRequest) -> Self { Self { - id: &mirror.id, + id: mirror.id, prepared_statements: &mut mirror.prepared_statements, params: &mut mirror.params, client_request: buffer, diff --git a/pgdog/src/frontend/client/query_engine/discard.rs b/pgdog/src/frontend/client/query_engine/discard.rs index 1715addcf..05a2295ca 100644 --- a/pgdog/src/frontend/client/query_engine/discard.rs +++ b/pgdog/src/frontend/client/query_engine/discard.rs @@ -1,4 +1,4 @@ -use crate::net::{BackendKeyData, CommandComplete, Protocol, ReadyForQuery}; +use crate::net::{CommandComplete, Protocol, ReadyForQuery}; use super::*; @@ -13,9 +13,7 @@ impl QueryEngine { let bytes_sent = context .stream .send_many(&[ - CommandComplete::new("DISCARD") - .message()? - .backend(BackendKeyData::default()), + CommandComplete::new("DISCARD").message()?, ReadyForQuery::in_transaction(context.in_transaction()).message()?, ]) .await?; diff --git a/pgdog/src/frontend/client/query_engine/end_transaction.rs b/pgdog/src/frontend/client/query_engine/end_transaction.rs index b8d5b1cfd..fc74b149f 100644 --- a/pgdog/src/frontend/client/query_engine/end_transaction.rs +++ b/pgdog/src/frontend/client/query_engine/end_transaction.rs @@ -1,4 +1,4 @@ -use crate::net::{BackendKeyData, CommandComplete, NoticeResponse, Protocol, ReadyForQuery}; +use crate::net::{CommandComplete, NoticeResponse, Protocol, ReadyForQuery}; use super::*; @@ -23,7 +23,7 @@ impl QueryEngine { } else { vec![] }; - messages.push(cmd.message()?.backend(BackendKeyData::default())); + messages.push(cmd.message()?); messages.push(ReadyForQuery::idle().message()?); context.stream.send_many(&messages).await? diff --git a/pgdog/src/frontend/client/query_engine/pub_sub.rs b/pgdog/src/frontend/client/query_engine/pub_sub.rs index c9544c8d5..9d1240913 100644 --- a/pgdog/src/frontend/client/query_engine/pub_sub.rs +++ b/pgdog/src/frontend/client/query_engine/pub_sub.rs @@ -1,4 +1,4 @@ -use crate::net::{BackendKeyData, CommandComplete, Protocol, ReadyForQuery}; +use crate::net::{CommandComplete, Protocol, ReadyForQuery}; use super::*; @@ -61,9 +61,7 @@ impl QueryEngine { let bytes_sent = context .stream .send_many(&[ - CommandComplete::new(command) - .message()? - .backend(BackendKeyData::default()), + CommandComplete::new(command).message()?, ReadyForQuery::in_transaction(context.in_transaction()).message()?, ]) .await?; diff --git a/pgdog/src/frontend/client/query_engine/start_transaction.rs b/pgdog/src/frontend/client/query_engine/start_transaction.rs index f525466d8..91ad3b8f5 100644 --- a/pgdog/src/frontend/client/query_engine/start_transaction.rs +++ b/pgdog/src/frontend/client/query_engine/start_transaction.rs @@ -1,9 +1,6 @@ use crate::{ frontend::client::TransactionType, - net::{ - BackendKeyData, BindComplete, CommandComplete, NoticeResponse, ParseComplete, Protocol, - ReadyForQuery, - }, + net::{BindComplete, CommandComplete, NoticeResponse, ParseComplete, Protocol, ReadyForQuery}, }; use super::*; @@ -29,9 +26,7 @@ impl QueryEngine { context .stream .send_many(&[ - CommandComplete::new_begin() - .message()? - .backend(BackendKeyData::default()), + CommandComplete::new_begin().message()?, ReadyForQuery::in_transaction(context.in_transaction()).message()?, ]) .await? @@ -57,17 +52,11 @@ impl QueryEngine { 'B' => reply.push(BindComplete.message()?), 'D' | 'H' => (), 'E' => reply.push(if in_transaction { - CommandComplete::new_begin() - .message()? - .backend(BackendKeyData::default()) + CommandComplete::new_begin().message()? } else if !rollback { - CommandComplete::new_commit() - .message()? - .backend(BackendKeyData::default()) + CommandComplete::new_commit().message()? } else { - CommandComplete::new_rollback() - .message()? - .backend(BackendKeyData::default()) + CommandComplete::new_rollback().message()? }), 'S' => { if rollback && !context.in_transaction() { diff --git a/pgdog/src/frontend/client/test/test_client.rs b/pgdog/src/frontend/client/test/test_client.rs index f23c766e2..eb4ea52f1 100644 --- a/pgdog/src/frontend/client/test/test_client.rs +++ b/pgdog/src/frontend/client/test/test_client.rs @@ -16,7 +16,7 @@ use crate::{ router::{parser::Shard, sharding::ContextBuilder}, Client, }, - net::{BackendKeyData, ErrorResponse, Message, Parameters, Protocol, Stream}, + net::{ErrorResponse, Message, Parameters, Protocol, Stream}, }; /// Try to convert a Message to the specified type. @@ -57,7 +57,7 @@ pub async fn read_message(conn: &mut TcpStream) -> Message { payload.put_i32(len); payload.put(Bytes::from(rest)); - Message::new(payload.freeze()).backend(BackendKeyData::default()) + Message::new(payload.freeze()) } /// Send a protocol message to a TCP stream. diff --git a/pgdog/src/frontend/comms.rs b/pgdog/src/frontend/comms.rs index 2887e2d16..c82725c84 100644 --- a/pgdog/src/frontend/comms.rs +++ b/pgdog/src/frontend/comms.rs @@ -13,7 +13,7 @@ use parking_lot::Mutex; use tokio::sync::Notify; use tokio_util::task::TaskTracker; -use crate::net::messages::BackendKeyData; +use crate::net::messages::{BackendKeyData, BackendPid}; use crate::net::Parameters; use super::{ConnectedClient, Stats}; @@ -31,9 +31,9 @@ struct Global { shutdown: Arc, offline: AtomicBool, // This uses the FNV hasher, which is safe, - // because BackendKeyData is randomly generated by us, + // because BackendPid is randomly generated by us, // not by the client. - clients: Mutex>, + clients: Mutex>, tracker: TaskTracker, } @@ -63,7 +63,7 @@ impl Comms { } /// Get all connected clients. - pub fn clients(&self) -> HashMap { + pub fn clients(&self) -> HashMap { self.global.clients.lock().clone() } @@ -81,40 +81,47 @@ impl Comms { self.global.clients.lock().len() } - /// There are no connected clients. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - /// New client connected. - pub fn connect(&self, id: &BackendKeyData, addr: SocketAddr, params: &Parameters) { + pub fn connect(&self, key: BackendKeyData, addr: SocketAddr, params: &Parameters) { + let pid = key.pid(); self.global .clients .lock() - .insert(*id, ConnectedClient::new(id, addr, params)); + .insert(pid, ConnectedClient::new(key, addr, params)); } /// Update client parameters. - pub fn update_params(&self, id: &BackendKeyData, params: Parameters) { + pub fn update_params(&self, id: BackendPid, params: Parameters) { let mut guard = self.global.clients.lock(); - if let Some(entry) = guard.get_mut(id) { + if let Some(entry) = guard.get_mut(&id) { entry.paramters = params; } } /// Client disconnected. - pub fn disconnect(&self, id: &BackendKeyData) { - self.global.clients.lock().remove(id); + pub fn disconnect(&self, id: BackendPid) { + self.global.clients.lock().remove(&id); } /// Update stats. - pub fn update_stats(&self, id: &BackendKeyData, stats: Stats) { + pub fn update_stats(&self, id: BackendPid, stats: Stats) { let mut guard = self.global.clients.lock(); - if let Some(entry) = guard.get_mut(id) { + if let Some(entry) = guard.get_mut(&id) { entry.stats = stats; } } + /// Verify that a cancel request has a valid secret for the given client. + pub fn verify_cancel(&self, key: &BackendKeyData) -> bool { + let pid = key.pid; + self.global + .clients + .lock() + .get(&pid) + .map(|client| client.key.secret == key.secret) + .unwrap_or(false) + } + /// Notify clients pgDog is shutting down. pub fn shutdown(&self) { self.global.offline.store(true, Ordering::Relaxed); @@ -136,7 +143,7 @@ impl Comms { #[derive(Debug, Clone)] pub struct ClientComms { comms: Comms, - id: BackendKeyData, + id: BackendPid, } impl Deref for ClientComms { @@ -149,25 +156,71 @@ impl Deref for ClientComms { impl ClientComms { pub fn disconnect(&self) { - self.comms.disconnect(&self.id); + self.comms.disconnect(self.id); } pub fn update_stats(&self, stats: Stats) { - self.comms.update_stats(&self.id, stats); + self.comms.update_stats(self.id, stats); } - pub fn new(id: &BackendKeyData) -> Self { - Self { - id: *id, - comms: comms(), - } + pub fn new(id: BackendPid) -> Self { + Self { id, comms: comms() } } - pub fn connect(&self, addr: SocketAddr, params: &Parameters) { - self.comms.connect(&self.id, addr, params) + pub fn connect(&self, key: BackendKeyData, addr: SocketAddr, params: &Parameters) { + self.comms.connect(key, addr, params) } pub fn update_params(&self, params: &Parameters) { - self.comms.update_params(&self.id, params.clone()); + self.comms.update_params(self.id, params.clone()); + } +} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + + use super::*; + use crate::net::{messages::BackendKeyData, Parameters}; + + fn addr() -> SocketAddr { + "127.0.0.1:5432".parse().unwrap() + } + + #[test] + fn test_verify_cancel_correct_secret() { + let comms = Comms::default(); + let key = BackendKeyData::random_legacy(); + comms.connect(key.clone(), addr(), &Parameters::default()); + assert!(comms.verify_cancel(&key)); + } + + #[test] + fn test_verify_cancel_wrong_secret() { + let comms = Comms::default(); + let key = BackendKeyData::random_legacy(); + comms.connect(key.clone(), addr(), &Parameters::default()); + + // Same pid, different secret. + let wrong = BackendKeyData::legacy(key.pid(), 0); + assert!(!comms.verify_cancel(&wrong)); + } + + #[test] + fn test_verify_cancel_unknown_pid() { + let comms = Comms::default(); + // Nothing registered — any key must be rejected. + assert!(!comms.verify_cancel(&BackendKeyData::random_legacy())); + } + + #[test] + fn test_verify_cancel_after_disconnect() { + let comms = Comms::default(); + let key = BackendKeyData::random_legacy(); + comms.connect(key.clone(), addr(), &Parameters::default()); + assert!(comms.verify_cancel(&key)); + + comms.disconnect(key.pid()); + assert!(!comms.verify_cancel(&key)); } } diff --git a/pgdog/src/frontend/connected_client.rs b/pgdog/src/frontend/connected_client.rs index dd7318e32..74c1c5c43 100644 --- a/pgdog/src/frontend/connected_client.rs +++ b/pgdog/src/frontend/connected_client.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Local}; use std::net::SocketAddr; -use crate::net::{BackendKeyData, Parameters}; +use crate::net::{messages::BackendKeyData, Parameters}; use super::Stats; @@ -16,15 +16,15 @@ pub struct ConnectedClient { pub connected_at: DateTime, /// Client connection parameters. pub paramters: Parameters, - /// Identifier. - pub id: BackendKeyData, + /// Cancel key identifying this client and its secret. + pub key: BackendKeyData, } impl ConnectedClient { /// New connected client. - pub fn new(id: &BackendKeyData, addr: SocketAddr, params: &Parameters) -> Self { + pub fn new(key: BackendKeyData, addr: SocketAddr, params: &Parameters) -> Self { Self { - id: *id, + key, stats: Stats::new(), addr, connected_at: Local::now(), diff --git a/pgdog/src/frontend/listener.rs b/pgdog/src/frontend/listener.rs index dbc314b21..bc1e037ef 100644 --- a/pgdog/src/frontend/listener.rs +++ b/pgdog/src/frontend/listener.rs @@ -133,7 +133,7 @@ impl Listener { { // Shutdown timeout elapsed; cancel any still-running queries before tearing pools down. let cancel_futures = comms.clients().into_keys().map(|id| async move { - if let Err(err) = databases().cancel(&id).await { + if let Err(err) = databases().cancel(id).await { error!(?id, "cancel request failed during shutdown: {err}"); } }); @@ -239,7 +239,9 @@ impl Listener { } Startup::Cancel { id } => { - let _ = databases().cancel(&id).await; + if comms().verify_cancel(&id) { + let _ = databases().cancel(id.pid).await; + } break; } } diff --git a/pgdog/src/lib.rs b/pgdog/src/lib.rs index 2aa95cc9a..f7a9376fc 100644 --- a/pgdog/src/lib.rs +++ b/pgdog/src/lib.rs @@ -2,6 +2,9 @@ #![allow(clippy::result_unit_err)] #![deny(clippy::print_stdout)] +#[macro_use] +extern crate derive_more; + pub mod admin; pub mod auth; pub mod backend; diff --git a/pgdog/src/net/messages/backend_key.rs b/pgdog/src/net/messages/backend_key.rs index 2c3b20350..76816d503 100644 --- a/pgdog/src/net/messages/backend_key.rs +++ b/pgdog/src/net/messages/backend_key.rs @@ -1,47 +1,31 @@ //! BackendKeyData (B) message. use std::fmt::Display; -use std::sync::atomic::AtomicI32; -use std::sync::atomic::Ordering; use crate::net::messages::code; use crate::net::messages::prelude::*; use crate::net::messages::protocol_version::ProtocolVersion; use bytes::Buf; -use once_cell::sync::Lazy; -use rand::Rng; +use smallvec::SmallVec; + +use super::backend_pid::BackendPid; -static COUNTER: Lazy = Lazy::new(|| AtomicI32::new(0)); +use rand::Rng; const LEGACY_SECRET_LEN: usize = std::mem::size_of::(); const EXTENDED_SECRET_LEN: usize = 32; -pub const MAX_SECRET_LEN: usize = 256; - -// This wraps around. -fn next_counter() -> i32 { - COUNTER.fetch_add(1, Ordering::SeqCst) -} +const MAX_SECRET_LEN: usize = 256; /// Variable-length cancel secret. -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct SecretKey { - len: u16, - bytes: [u8; MAX_SECRET_LEN], -} - -impl Default for SecretKey { - fn default() -> Self { - Self::legacy(0) - } + bytes: SmallVec<[u8; EXTENDED_SECRET_LEN]>, } impl SecretKey { /// Create a 3.0-compatible secret key from a 4-byte integer. pub fn legacy(secret: i32) -> Self { - let mut bytes = [0; MAX_SECRET_LEN]; - bytes[..LEGACY_SECRET_LEN].copy_from_slice(&secret.to_be_bytes()); Self { - len: LEGACY_SECRET_LEN as u16, - bytes, + bytes: SmallVec::from_slice(&secret.to_be_bytes()), } } @@ -52,12 +36,10 @@ impl SecretKey { "cancel secret must be between 1 and {MAX_SECRET_LEN} bytes" ); - let mut bytes = [0; MAX_SECRET_LEN]; - rand::rng().fill(&mut bytes[..len]); - Self { - len: len as u16, - bytes, - } + let mut bytes = SmallVec::with_capacity(len); + bytes.resize(len, 0); + rand::rng().fill(bytes.as_mut_slice()); + Self { bytes } } /// Create a secret key from raw wire bytes. @@ -66,22 +48,19 @@ impl SecretKey { return Err(crate::net::Error::UnexpectedPayload); } - let mut bytes = [0; MAX_SECRET_LEN]; - bytes[..secret.len()].copy_from_slice(secret); Ok(Self { - len: secret.len() as u16, - bytes, + bytes: SmallVec::from_slice(secret), }) } /// Secret bytes as they appear on the wire. pub fn as_slice(&self) -> &[u8] { - &self.bytes[..self.len()] + self.bytes.as_slice() } /// Secret length in bytes. pub fn len(&self) -> usize { - self.len as usize + self.bytes.len() } } @@ -100,10 +79,13 @@ impl Display for SecretKey { } /// BackendKeyData (B) -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)] +/// +/// Holds the full cancel secret alongside the pid. Use `BackendPid` instead +/// when only the process identity is needed (HashMap keys, routing, stats). +#[derive(Clone, Debug, PartialEq, Eq)] pub struct BackendKeyData { /// Process ID. - pub pid: i32, + pub pid: BackendPid, /// Process secret. pub secret: SecretKey, } @@ -115,10 +97,15 @@ impl Display for BackendKeyData { } impl BackendKeyData { + /// Return the `BackendPid` for this key (pid only, no secret). + pub fn pid(&self) -> BackendPid { + self.pid + } + /// Create new random BackendKeyData (B) message. - pub fn new() -> Self { + pub fn random_legacy() -> Self { Self { - pid: rand::rng().random(), + pid: BackendPid::random(), secret: SecretKey::random(LEGACY_SECRET_LEN), } } @@ -129,20 +116,20 @@ impl BackendKeyData { pub fn new_client(protocol_version: ProtocolVersion) -> Self { // The client must echo this secret back in CancelRequest, so its shape // has to match the negotiated frontend protocol version. - let secret_len = if protocol_version == ProtocolVersion::V3_2 { + let secret_len = if protocol_version.supports_extended_cancel_key() { EXTENDED_SECRET_LEN } else { LEGACY_SECRET_LEN }; Self { - pid: next_counter(), + pid: BackendPid::next(), secret: SecretKey::random(secret_len), } } /// Create legacy 3.0-compatible backend key data. - pub fn legacy(pid: i32, secret: i32) -> Self { + pub fn legacy(pid: BackendPid, secret: i32) -> Self { Self { pid, secret: SecretKey::legacy(secret), @@ -154,7 +141,7 @@ impl ToBytes for BackendKeyData { fn to_bytes(&self) -> bytes::Bytes { let mut payload = Payload::named(self.code()); - payload.put_i32(self.pid); + payload.put_i32(i32::from(self.pid)); payload.put_slice(self.secret.as_slice()); payload.freeze() @@ -176,7 +163,7 @@ impl FromBytes for BackendKeyData { return Err(Error::UnexpectedPayload); } - let pid = bytes.get_i32(); + let pid = BackendPid::from(bytes.get_i32()); let secret = SecretKey::from_slice(&bytes.copy_to_bytes(secret_len))?; Ok(Self { pid, secret }) @@ -191,12 +178,12 @@ impl Protocol for BackendKeyData { #[cfg(test)] mod tests { - use super::{BackendKeyData, ProtocolVersion, SecretKey}; + use super::{BackendKeyData, BackendPid, ProtocolVersion, SecretKey}; use crate::net::messages::{FromBytes, ToBytes}; #[test] fn test_backend_key_roundtrip_legacy() { - let key = BackendKeyData::legacy(42, 1234); + let key = BackendKeyData::legacy(BackendPid::from(42), 1234); let roundtrip = BackendKeyData::from_bytes(key.to_bytes()).unwrap(); assert_eq!(roundtrip, key); assert_eq!(roundtrip.secret.len(), 4); @@ -205,7 +192,7 @@ mod tests { #[test] fn test_backend_key_roundtrip_extended() { let key = BackendKeyData { - pid: 7, + pid: BackendPid::from(7), secret: SecretKey::random(32), }; let roundtrip = BackendKeyData::from_bytes(key.to_bytes()).unwrap(); @@ -213,6 +200,17 @@ mod tests { assert_eq!(roundtrip.secret.len(), 32); } + #[test] + fn test_backend_key_roundtrip_max_secret_len() { + let key = BackendKeyData { + pid: BackendPid::from(9), + secret: SecretKey::random(256), + }; + let roundtrip = BackendKeyData::from_bytes(key.to_bytes()).unwrap(); + assert_eq!(roundtrip, key); + assert_eq!(roundtrip.secret.len(), 256); + } + #[test] fn test_new_client_uses_protocol_specific_secret_length() { assert_eq!( diff --git a/pgdog/src/net/messages/backend_pid.rs b/pgdog/src/net/messages/backend_pid.rs new file mode 100644 index 000000000..3816338eb --- /dev/null +++ b/pgdog/src/net/messages/backend_pid.rs @@ -0,0 +1,45 @@ +//! BackendPid — lightweight process identity for routing and statistics. +//! +//! This is the canonical identifier for a backend server or client connection. + +use std::sync::atomic::{AtomicI32, Ordering}; + +use pgdog_postgres_types::ToDataRowColumn; + +use once_cell::sync::Lazy; +use rand::Rng; + +static COUNTER: Lazy = Lazy::new(|| AtomicI32::new(0)); + +/// Increment the global connection counter and return the old value. +fn next_counter() -> i32 { + COUNTER.fetch_add(1, Ordering::SeqCst) +} + +/// Opaque backend-process identifier. +/// +/// Used as `HashMap` keys, struct fields, and function arguments everywhere the +/// cancel *secret* is not needed. +#[derive(Copy, Clone, Debug, Display, Hash, PartialEq, Eq, PartialOrd, Ord, From, Into)] +pub struct BackendPid(i32); + +impl BackendPid { + /// Create a random `BackendPid` (used for server connections). + pub fn random() -> Self { + Self(rand::rng().random()) + } + + /// Create the next sequential `BackendPid` (used for client connections). + pub fn next() -> Self { + Self(next_counter()) + } +} + +impl ToDataRowColumn for BackendPid { + fn to_data_row_column(&self) -> pgdog_postgres_types::Data { + // Displayed as a decimal integer, same as i64. + pgdog_postgres_types::Data::from(bytes::Bytes::copy_from_slice( + self.0.to_string().as_bytes(), + )) + } +} diff --git a/pgdog/src/net/messages/hello.rs b/pgdog/src/net/messages/hello.rs index 1b1cd0f91..004e15f56 100644 --- a/pgdog/src/net/messages/hello.rs +++ b/pgdog/src/net/messages/hello.rs @@ -2,7 +2,7 @@ use crate::net::{ c_string, - messages::{BackendKeyData, ProtocolVersion}, + messages::{BackendKeyData, BackendPid, ProtocolVersion}, parameter::{ParameterValue, Parameters}, Error, }; @@ -49,7 +49,7 @@ impl Startup { 80877104 => Ok(Startup::GssEnc), // CancelRequest (F) 80877102 => { - let pid = stream.read_i32().await?; + let pid = BackendPid::from(stream.read_i32().await?); // CancelRequest secrets became variable-length in protocol 3.2. let secret_len = usize::try_from(len) .ok() @@ -204,7 +204,7 @@ impl super::ToBytes for Startup { let mut payload = Payload::new(); payload.put_i32(80877102); - payload.put_i32(id.pid); + payload.put_i32(i32::from(id.pid)); payload.put_slice(id.secret.as_slice()); payload.freeze() diff --git a/pgdog/src/net/messages/mod.rs b/pgdog/src/net/messages/mod.rs index a125e7ab7..4988f2834 100644 --- a/pgdog/src/net/messages/mod.rs +++ b/pgdog/src/net/messages/mod.rs @@ -1,6 +1,7 @@ //! PostgreSQL wire protocol messages. pub mod auth; pub mod backend_key; +pub mod backend_pid; pub mod bind; pub mod bind_complete; pub mod buffer; @@ -39,6 +40,7 @@ pub mod terminate; pub use auth::{Authentication, Password}; pub use backend_key::BackendKeyData; +pub use backend_pid::BackendPid; pub use bind::{Bind, Format, Parameter, ParameterWithFormat}; pub use bind_complete::BindComplete; pub use buffer::MessageBuffer; @@ -109,13 +111,18 @@ pub trait Protocol: ToBytes + FromBytes + std::fmt::Debug { #[derive(Clone, PartialEq, Default, Copy, Debug)] pub enum Source { - Backend(BackendKeyData), + /// Message synthesised by pgdog itself (not from any real connection). + /// This is the default: any message constructed without an explicit source is internal. #[default] + Internal, + /// Message received from a PostgreSQL backend connection. + Backend(BackendPid), + /// Message received from the client (frontend). Frontend, } impl Source { - pub fn backend_id(&self) -> Option { + pub fn backend_id(&self) -> Option { if let Self::Backend(id) = self { Some(*id) } else { @@ -144,27 +151,35 @@ impl std::fmt::Debug for Message { match self.code() { 'Q' => Query::from_bytes(self.payload()).unwrap().fmt(f), 'D' => match self.source { - Source::Backend(_) => DataRow::from_bytes(self.payload()).unwrap().fmt(f), Source::Frontend => Describe::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) | Source::Internal => { + DataRow::from_bytes(self.payload()).unwrap().fmt(f) + } }, 'P' => Parse::from_bytes(self.payload()).unwrap().fmt(f), 'B' => Bind::from_bytes(self.payload()).unwrap().fmt(f), 'S' => match self.source { Source::Frontend => f.debug_struct("Sync").finish(), - Source::Backend(_) => ParameterStatus::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) | Source::Internal => { + ParameterStatus::from_bytes(self.payload()).unwrap().fmt(f) + } }, '1' => ParseComplete::from_bytes(self.payload()).unwrap().fmt(f), '2' => BindComplete::from_bytes(self.payload()).unwrap().fmt(f), '3' => f.debug_struct("CloseComplete").finish(), 'E' => match self.source { Source::Frontend => f.debug_struct("Execute").finish(), - Source::Backend(_) => ErrorResponse::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) | Source::Internal => { + ErrorResponse::from_bytes(self.payload()).unwrap().fmt(f) + } }, 'T' => RowDescription::from_bytes(self.payload()).unwrap().fmt(f), 'Z' => ReadyForQuery::from_bytes(self.payload()).unwrap().fmt(f), 'C' => match self.source { - Source::Backend(_) => CommandComplete::from_bytes(self.payload()).unwrap().fmt(f), Source::Frontend => Close::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) | Source::Internal => { + CommandComplete::from_bytes(self.payload()).unwrap().fmt(f) + } }, 'd' => CopyData::from_bytes(self.payload()).unwrap().fmt(f), 'v' => NegotiateProtocolVersion::from_bytes(self.payload()) @@ -242,7 +257,7 @@ impl Message { } /// This message is coming from the backend. - pub fn backend(mut self, id: BackendKeyData) -> Self { + pub fn backend(mut self, id: BackendPid) -> Self { self.source = Source::Backend(id); self } @@ -253,6 +268,12 @@ impl Message { self } + /// This message was synthesised by pgdog (not from any real connection). + pub fn internal(mut self) -> Self { + self.source = Source::Internal; + self + } + /// Where is this message coming from? pub fn source(&self) -> Source { self.source diff --git a/pgdog/src/net/messages/protocol_version.rs b/pgdog/src/net/messages/protocol_version.rs index 065d297d2..d5446f7c4 100644 --- a/pgdog/src/net/messages/protocol_version.rs +++ b/pgdog/src/net/messages/protocol_version.rs @@ -49,6 +49,12 @@ impl ProtocolVersion { matches!(self, Self::V3_0 | Self::V3_2) } + /// Whether this protocol version uses the extended (variable-length) + /// `BackendKeyData` cancel secret introduced in 3.2. + pub fn supports_extended_cancel_key(self) -> bool { + self >= Self::V3_2 + } + /// Highest supported protocol version that can satisfy this request. /// /// PostgreSQL minor-version negotiation is a downgrade mechanism, so a @@ -101,4 +107,13 @@ mod tests { ); assert_eq!(ProtocolVersion::new(4, 0).negotiated(), None); } + + #[test] + fn test_supports_extended_cancel_key() { + assert!(!ProtocolVersion::V3_0.supports_extended_cancel_key()); + assert!(!ProtocolVersion::new(3, 1).supports_extended_cancel_key()); + assert!(ProtocolVersion::V3_2.supports_extended_cancel_key()); + assert!(ProtocolVersion::new(3, 3).supports_extended_cancel_key()); + assert!(ProtocolVersion::new(4, 0).supports_extended_cancel_key()); + } }