diff --git a/controlplane/k8s_pool.go b/controlplane/k8s_pool.go index 1c9de2e..0ffd497 100644 --- a/controlplane/k8s_pool.go +++ b/controlplane/k8s_pool.go @@ -643,33 +643,10 @@ func (p *K8sWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, erro return idle, nil } - // 2. No idle worker — check if we have any live workers at all + // 2. No idle worker — spawn a new one if below capacity. liveCount := p.liveWorkerCountLocked() canSpawn := p.maxWorkers == 0 || liveCount < p.maxWorkers - if liveCount > 0 { - // We have live workers. Assign to the least-loaded one immediately - // and spawn a new worker in the background if below capacity. - w := p.leastLoadedWorkerLocked() - if w != nil { - w.activeSessions++ - if canSpawn { - id := p.allocateWorkerIDLocked() - p.spawning++ - p.mu.Unlock() - slog.Debug("Assigned to least-loaded worker, spawning new worker in background.", - "worker", w.ID, "active_sessions", w.activeSessions, "background_worker", id) - go p.spawnWorkerBackground(id) - } else { - p.mu.Unlock() - slog.Debug("Assigned to least-loaded worker (at capacity).", - "worker", w.ID, "active_sessions", w.activeSessions) - } - return w, nil - } - } - - // 3. No live workers at all (cold start or all dead) — must block on spawn if canSpawn { id := p.allocateWorkerIDLocked() p.spawning++ @@ -696,33 +673,17 @@ func (p *K8sWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, erro return w, nil } - // At capacity with all workers dead (spawning in progress) — wait and retry + // 3. At capacity — wait for a worker to become idle. p.mu.Unlock() select { case <-ctx.Done(): return nil, ctx.Err() - case <-time.After(100 * time.Millisecond): + case <-time.After(200 * time.Millisecond): + // Retry } } } -// spawnWorkerBackground spawns a worker pod without blocking AcquireWorker. -// The new worker becomes available for future sessions once ready. -func (p *K8sWorkerPool) spawnWorkerBackground(id int) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) - defer cancel() - - err := p.SpawnWorker(ctx, id) - - p.mu.Lock() - p.spawning-- - p.mu.Unlock() - - if err != nil { - slog.Warn("Background worker spawn failed.", "worker", id, "error", err) - } -} - // ReleaseWorker decrements the active session count for a worker. func (p *K8sWorkerPool) ReleaseWorker(id int) { p.mu.Lock() @@ -1201,24 +1162,6 @@ func (p *K8sWorkerPool) findIdleWorkerLocked() *ManagedWorker { return nil } -func (p *K8sWorkerPool) leastLoadedWorkerLocked() *ManagedWorker { - var best *ManagedWorker - for _, w := range p.workers { - select { - case <-w.done: - continue - default: - } - if !p.isGenericSessionSchedulableWorkerLocked(w) { - continue - } - if best == nil || w.activeSessions < best.activeSessions { - best = w - } - } - return best -} - func (p *K8sWorkerPool) liveWorkerCountLocked() int { count := p.spawning for _, w := range p.workers { @@ -1376,6 +1319,23 @@ func (p *K8sWorkerPool) spawnWarmWorker(ctx context.Context, id int) error { return p.SpawnWorker(ctx, id) } +// spawnWorkerBackground spawns a worker pod without blocking the caller. +// Used for warm pool replenishment after a worker is reserved. +func (p *K8sWorkerPool) spawnWorkerBackground(id int) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + err := p.SpawnWorker(ctx, id) + + p.mu.Lock() + p.spawning-- + p.mu.Unlock() + + if err != nil { + slog.Warn("Background worker spawn failed.", "worker", id, "error", err) + } +} + func (p *K8sWorkerPool) spawnWarmWorkerBackground(id int) { if p.spawnWarmWorkerBackgroundFunc != nil { p.spawnWarmWorkerBackgroundFunc(id) diff --git a/controlplane/k8s_pool_test.go b/controlplane/k8s_pool_test.go index 706b565..0435177 100644 --- a/controlplane/k8s_pool_test.go +++ b/controlplane/k8s_pool_test.go @@ -301,20 +301,6 @@ func TestK8sPool_FindIdleWorker(t *testing.T) { } } -func TestK8sPool_LeastLoadedWorker(t *testing.T) { - pool, _ := newTestK8sPool(t, 5) - - done := make(chan struct{}) - pool.workers[1] = &ManagedWorker{ID: 1, activeSessions: 5, done: done} - pool.workers[2] = &ManagedWorker{ID: 2, activeSessions: 2, done: done} - pool.workers[3] = &ManagedWorker{ID: 3, activeSessions: 3, done: done} - - w := pool.leastLoadedWorkerLocked() - if w == nil || w.ID != 2 { - t.Fatalf("expected least loaded worker 2, got %v", w) - } -} - func TestK8sPool_LiveWorkerCount(t *testing.T) { pool, _ := newTestK8sPool(t, 5) diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index d6a3f8d..8974643 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -169,7 +169,7 @@ func (sm *SessionManager) DestroySession(pid int32) { case <-worker.done: // Worker already dead, skip RPC default: - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) _ = worker.DestroySession(ctx, session.SessionToken) cancel() } @@ -178,7 +178,7 @@ func (sm *SessionManager) DestroySession(pid int32) { // Release the worker for reuse after cleanup is complete. sm.pool.ReleaseWorker(session.WorkerID) - slog.Debug("Session destroyed.", "pid", pid, "worker", session.WorkerID) + slog.Info("Session destroyed, worker recycled.", "pid", pid, "worker", session.WorkerID) // Rebalance remaining sessions if sm.rebalancer != nil { diff --git a/controlplane/worker_mgr.go b/controlplane/worker_mgr.go index 4acb25e..ad0f0b6 100644 --- a/controlplane/worker_mgr.go +++ b/controlplane/worker_mgr.go @@ -514,108 +514,79 @@ func (p *FlightWorkerPool) SpawnMinWorkers(count int) error { // AcquireWorker returns a worker for a new session. // -// Strategy: +// Strategy (1:1 worker-to-session model): // 1. Reuse an idle worker (0 active sessions) if available. // 2. If the pool has fewer live workers than maxWorkers (or maxWorkers is 0), // spawn a new worker process. -// 3. If the pool is at capacity, assign to the least-loaded live worker. -// -// This ensures the number of worker processes never exceeds maxWorkers while -// allowing unlimited concurrent sessions across the fixed pool. +// 3. If the pool is at capacity, wait with backoff until a worker becomes idle. func (p *FlightWorkerPool) AcquireWorker(ctx context.Context) (*ManagedWorker, error) { acquireStart := time.Now() defer func() { observeControlPlaneWorkerAcquire(time.Since(acquireStart)) }() - p.mu.Lock() - if p.shuttingDown { - p.mu.Unlock() - return nil, fmt.Errorf("pool is shutting down") - } - - // Remove dead worker entries so they don't inflate the count. - p.cleanDeadWorkersLocked() - - // 1. Try to claim an idle worker before spawning a new one. - idle := p.findIdleWorkerLocked() - if idle != nil { - idle.activeSessions++ - p.mu.Unlock() - return idle, nil - } - - // 2. If below the process cap (or unlimited), spawn a new worker. - liveCount := p.liveWorkerCountLocked() - if p.maxWorkers == 0 || liveCount < p.maxWorkers { - id := p.nextWorkerID - p.nextWorkerID++ - p.spawning++ - p.mu.Unlock() - - err := p.SpawnWorker(id) + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } p.mu.Lock() - p.spawning-- - p.mu.Unlock() - - if err != nil { - return nil, err + if p.shuttingDown { + p.mu.Unlock() + return nil, fmt.Errorf("pool is shutting down") } - w, ok := p.Worker(id) - if !ok { - return nil, fmt.Errorf("worker %d not found after spawn", id) + // Remove dead worker entries so they don't inflate the count. + p.cleanDeadWorkersLocked() + + // 1. Try to claim an idle worker before spawning a new one. + idle := p.findIdleWorkerLocked() + if idle != nil { + idle.activeSessions++ + p.mu.Unlock() + return idle, nil } - p.mu.Lock() - w.activeSessions++ - p.mu.Unlock() - return w, nil - } + // 2. If below the process cap (or unlimited), spawn a new worker. + liveCount := p.liveWorkerCountLocked() + if p.maxWorkers == 0 || liveCount < p.maxWorkers { + id := p.nextWorkerID + p.nextWorkerID++ + p.spawning++ + p.mu.Unlock() - // 3. At capacity — assign to the least-loaded live worker. - w := p.leastLoadedWorkerLocked() - if w != nil { - w.activeSessions++ - p.mu.Unlock() - return w, nil - } + err := p.SpawnWorker(id) - // All workers are dead (already cleaned above). Spawn a replacement. - // Still respect maxWorkers — another goroutine may already be spawning. - liveCount = p.liveWorkerCountLocked() - if p.maxWorkers > 0 && liveCount >= p.maxWorkers { - // A spawn is already in progress; wait for it to finish and use that worker. - p.mu.Unlock() - // Brief backoff then retry — the in-progress spawn will add a worker shortly. - time.Sleep(100 * time.Millisecond) - return p.AcquireWorker(ctx) - } - id := p.nextWorkerID - p.nextWorkerID++ - p.spawning++ - p.mu.Unlock() + p.mu.Lock() + p.spawning-- + p.mu.Unlock() - err := p.SpawnWorker(id) + if err != nil { + return nil, err + } - p.mu.Lock() - p.spawning-- - p.mu.Unlock() + w, ok := p.Worker(id) + if !ok { + return nil, fmt.Errorf("worker %d not found after spawn", id) + } - if err != nil { - return nil, err - } + p.mu.Lock() + w.activeSessions++ + p.mu.Unlock() + return w, nil + } - w, ok := p.Worker(id) - if !ok { - return nil, fmt.Errorf("worker %d not found after spawn", id) + // 3. At capacity — wait for a worker to become idle. + p.mu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(200 * time.Millisecond): + // Retry + } } - - p.mu.Lock() - w.activeSessions++ - p.mu.Unlock() - return w, nil } // ReleaseWorker decrements the active session count for a worker and updates its lastUsed time. @@ -704,22 +675,7 @@ func (p *FlightWorkerPool) findIdleWorkerLocked() *ManagedWorker { return nil } -// leastLoadedWorkerLocked returns the live worker with the fewest active -// sessions, or nil if all workers are dead. Caller must hold p.mu. -func (p *FlightWorkerPool) leastLoadedWorkerLocked() *ManagedWorker { - var best *ManagedWorker - for _, w := range p.workers { - select { - case <-w.done: - continue // dead - default: - } - if best == nil || w.activeSessions < best.activeSessions { - best = w - } - } - return best -} + // liveWorkerCountLocked returns the number of workers whose process is still // running (done channel not closed) plus workers currently being spawned. diff --git a/controlplane/worker_mgr_test.go b/controlplane/worker_mgr_test.go index fda7db3..dfa306f 100644 --- a/controlplane/worker_mgr_test.go +++ b/controlplane/worker_mgr_test.go @@ -2,6 +2,7 @@ package controlplane import ( "context" + "errors" "fmt" "net" "os" @@ -292,33 +293,60 @@ func TestAcquireWorkerReusesIdleWorker(t *testing.T) { } } -func TestAcquireWorkerLeastLoadedAtCapacity(t *testing.T) { - pool := NewFlightWorkerPool(t.TempDir(), "", 0, 2) +func TestAcquireWorkerWaitsWhenAtCapacity(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 0, 1) - // Pre-populate 2 busy workers (at capacity). + // Pre-populate 1 busy worker (at capacity). w0, cleanup0 := makeFakeWorker(t, 0) defer cleanup0() - w1, cleanup1 := makeFakeWorker(t, 1) - defer cleanup1() pool.mu.Lock() - w0.activeSessions = 3 - w1.activeSessions = 1 + w0.activeSessions = 1 pool.workers[0] = w0 - pool.workers[1] = w1 - pool.nextWorkerID = 2 + pool.nextWorkerID = 1 pool.mu.Unlock() - // Acquire should pick the least-loaded worker (w1 with 1 session). - w, err := pool.AcquireWorker(context.Background()) - if err != nil { - t.Fatalf("unexpected error: %v", err) + // AcquireWorker should block since the only worker is busy. + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + _, err := pool.AcquireWorker(ctx) + if err == nil { + t.Fatal("expected timeout error when all workers are busy") } - if w.ID != 1 { - t.Fatalf("expected worker 1 (least loaded), got worker %d", w.ID) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected DeadlineExceeded, got: %v", err) + } +} + +func TestAcquireWorkerSucceedsAfterRelease(t *testing.T) { + pool := NewFlightWorkerPool(t.TempDir(), "", 0, 1) + + // Pre-populate 1 busy worker. + w0, cleanup0 := makeFakeWorker(t, 0) + defer cleanup0() + + pool.mu.Lock() + w0.activeSessions = 1 + pool.workers[0] = w0 + pool.nextWorkerID = 1 + pool.mu.Unlock() + + // Release the worker after a short delay. + go func() { + time.Sleep(100 * time.Millisecond) + pool.ReleaseWorker(0) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + w, err := pool.AcquireWorker(ctx) + if err != nil { + t.Fatalf("expected acquire to succeed after release, got: %v", err) } - if w.activeSessions != 2 { - t.Fatalf("expected 2 active sessions after acquire, got %d", w.activeSessions) + if w.ID != 0 { + t.Fatalf("expected worker 0, got worker %d", w.ID) } } @@ -639,84 +667,51 @@ func TestCleanDeadWorkersLocked(t *testing.T) { } } -func TestLeastLoadedWorkerLocked(t *testing.T) { - pool := NewFlightWorkerPool(t.TempDir(), "", 0, 5) - - w0, cleanup0 := makeFakeWorker(t, 0) - defer cleanup0() - w1, cleanup1 := makeFakeWorker(t, 1) - defer cleanup1() - w2, cleanup2 := makeFakeWorker(t, 2) - defer cleanup2() - - w0.activeSessions = 5 - w1.activeSessions = 2 - w2.activeSessions = 8 - - pool.mu.Lock() - pool.workers[0] = w0 - pool.workers[1] = w1 - pool.workers[2] = w2 - best := pool.leastLoadedWorkerLocked() - pool.mu.Unlock() +func TestAcquireWorkerConcurrentOneToOne(t *testing.T) { + // With maxWorkers=3 and 3 idle workers, 3 concurrent acquires should + // each get a different worker (1:1 model). + pool := NewFlightWorkerPool(t.TempDir(), "", 0, 3) - if best == nil { - t.Fatal("expected a worker") - return - } - if best.ID != 1 { - t.Fatalf("expected worker 1 (least loaded with 2 sessions), got worker %d", best.ID) + for i := 0; i < 3; i++ { + w, cleanup := makeFakeWorker(t, i) + defer cleanup() + pool.mu.Lock() + pool.workers[i] = w + pool.mu.Unlock() } -} - -func TestAcquireWorkerConcurrentSharing(t *testing.T) { - // With maxWorkers=2 and 2 busy workers, 10 concurrent acquires should - // all succeed by sharing the existing workers. - pool := NewFlightWorkerPool(t.TempDir(), "", 0, 2) - - w0, cleanup0 := makeFakeWorker(t, 0) - defer cleanup0() - w1, cleanup1 := makeFakeWorker(t, 1) - defer cleanup1() - pool.mu.Lock() - w0.activeSessions = 1 - w1.activeSessions = 1 - pool.workers[0] = w0 - pool.workers[1] = w1 - pool.nextWorkerID = 2 + pool.nextWorkerID = 3 pool.mu.Unlock() - const concurrency = 10 + const concurrency = 3 var wg sync.WaitGroup - errors := make(chan error, concurrency) + results := make(chan *ManagedWorker, concurrency) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() - _, err := pool.AcquireWorker(context.Background()) + w, err := pool.AcquireWorker(context.Background()) if err != nil { - errors <- err + t.Errorf("unexpected error: %v", err) + return } + results <- w }() } wg.Wait() - close(errors) + close(results) - for err := range errors { - t.Fatalf("unexpected error: %v", err) + workers := make(map[int]bool) + for w := range results { + if workers[w.ID] { + t.Errorf("worker %d was assigned multiple times", w.ID) + } + workers[w.ID] = true } - - // All 10 sessions should be spread across the 2 workers. - pool.mu.RLock() - total := w0.activeSessions + w1.activeSessions - pool.mu.RUnlock() - - // 2 original + 10 new = 12 - if total != 12 { - t.Fatalf("expected 12 total sessions across 2 workers, got %d", total) + if len(workers) != 3 { + t.Fatalf("expected 3 unique workers, got %d", len(workers)) } } diff --git a/controlplane/worker_pool.go b/controlplane/worker_pool.go index 0f09c69..fe4a934 100644 --- a/controlplane/worker_pool.go +++ b/controlplane/worker_pool.go @@ -11,8 +11,9 @@ import ( // - FlightWorkerPool: spawns workers as local child processes (default) // - K8sWorkerPool: creates workers as Kubernetes pods (build tag: kubernetes) type WorkerPool interface { - // AcquireWorker returns a worker for a new session. It may reuse an idle - // worker, spawn a new one, or assign to the least-loaded worker. + // AcquireWorker returns a worker for a new session (1:1 model). + // It reuses an idle worker or spawns a new one. If at capacity, it waits + // until a worker becomes idle. AcquireWorker(ctx context.Context) (*ManagedWorker, error) // ReleaseWorker decrements the active session count for a worker. diff --git a/duckdbservice/service.go b/duckdbservice/service.go index 5b58ab9..f7e1619 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -441,7 +441,8 @@ func (p *SessionPool) GetSession(token string) (*Session, bool) { return s, ok } -// DestroySession closes and removes a session. +// DestroySession closes and removes a session, then resets the shared DuckDB +// instance in-place so the next session starts with clean state. func (p *SessionPool) DestroySession(token string) error { p.mu.Lock() session, ok := p.sessions[token] @@ -469,30 +470,19 @@ func (p *SessionPool) DestroySession(token string) error { stop() } if session.Conn != nil { - // Drop temporary tables before returning the connection to the pool. - // sql.Conn.Close() returns the underlying driver connection to sql.DB's - // pool rather than closing it. DuckDB temp tables are connection-scoped, - // so they'd leak into the next session that gets the same connection. - cleanupStart := time.Now() - slog.Debug("Cleaning up session state.", "user", session.Username) - cleanupSessionState(session.Conn) - slog.Debug("Session state cleaned up.", "user", session.Username, "duration", time.Since(cleanupStart)) - connCloseStart := time.Now() _ = session.Conn.Close() - slog.Debug("Session connection closed (returned to pool).", "user", session.Username, "duration", time.Since(connCloseStart)) } - // Do NOT close session.DB if it is a shared DB (warmup or fallback) - p.mu.RLock() - isShared := session.DB == p.warmupDB || session.DB == p.fallbackDB - p.mu.RUnlock() - if session.DB != nil && !isShared { - if err := session.DB.Close(); err != nil { - slog.Warn("Failed to close session database", "error", err) + // Reset the shared DuckDB instance in-place: drop user objects, reset + // settings, and re-apply warmup config. This avoids the ~90ms cost of + // closing and reopening the DB while still guaranteeing clean state. + if session.DB != nil { + if err := p.resetSessionState(session.DB); err != nil { + slog.Warn("Failed to reset session state.", "user", session.Username, "error", err) } } - slog.Debug("Destroyed DuckDB session", "user", session.Username) + slog.Info("Session destroyed.", "user", session.Username) return nil } @@ -525,6 +515,9 @@ func (p *SessionPool) CloseAll() { } } + // Track which DBs we've already closed to avoid double-close. + closedDBs := make(map[*sql.DB]bool) + for _, session := range sessions { session.mu.Lock() for id, ttx := range session.txns { @@ -536,15 +529,17 @@ func (p *SessionPool) CloseAll() { if session.Conn != nil { _ = session.Conn.Close() } - if session.DB != nil && session.DB != p.warmupDB && session.DB != p.fallbackDB { + if session.DB != nil && !closedDBs[session.DB] { _ = session.DB.Close() + closedDBs[session.DB] = true } } - if p.warmupDB != nil { + if p.warmupDB != nil && !closedDBs[p.warmupDB] { _ = p.warmupDB.Close() + closedDBs[p.warmupDB] = true } - if p.fallbackDB != nil && p.fallbackDB != p.warmupDB { + if p.fallbackDB != nil && !closedDBs[p.fallbackDB] { _ = p.fallbackDB.Close() } if p.activation != nil && p.activation.db != nil && p.activation.db != p.warmupDB && p.activation.db != p.fallbackDB { @@ -552,51 +547,6 @@ func (p *SessionPool) CloseAll() { } } -// cleanupSessionState drops temporary tables and views on the connection so -// that session-scoped state doesn't leak when the connection is returned to -// the pool. -func cleanupSessionState(conn *sql.Conn) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // Drop temporary tables - dropTemporary(ctx, conn, - "SELECT table_name FROM duckdb_tables() WHERE temporary = true", - `DROP TABLE IF EXISTS temp."%s"`, - ) - - // Drop temporary views - dropTemporary(ctx, conn, - "SELECT view_name FROM duckdb_views() WHERE temporary = true", - `DROP VIEW IF EXISTS temp."%s"`, - ) -} - -func dropTemporary(ctx context.Context, conn *sql.Conn, query, dropFmt string) { - rows, err := conn.QueryContext(ctx, query) - if err != nil { - slog.Warn("Failed to query temporary objects for cleanup.", "error", err) - return - } - var names []string - for rows.Next() { - var name string - if rows.Scan(&name) == nil { - names = append(names, name) - } - } - if err := rows.Err(); err != nil { - slog.Warn("Error iterating temporary objects for cleanup.", "error", err) - } - _ = rows.Close() - - for _, name := range names { - if _, err := conn.ExecContext(ctx, fmt.Sprintf(dropFmt, name)); err != nil { - slog.Warn("Failed to drop temporary object during cleanup.", "name", name, "error", err) - } - } -} - // initSearchPath sets the DuckDB search_path for a session connection. // It tries to include the user's schema first; if that schema doesn't exist, // it falls back to just 'main' (DuckDB's default schema). diff --git a/duckdbservice/session_reset.go b/duckdbservice/session_reset.go new file mode 100644 index 0000000..f8bf852 --- /dev/null +++ b/duckdbservice/session_reset.go @@ -0,0 +1,434 @@ +package duckdbservice + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/posthog/duckgres/server" +) + +// Allowlists of objects created during DuckDB warmup (initPgCatalog, +// initClickHouseMacros, initInformationSchema, openBaseDB, AttachDuckLake). +// Objects not in these lists are considered user-created and will be +// dropped during session reset to prevent state leakage between sessions. +// +// All names are lowercase because DuckDB folds unquoted identifiers to lowercase. + +var systemMacros = map[string]bool{ + // PostgreSQL compatibility (server/catalog.go initPgCatalog) + "pg_get_userbyid": true, + "pg_table_is_visible": true, + "has_schema_privilege": true, + "has_table_privilege": true, + "has_any_column_privilege": true, + "has_database_privilege": true, + "pg_encoding_to_char": true, + "format_type": true, + "obj_description": true, + "col_description": true, + "shobj_description": true, + "pg_get_indexdef": true, + "pg_get_partkeydef": true, + "pg_get_serial_sequence": true, + "pg_get_statisticsobjdef_columns": true, + "pg_relation_is_publishable": true, + "current_setting": true, + "pg_is_in_recovery": true, + "similar_to_escape": true, + "version": true, + "div": true, + "array_remove": true, + "to_number": true, + "pg_backend_pid": true, + "pg_total_relation_size": true, + "pg_relation_size": true, + "pg_table_size": true, + "pg_stat_get_numscans": true, + "pg_indexes_size": true, + "pg_database_size": true, + "pg_size_pretty": true, + "txid_current": true, + "pg_current_xact_id": true, + "quote_ident": true, + "quote_literal": true, + "quote_nullable": true, + // Utility macros (server/catalog.go initUtilityMacros) + "uptime": true, + "worker_uptime": true, + "control_plane_version": true, + "worker_version": true, + // ClickHouse compatibility (server/chsql.go) + "tostring": true, + "toint32": true, + "toint64": true, + "tofloat": true, + "toint32ornull": true, + "toint32orzero": true, + "intdiv": true, + "modulo": true, + "empty": true, + "notempty": true, + "splitbychar": true, + "lengthutf8": true, + "toyear": true, + "tomonth": true, + "todayofmonth": true, + "toyyyymmdd": true, + "toyyyymm": true, + "protocol": true, + "domain": true, + "topleveldomain": true, + "ipv4numtostring": true, + "jsonextractstring": true, + "jsonhas": true, + "generateuuidv4": true, + "ifnull": true, +} + +var systemViews = map[string]bool{ + // pg_catalog views (server/catalog.go initPgCatalog) + "pg_database": true, + "pg_class_full": true, + "pg_collation": true, + "pg_policy": true, + "pg_roles": true, + "pg_statistic_ext": true, + "pg_publication_tables": true, + "pg_rules": true, + "pg_publication": true, + "pg_publication_rel": true, + "pg_inherits": true, + "pg_matviews": true, + "pg_stat_statements": true, + "pg_partitioned_table": true, + "pg_rewrite": true, + "pg_stat_user_tables": true, + "pg_statio_user_tables": true, + "pg_stat_activity": true, + "pg_namespace": true, + "pg_type": true, + "pg_attribute": true, + "pg_constraint": true, + "pg_enum": true, + "pg_indexes": true, + "pg_shdescription": true, + "pg_extension": true, + // Stub views + "pg_auth_members": true, + "pg_opclass": true, + "pg_conversion": true, + "pg_language": true, + "pg_foreign_server": true, + "pg_foreign_data_wrapper": true, + "pg_foreign_table": true, + "pg_trigger": true, + "pg_locks": true, + // Information schema wrappers (server/catalog.go initInformationSchema) + "information_schema_columns_compat": true, + "information_schema_tables_compat": true, + "information_schema_schemata_compat": true, + "information_schema_views_compat": true, +} + +var systemTables = map[string]bool{ + "__duckgres_column_metadata": true, +} + +var systemDatabases = map[string]bool{ + "memory": true, + "system": true, + "temp": true, + "ducklake": true, +} + +var systemSchemas = map[string]bool{ + "main": true, + "pg_catalog": true, + "information_schema": true, +} + +var systemSecrets = map[string]bool{ + "ducklake_s3": true, +} + +// resetSessionState performs exhaustive in-place cleanup of the shared DuckDB +// instance after a session ends. It drops all user-created state (tables, views, +// macros, settings, temp objects, attached databases) while preserving the +// warmup-created objects via allowlists. This avoids the ~90ms cost of closing +// and reopening the DuckDB instance. +func (p *SessionPool) resetSessionState(db *sql.DB) error { + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + return fmt.Errorf("reset session state: %w", err) + } + defer func() { _ = conn.Close() }() + + start := time.Now() + + // 1. RESET all DuckDB settings to defaults. + resetAllSettings(ctx, conn) + + // 2. Drop all temporary objects. + dropTempObjects(ctx, conn) + + // 3. Drop user-created objects in memory.main not in allowlists. + dropUserObjects(ctx, conn) + + // 4. Drop user-created schemas in memory. + dropUserSchemas(ctx, conn) + + // 5. Detach user-attached databases. + detachUserDatabases(ctx, conn) + + // 6. Drop user-created secrets. + dropUserSecrets(ctx, conn) + + // 7. Re-apply warmup settings (threads, memory_limit, paths, DuckLake). + p.reapplySettings(ctx, conn) + + slog.Info("Session state reset.", "duration", time.Since(start)) + return nil +} + +// resetAllSettings resets all DuckDB settings to their defaults. +// Read-only settings will fail to reset; errors are silently ignored. +func resetAllSettings(ctx context.Context, conn *sql.Conn) { + rows, err := conn.QueryContext(ctx, "SELECT name FROM duckdb_settings()") + if err != nil { + slog.Warn("Failed to query settings for reset.", "error", err) + return + } + defer func() { _ = rows.Close() }() + + var names []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + continue + } + names = append(names, name) + } + + for _, name := range names { + _, _ = conn.ExecContext(ctx, "RESET "+name) + } +} + +// dropTempObjects drops all tables and views in the temp database. +func dropTempObjects(ctx context.Context, conn *sql.Conn) { + // Views first (they may reference tables). + for _, v := range queryPairs(ctx, conn, + "SELECT schema_name, view_name FROM duckdb_views() WHERE database_name = 'temp'") { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP VIEW IF EXISTS temp."%s"."%s" CASCADE`, v[0], v[1])); err != nil { + slog.Warn("Failed to drop temp view.", "view", v[1], "error", err) + } + } + for _, t := range queryPairs(ctx, conn, + "SELECT schema_name, table_name FROM duckdb_tables() WHERE database_name = 'temp'") { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS temp."%s"."%s" CASCADE`, t[0], t[1])); err != nil { + slog.Warn("Failed to drop temp table.", "table", t[1], "error", err) + } + } +} + +// dropUserObjects drops non-allowlisted views, tables, macros, sequences, and +// types in the memory.main schema. +func dropUserObjects(ctx context.Context, conn *sql.Conn) { + // Views + for _, v := range queryPairs(ctx, conn, + "SELECT schema_name, view_name FROM duckdb_views() WHERE database_name = 'memory' AND schema_name = 'main'") { + if !systemViews[strings.ToLower(v[1])] { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP VIEW IF EXISTS memory.main."%s" CASCADE`, v[1])); err != nil { + slog.Warn("Failed to drop user view.", "view", v[1], "error", err) + } + } + } + + // Tables + for _, t := range queryPairs(ctx, conn, + "SELECT schema_name, table_name FROM duckdb_tables() WHERE database_name = 'memory' AND schema_name = 'main'") { + if !systemTables[strings.ToLower(t[1])] { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS memory.main."%s" CASCADE`, t[1])); err != nil { + slog.Warn("Failed to drop user table.", "table", t[1], "error", err) + } + } + } + + // Macros (scalar and table macros) + for _, m := range queryMacros(ctx, conn) { + if !systemMacros[strings.ToLower(m.name)] { + stmt := fmt.Sprintf(`DROP MACRO IF EXISTS "%s"`, m.name) + if m.isTable { + stmt = fmt.Sprintf(`DROP MACRO TABLE IF EXISTS "%s"`, m.name) + } + if _, err := conn.ExecContext(ctx, stmt); err != nil { + slog.Warn("Failed to drop user macro.", "macro", m.name, "error", err) + } + } + } + + // Sequences + for _, s := range queryPairs(ctx, conn, + "SELECT schema_name, sequence_name FROM duckdb_sequences() WHERE database_name = 'memory' AND schema_name = 'main'") { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP SEQUENCE IF EXISTS memory.main."%s"`, s[1])); err != nil { + slog.Warn("Failed to drop user sequence.", "sequence", s[1], "error", err) + } + } + + // User-defined types + for _, t := range queryPairs(ctx, conn, + "SELECT schema_name, type_name FROM duckdb_types() WHERE database_name = 'memory' AND schema_name = 'main' AND internal = false") { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP TYPE IF EXISTS memory.main."%s"`, t[1])); err != nil { + slog.Warn("Failed to drop user type.", "type", t[1], "error", err) + } + } +} + +// dropUserSchemas drops schemas in memory that aren't in the system allowlist. +func dropUserSchemas(ctx context.Context, conn *sql.Conn) { + for _, s := range queryPairs(ctx, conn, + "SELECT database_name, schema_name FROM duckdb_schemas() WHERE database_name = 'memory'") { + if !systemSchemas[strings.ToLower(s[1])] { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP SCHEMA IF EXISTS memory."%s" CASCADE`, s[1])); err != nil { + slog.Warn("Failed to drop user schema.", "schema", s[1], "error", err) + } + } + } +} + +// detachUserDatabases detaches databases not in the system allowlist. +func detachUserDatabases(ctx context.Context, conn *sql.Conn) { + for _, d := range queryPairs(ctx, conn, + "SELECT database_name, database_name FROM duckdb_databases()") { + if !systemDatabases[strings.ToLower(d[0])] { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DETACH "%s"`, d[0])); err != nil { + slog.Warn("Failed to detach user database.", "database", d[0], "error", err) + } + } + } +} + +// dropUserSecrets drops secrets not in the system allowlist. +func dropUserSecrets(ctx context.Context, conn *sql.Conn) { + rows, err := conn.QueryContext(ctx, "SELECT name FROM duckdb_secrets()") + if err != nil { + return // duckdb_secrets() may not exist + } + defer func() { _ = rows.Close() }() + + var names []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + continue + } + if !systemSecrets[strings.ToLower(name)] { + names = append(names, name) + } + } + + for _, name := range names { + if _, err := conn.ExecContext(ctx, fmt.Sprintf(`DROP SECRET IF EXISTS "%s"`, name)); err != nil { + slog.Warn("Failed to drop user secret.", "secret", name, "error", err) + } + } +} + +// reapplySettings re-applies the warmup settings that were cleared by RESET. +func (p *SessionPool) reapplySettings(ctx context.Context, conn *sql.Conn) { + // threads + threads := p.cfg.Threads + if threads == 0 { + threads = runtime.NumCPU() * 2 + } + _, _ = conn.ExecContext(ctx, fmt.Sprintf("SET threads = %d", threads)) + + // memory_limit + memLimit := p.cfg.MemoryLimit + if memLimit == "" { + memLimit = server.AutoMemoryLimit() + } + _, _ = conn.ExecContext(ctx, fmt.Sprintf("SET memory_limit = '%s'", memLimit)) + + // temp_directory + tempDir := filepath.Join(p.cfg.DataDir, "tmp") + _, _ = conn.ExecContext(ctx, fmt.Sprintf("SET temp_directory = '%s'", tempDir)) + + // extension_directory + extDir := filepath.Join(p.cfg.DataDir, "extensions") + _, _ = conn.ExecContext(ctx, fmt.Sprintf("SET extension_directory = '%s'", extDir)) + + // cache_httpfs_cache_directory (if the extension is loaded) + for _, ext := range p.cfg.Extensions { + name := ext + if idx := strings.Index(name, ":"); idx >= 0 { + name = name[:idx] + } + if name == "cache_httpfs" { + cacheDir := filepath.Join(p.cfg.DataDir, "cache") + _, _ = conn.ExecContext(ctx, fmt.Sprintf("SET cache_httpfs_cache_directory = '%s/'", cacheDir)) + break + } + } + + // DuckLake settings + if p.cfg.DuckLake.MetadataStore != "" { + _, _ = conn.ExecContext(ctx, "SET ducklake_max_retry_count = 100") + _, _ = conn.ExecContext(ctx, "USE ducklake") + } +} + +type macroInfo struct { + name string + isTable bool +} + +// queryMacros returns user-created macros (scalar and table) in memory.main. +func queryMacros(ctx context.Context, conn *sql.Conn) []macroInfo { + rows, err := conn.QueryContext(ctx, + "SELECT DISTINCT function_name, function_type FROM duckdb_functions() "+ + "WHERE database_name = 'memory' AND schema_name = 'main' "+ + "AND function_type IN ('macro', 'table_macro')") + if err != nil { + slog.Warn("Failed to query macros.", "error", err) + return nil + } + defer func() { _ = rows.Close() }() + + var macros []macroInfo + for rows.Next() { + var name, ftype string + if err := rows.Scan(&name, &ftype); err != nil { + continue + } + macros = append(macros, macroInfo{name: name, isTable: ftype == "table_macro"}) + } + return macros +} + +// queryPairs runs a two-column query and returns the results. +func queryPairs(ctx context.Context, conn *sql.Conn, query string) [][2]string { + rows, err := conn.QueryContext(ctx, query) + if err != nil { + slog.Debug("Catalog query returned no results.", "query", query, "error", err) + return nil + } + defer func() { _ = rows.Close() }() + + var results [][2]string + for rows.Next() { + var a, b string + if err := rows.Scan(&a, &b); err != nil { + continue + } + results = append(results, [2]string{a, b}) + } + return results +} diff --git a/duckdbservice/session_reset_test.go b/duckdbservice/session_reset_test.go new file mode 100644 index 0000000..cb65221 --- /dev/null +++ b/duckdbservice/session_reset_test.go @@ -0,0 +1,471 @@ +package duckdbservice + +import ( + "context" + "database/sql" + "strings" + "testing" + + _ "github.com/duckdb/duckdb-go/v2" + "github.com/posthog/duckgres/server" +) + +func newTestPool(t *testing.T) (*SessionPool, *sql.DB) { + t.Helper() + db, err := sql.Open("duckdb", ":memory:") + if err != nil { + t.Fatalf("failed to open DuckDB: %v", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + t.Cleanup(func() { _ = db.Close() }) + + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + duckLakeSem: make(chan struct{}, 1), + cfg: server.Config{ + DataDir: t.TempDir(), + }, + stopCh: make(chan struct{}), + warmupDone: make(chan struct{}), + warmupDB: db, + } + close(pool.warmupDone) + return pool, db +} + +func TestResetSessionState_ClearsSetVariables(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + // Simulate a session setting a variable. + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "SET default_null_order = 'NULLS_FIRST'"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + // Verify the setting persists on the shared connection. + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + var val string + if err := conn2.QueryRowContext(ctx, "SELECT value FROM duckdb_settings() WHERE name = 'default_null_order'").Scan(&val); err != nil { + t.Fatal(err) + } + if !strings.Contains(strings.ToUpper(val), "NULLS_FIRST") { + t.Fatalf("expected NULLS_FIRST, got %s", val) + } + _ = conn2.Close() + + // Reset session state. + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + // Verify the setting was reset. + conn3, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn3.Close() }() + + if err := conn3.QueryRowContext(ctx, "SELECT value FROM duckdb_settings() WHERE name = 'default_null_order'").Scan(&val); err != nil { + t.Fatal(err) + } + if strings.Contains(strings.ToUpper(val), "NULLS_FIRST") { + t.Fatalf("setting should have been reset, got %s", val) + } +} + +func TestResetSessionState_ClearsTempTables(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE TEMP TABLE tmp_test (id INTEGER)"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + // Verify temp table is gone. + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_tables() WHERE database_name = 'temp' AND table_name = 'tmp_test'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("temp table should have been dropped") + } +} + +func TestResetSessionState_ClearsUserTables(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE TABLE user_data (id INTEGER, name VARCHAR)"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_tables() WHERE database_name = 'memory' AND schema_name = 'main' AND table_name = 'user_data'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user table should have been dropped") + } +} + +func TestResetSessionState_PreservesSystemTable(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + // Create the system table that warmup would create. + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS __duckgres_column_metadata ( + table_schema VARCHAR, table_name VARCHAR, column_name VARCHAR, + character_maximum_length INTEGER, PRIMARY KEY (table_schema, table_name, column_name) + )`); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_tables() WHERE table_name = '__duckgres_column_metadata'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("system table __duckgres_column_metadata should be preserved") + } +} + +func TestResetSessionState_ClearsUserViews(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE VIEW user_view AS SELECT 1 AS x"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_views() WHERE database_name = 'memory' AND schema_name = 'main' AND view_name = 'user_view'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user view should have been dropped") + } +} + +func TestResetSessionState_PreservesSystemViews(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + // Create a system view (simulating warmup). + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE OR REPLACE VIEW pg_database AS SELECT 1 AS oid, 'postgres' AS datname"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_views() WHERE database_name = 'memory' AND schema_name = 'main' AND view_name = 'pg_database'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("system view pg_database should be preserved") + } +} + +func TestResetSessionState_ClearsUserMacros(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE MACRO my_custom_macro(x) AS x * 2"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, + "SELECT COUNT(*) FROM duckdb_functions() WHERE database_name = 'memory' AND schema_name = 'main' AND function_name = 'my_custom_macro'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user macro should have been dropped") + } +} + +func TestResetSessionState_PreservesSystemMacros(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + // Create a system macro (simulating warmup). + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE OR REPLACE MACRO pg_backend_pid() AS 0"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, + "SELECT COUNT(*) FROM duckdb_functions() WHERE database_name = 'memory' AND schema_name = 'main' AND function_name = 'pg_backend_pid'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("system macro pg_backend_pid should be preserved") + } +} + +func TestResetSessionState_ClearsUserSchemas(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE SCHEMA user_schema"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE TABLE user_schema.data (id INTEGER)"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_schemas() WHERE database_name = 'memory' AND schema_name = 'user_schema'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user schema should have been dropped") + } +} + +func TestResetSessionState_DetachesUserDatabases(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "ATTACH ':memory:' AS user_db"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_databases() WHERE database_name = 'user_db'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user database should have been detached") + } +} + +func TestResetSessionState_ClearsUserSequences(t *testing.T) { + pool, db := newTestPool(t) + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "CREATE SEQUENCE user_seq START 1"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var count int + err = conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM duckdb_sequences() WHERE database_name = 'memory' AND schema_name = 'main' AND sequence_name = 'user_seq'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatal("user sequence should have been dropped") + } +} + +func TestResetSessionState_ReappliesSettings(t *testing.T) { + pool, db := newTestPool(t) + pool.cfg.Threads = 4 + pool.cfg.MemoryLimit = "512MB" + ctx := context.Background() + + // Change settings to non-default values. + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "SET threads = 1"); err != nil { + t.Fatal(err) + } + _ = conn.Close() + + if err := pool.resetSessionState(db); err != nil { + t.Fatal(err) + } + + // Verify warmup settings were re-applied. + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn2.Close() }() + + var threads string + if err := conn2.QueryRowContext(ctx, "SELECT value FROM duckdb_settings() WHERE name = 'threads'").Scan(&threads); err != nil { + t.Fatal(err) + } + if threads != "4" { + t.Fatalf("expected threads=4 after reset, got %s", threads) + } + + var memLimit string + if err := conn2.QueryRowContext(ctx, "SELECT value FROM duckdb_settings() WHERE name = 'memory_limit'").Scan(&memLimit); err != nil { + t.Fatal(err) + } + // 512MB (base 10) = 488.2 MiB (base 2) + if !strings.Contains(memLimit, "488") { + t.Fatalf("expected memory_limit ~488 MiB (512MB) after reset, got %s", memLimit) + } +} diff --git a/duckgres_local_test.yaml b/duckgres_local_test.yaml new file mode 100644 index 0000000..776454e --- /dev/null +++ b/duckgres_local_test.yaml @@ -0,0 +1,17 @@ +host: "0.0.0.0" +port: 35437 +data_dir: "./data" +users: + postgres: "postgres" +extensions: + - ducklake +ducklake: + metadata_store: "postgres:host=localhost port=5433 user=ducklake password=ducklake dbname=ducklake" + object_store: "s3://ducklake/data/" + s3_provider: "config" + s3_endpoint: "localhost:9000" + s3_access_key: "minioadmin" + s3_secret_key: "minioadmin" + s3_region: "us-east-1" + s3_use_ssl: false + s3_url_style: "path" diff --git a/server/sysinfo.go b/server/sysinfo.go index d1ad47f..18cf3f1 100644 --- a/server/sysinfo.go +++ b/server/sysinfo.go @@ -52,6 +52,9 @@ var ( autoMemoryLimitValue string ) +// AutoMemoryLimit returns the automatically computed DuckDB memory limit. +func AutoMemoryLimit() string { return autoMemoryLimit() } + // autoMemoryLimit computes a DuckDB memory_limit based on system memory. // Formula: totalMem * 0.75, with a floor of 256MB. // Every session gets the full budget — DuckDB will spill to disk/swap if