diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index 70a8051..586ed3d 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -21,6 +21,15 @@ type SessionProgress struct { Stalled bool } +// SessionConn is the client transport for a managed session. It supports both +// writing pgwire packets (used to deliver a FATAL ErrorResponse on worker +// crash before closing) and closing the underlying TCP. *tls.Conn — the type +// the pgwire handshake hands us — satisfies this interface naturally. +type SessionConn interface { + io.Writer + io.Closer +} + // ManagedSession tracks a client session bound to a worker. type ManagedSession struct { PID int32 @@ -28,7 +37,12 @@ type ManagedSession struct { Protocol string // "postgres" or "flight" SessionToken string Executor *flightclient.FlightExecutor - connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop + // conn is the client TCP/TLS connection. Used by the worker-crash path + // to deliver a FATAL ErrorResponse and then close the socket — the + // FATAL is what lets clients (libpq, dbt's psycopg2 adapter) cleanly + // surface "your session was lost" instead of waiting forever on a + // half-open TCP that just got reset. + conn SessionConn // Cached query progress from worker health checks. queryProgress atomic.Value // stores *SessionProgress (or nil) @@ -296,23 +310,42 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) { errorFn(pid) sm.mu.Lock() session, ok := sm.sessions[pid] + var executor *flightclient.FlightExecutor + var conn SessionConn if ok { delete(sm.sessions, pid) - if session.Executor != nil { - _ = session.Executor.Close() - } - // Close the TCP connection to unblock the message loop's read. - // This causes the session goroutine to exit instead of looping - // with ErrWorkerDead on every query. The deferred close in - // handleConnection will also call Close() on the same conn; - // that's harmless (net.Conn.Close on a closed socket returns - // an error which is discarded). - if session.connCloser != nil { - _ = session.connCloser.Close() - } + executor = session.Executor + conn = session.conn } remainingSessions := len(sm.sessions) sm.mu.Unlock() + + if executor != nil { + _ = executor.Close() + } + // Deliver a pgwire FATAL ErrorResponse before closing the TCP. + // Without the FATAL, libpq-based clients (psql, dbt's psycopg2 + // adapter) can hang on a half-open socket — psql happens to + // handle the bare TCP close OK because its read loop returns, + // but dbt's libpq-async + disabled keepalives setup leaves + // PQconsumeInput parked indefinitely. The FATAL gives every + // client a structured "your session was lost" they can surface. + // + // Concurrency: the message loop also writes to this conn via + // its own bufio.Writer, but *tls.Conn / net.Conn serialize + // underlying Write calls internally — so we may interleave at + // the message boundary (corrupting an in-flight DataRow), but + // not at the byte level. A client that sees a malformed packet + // followed by a FATAL still surfaces an error cleanly, which + // is strictly better than a silent half-open socket. + // + // Write and Close happen outside sm.mu so a slow or wedged client + // socket cannot block unrelated session-manager operations. + if conn != nil { + _ = server.WriteErrorResponse(conn, "FATAL", "08006", + fmt.Sprintf("worker %d for this session became unresponsive and was reaped", workerID)) + _ = conn.Close() + } slog.Info("Worker crash session cleanup completed.", "pid", pid, "worker", workerID, @@ -332,17 +365,40 @@ func (sm *SessionManager) OnWorkerCrash(workerID int, errorFn func(pid int32)) { } } -// SetConnCloser registers the client's TCP connection so it can be closed -// when the backing worker crashes. This unblocks the message loop's read, -// causing it to exit cleanly instead of looping on ErrWorkerDead. -func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) { +// SetSessionConn registers the client's TCP/TLS transport so the worker-crash +// path can deliver a FATAL ErrorResponse and close the socket. *tls.Conn from +// the pgwire handshake satisfies SessionConn (io.Writer + io.Closer). +func (sm *SessionManager) SetSessionConn(pid int32, conn SessionConn) { sm.mu.Lock() defer sm.mu.Unlock() if s, ok := sm.sessions[pid]; ok { - s.connCloser = closer + s.conn = conn } } +// SetConnCloser is a back-compat shim for callers that previously passed an +// io.Closer. The real type they pass (*tls.Conn) is also an io.Writer, so we +// upcast to SessionConn here. New callers should use SetSessionConn directly. +// +// Deprecated: use SetSessionConn. +func (sm *SessionManager) SetConnCloser(pid int32, closer io.Closer) { + conn, ok := closer.(SessionConn) + if !ok { + // Caller passed a closer that isn't also a Writer — we can still + // close on crash, just can't deliver a FATAL. Wrap in a discarding + // writer so the type satisfies SessionConn. + conn = closeOnlyConn{closer} + } + sm.SetSessionConn(pid, conn) +} + +// closeOnlyConn adapts an io.Closer with no Writer into a SessionConn whose +// Write is a no-op. Used by the deprecated SetConnCloser path for callers that +// genuinely don't have a Writer; modern callers pass *tls.Conn directly. +type closeOnlyConn struct{ io.Closer } + +func (closeOnlyConn) Write(p []byte) (int, error) { return len(p), nil } + // SessionCount returns the number of active sessions. func (sm *SessionManager) SessionCount() int { sm.mu.RLock() diff --git a/controlplane/session_mgr_test.go b/controlplane/session_mgr_test.go index 5488a13..dc52d4e 100644 --- a/controlplane/session_mgr_test.go +++ b/controlplane/session_mgr_test.go @@ -3,24 +3,63 @@ package controlplane import ( + "bytes" + "errors" "runtime" + "slices" "strings" + "sync" "sync/atomic" "testing" "github.com/posthog/duckgres/server/flightclient" ) -// mockCloser tracks whether Close was called. +// mockCloser stands in for the client TCP/TLS conn — captures bytes written +// (so tests can assert a FATAL ErrorResponse was delivered) and tracks whether +// Close was called. type mockCloser struct { - closed atomic.Bool + closed atomic.Bool + writeMu sync.Mutex + written []byte + events []string +} + +func (m *mockCloser) Write(p []byte) (int, error) { + m.writeMu.Lock() + defer m.writeMu.Unlock() + if m.closed.Load() { + return 0, errors.New("write after close") + } + m.events = append(m.events, "write") + m.written = append(m.written, p...) + return len(p), nil +} + +func (m *mockCloser) Bytes() []byte { + m.writeMu.Lock() + defer m.writeMu.Unlock() + out := make([]byte, len(m.written)) + copy(out, m.written) + return out } func (m *mockCloser) Close() error { + m.writeMu.Lock() + defer m.writeMu.Unlock() + m.events = append(m.events, "close") m.closed.Store(true) return nil } +func (m *mockCloser) Events() []string { + m.writeMu.Lock() + defer m.writeMu.Unlock() + out := make([]string, len(m.events)) + copy(out, m.events) + return out +} + func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) { pool := &FlightWorkerPool{ workers: make(map[int]*ManagedWorker), @@ -72,10 +111,10 @@ func TestOnWorkerCrash_ClosesConnections(t *testing.T) { sm.mu.Lock() sm.sessions[pid] = &ManagedSession{ - PID: pid, - WorkerID: 7, - Executor: executor, - connCloser: conn, + PID: pid, + WorkerID: 7, + Executor: executor, + conn: conn, } sm.byWorker[7] = []int32{pid} sm.mu.Unlock() @@ -99,8 +138,8 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) { conn2 := &mockCloser{} sm.mu.Lock() - sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, connCloser: conn1} - sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, connCloser: conn2} + sm.sessions[1001] = &ManagedSession{PID: 1001, WorkerID: 3, Executor: exec1, conn: conn1} + sm.sessions[1002] = &ManagedSession{PID: 1002, WorkerID: 3, Executor: exec2, conn: conn2} sm.byWorker[3] = []int32{1001, 1002} sm.mu.Unlock() @@ -117,6 +156,56 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) { } } +func TestOnWorkerCrash_WritesFATALBeforeClose(t *testing.T) { + // Asserts the new behavior: when a worker is reaped, the CP delivers a + // pgwire FATAL ErrorResponse on the client conn before closing it. This + // is the difference between psql cleanly surfacing "connection lost" and + // dbt's libpq state machine hanging silently on a half-open socket. + pool := &FlightWorkerPool{workers: make(map[int]*ManagedWorker)} + sm := NewSessionManager(pool, nil) + + conn := &mockCloser{} + pid := int32(1500) + + sm.mu.Lock() + sm.sessions[pid] = &ManagedSession{ + PID: pid, + WorkerID: 42, + Executor: &flightclient.FlightExecutor{}, + conn: conn, + } + sm.byWorker[42] = []int32{pid} + sm.mu.Unlock() + + sm.OnWorkerCrash(42, func(int32) {}) + + if !conn.closed.Load() { + t.Fatal("conn was not closed on worker crash") + } + + // Inspect the bytes the crash handler wrote: pgwire ErrorResponse starts + // with the byte 'E', followed by a 4-byte length, then field-tagged + // strings. We just assert FATAL + the worker ID appear so we know a + // FATAL packet was emitted before the close — not testing the wire + // encoding in detail (server/wire owns that). + got := conn.Bytes() + if len(got) == 0 { + t.Fatal("no bytes written before close — FATAL not delivered") + } + if got[0] != 'E' { + t.Errorf("expected first byte 'E' (ErrorResponse), got %q", got[0]) + } + if !bytes.Contains(got, []byte("FATAL")) { + t.Errorf("expected 'FATAL' in payload, got %q", got) + } + if !bytes.Contains(got, []byte("42")) { + t.Errorf("expected worker id '42' in payload, got %q", got) + } + if got := conn.Events(); !slices.Equal(got, []string{"write", "write", "write", "close"}) { + t.Fatalf("expected FATAL message writes before close, got event order %v", got) + } +} + func TestSetConnCloser(t *testing.T) { pool := &FlightWorkerPool{ workers: make(map[int]*ManagedWorker), @@ -233,10 +322,10 @@ func TestDestroySessionAfterOnWorkerCrash(t *testing.T) { sm.mu.Lock() sm.sessions[pid] = &ManagedSession{ - PID: pid, - WorkerID: 9, - Executor: executor, - connCloser: conn, + PID: pid, + WorkerID: 9, + Executor: executor, + conn: conn, } sm.byWorker[9] = []int32{pid} sm.mu.Unlock()