diff --git a/config_resolution.go b/config_resolution.go index 63fa1dd..486872e 100644 --- a/config_resolution.go +++ b/config_resolution.go @@ -40,7 +40,7 @@ type configCLIInputs struct { MaxConnections int ConfigStoreConn string ConfigPollInterval string - AdminToken string + InternalSecret string WorkerBackend string K8sWorkerImage string K8sWorkerNamespace string @@ -52,6 +52,7 @@ type configCLIInputs struct { K8sWorkerServiceAccount string K8sMaxWorkers int K8sSharedWarmTarget int + AWSRegion string QueryLog bool } @@ -73,9 +74,10 @@ type resolvedConfig struct { K8sWorkerServiceAccount string K8sMaxWorkers int K8sSharedWarmTarget int + AWSRegion string ConfigStoreConn string ConfigPollInterval time.Duration - AdminToken string + InternalSecret string } func defaultServerConfig() server.Config { @@ -129,9 +131,10 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun var k8sWorkerPort int var k8sWorkerSecret, k8sWorkerConfigMap, k8sWorkerImagePullPolicy, k8sWorkerServiceAccount string var k8sMaxWorkers, k8sSharedWarmTarget int + var awsRegion string var configStoreConn string var configPollInterval time.Duration - var adminToken string + var internalSecret string if fileCfg != nil { if fileCfg.Host != "" { @@ -581,8 +584,8 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid DUCKGRES_CONFIG_POLL_INTERVAL duration: " + err.Error()) } } - if v := getenv("DUCKGRES_ADMIN_TOKEN"); v != "" { - adminToken = v + if v := getenv("DUCKGRES_INTERNAL_SECRET"); v != "" { + internalSecret = v } if v := getenv("DUCKGRES_WORKER_BACKEND"); v != "" { workerBackend = v @@ -629,6 +632,10 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid DUCKGRES_K8S_SHARED_WARM_TARGET: " + err.Error()) } } + if v := getenv("DUCKGRES_AWS_REGION"); v != "" { + awsRegion = v + } + // Query log env vars if v := getenv("DUCKGRES_QUERY_LOG_ENABLED"); v != "" { if b, err := strconv.ParseBool(v); err == nil { @@ -789,8 +796,8 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun warn("Invalid --config-poll-interval duration: " + err.Error()) } } - if cli.Set["admin-token"] { - adminToken = cli.AdminToken + if cli.Set["internal-secret"] { + internalSecret = cli.InternalSecret } if cli.Set["worker-backend"] { workerBackend = cli.WorkerBackend @@ -825,6 +832,9 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun if cli.Set["k8s-shared-warm-target"] { k8sSharedWarmTarget = cli.K8sSharedWarmTarget } + if cli.Set["aws-region"] { + awsRegion = cli.AWSRegion + } if cli.Set["query-log"] { cfg.QueryLog.Enabled = cli.QueryLog } @@ -894,8 +904,9 @@ func resolveEffectiveConfig(fileCfg *FileConfig, cli configCLIInputs, getenv fun K8sWorkerServiceAccount: k8sWorkerServiceAccount, K8sMaxWorkers: k8sMaxWorkers, K8sSharedWarmTarget: k8sSharedWarmTarget, + AWSRegion: awsRegion, ConfigStoreConn: configStoreConn, ConfigPollInterval: configPollInterval, - AdminToken: adminToken, + InternalSecret: internalSecret, } } diff --git a/controlplane/activation_payload_test.go b/controlplane/activation_payload_test.go index be72889..c75499b 100644 --- a/controlplane/activation_payload_test.go +++ b/controlplane/activation_payload_test.go @@ -62,7 +62,7 @@ func TestBuildTenantActivationPayloadBuildsDuckLakeRuntimeFromWarehouseSecrets(t }, } - payload, err := BuildTenantActivationPayload(context.Background(), pool.clientset, pool.namespace, org) + payload, err := BuildTenantActivationPayload(context.Background(), pool.clientset, pool.namespace, org, nil) if err != nil { t.Fatalf("BuildTenantActivationPayload: %v", err) } diff --git a/controlplane/admin/api.go b/controlplane/admin/api.go index 1d08db2..ccb32bb 100644 --- a/controlplane/admin/api.go +++ b/controlplane/admin/api.go @@ -14,7 +14,7 @@ import ( "gorm.io/gorm/clause" ) -var errWarehousePayloadNotAllowed = errors.New("warehouse payload must be updated via /orgs/:name/warehouse") +var errWarehousePayloadNotAllowed = errors.New("warehouse payload must be updated via /orgs/:id/warehouse") // WorkerStatus represents a worker's current status for the API. type WorkerStatus struct { @@ -71,11 +71,11 @@ func registerAPIWithStore(r *gin.RouterGroup, store apiStore, info OrgStackInfo) // Orgs CRUD r.GET("/orgs", h.listOrgs) r.POST("/orgs", h.createOrg) - r.GET("/orgs/:name", h.getOrg) - r.PUT("/orgs/:name", h.updateOrg) - r.DELETE("/orgs/:name", h.deleteOrg) - r.GET("/orgs/:name/warehouse", h.getManagedWarehouse) - r.PUT("/orgs/:name/warehouse", h.putManagedWarehouse) + r.GET("/orgs/:id", h.getOrg) + r.PUT("/orgs/:id", h.updateOrg) + r.DELETE("/orgs/:id", h.deleteOrg) + r.GET("/orgs/:id/warehouse", h.getManagedWarehouse) + r.PUT("/orgs/:id/warehouse", h.putManagedWarehouse) // Users CRUD r.GET("/users", h.listUsers) @@ -285,6 +285,9 @@ func (s *gormAPIStore) UpsertManagedWarehouse(orgID string, warehouse *configsto func managedWarehouseUpsertColumns() []string { return []string{ + "image", + "aurora_min_acu", + "aurora_max_acu", "warehouse_database_region", "warehouse_database_endpoint", "warehouse_database_port", @@ -481,7 +484,7 @@ func (h *apiHandler) createOrg(c *gin.Context) { } func (h *apiHandler) getOrg(c *gin.Context) { - name := c.Param("name") + name := c.Param("id") org, err := h.store.GetOrg(name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "org not found"}) @@ -491,7 +494,7 @@ func (h *apiHandler) getOrg(c *gin.Context) { } func (h *apiHandler) updateOrg(c *gin.Context) { - name := c.Param("name") + name := c.Param("id") var updates configstore.Org if err := c.ShouldBindJSON(&updates); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -514,7 +517,7 @@ func (h *apiHandler) updateOrg(c *gin.Context) { } func (h *apiHandler) deleteOrg(c *gin.Context) { - name := c.Param("name") + name := c.Param("id") ok, err := h.store.DeleteOrg(name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -535,7 +538,7 @@ func validateOrgMutationPayload(org *configstore.Org) error { } func (h *apiHandler) getManagedWarehouse(c *gin.Context) { - warehouse, err := h.store.GetManagedWarehouse(c.Param("name")) + warehouse, err := h.store.GetManagedWarehouse(c.Param("id")) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { c.JSON(http.StatusNotFound, gin.H{"error": "managed warehouse not found"}) @@ -548,7 +551,7 @@ func (h *apiHandler) getManagedWarehouse(c *gin.Context) { } func (h *apiHandler) putManagedWarehouse(c *gin.Context) { - orgID := c.Param("name") + orgID := c.Param("id") var req managedWarehouseRequest if err := decodeStrictWarehouseRequest(c, &req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/controlplane/admin/dashboard.go b/controlplane/admin/dashboard.go index 6984663..0f7de44 100644 --- a/controlplane/admin/dashboard.go +++ b/controlplane/admin/dashboard.go @@ -90,10 +90,11 @@ func renderLoginPage(c *gin.Context, next, errMsg string) { } func requestAdminToken(c *gin.Context) string { - auth := c.GetHeader("Authorization") - if strings.HasPrefix(auth, "Bearer ") { - return strings.TrimPrefix(auth, "Bearer ") + // Primary: X-Duckgres-Internal-Secret header (service-to-service) + if secret := c.GetHeader("X-Duckgres-Internal-Secret"); secret != "" { + return secret } + // Fallback: cookie (dashboard UI) if cookie, err := c.Cookie(adminTokenCookieName); err == nil { return cookie } diff --git a/controlplane/configstore/models.go b/controlplane/configstore/models.go index 8215ee7..439190d 100644 --- a/controlplane/configstore/models.go +++ b/controlplane/configstore/models.go @@ -90,6 +90,10 @@ type ManagedWarehouseWorkerIdentity struct { type ManagedWarehouse struct { OrgID string `gorm:"primaryKey;size:255" json:"org_id"` + Image string `gorm:"size:512" json:"image"` + AuroraMinACU float64 `json:"aurora_min_acu"` + AuroraMaxACU float64 `json:"aurora_max_acu"` + WarehouseDatabase ManagedWarehouseDatabase `gorm:"embedded;embeddedPrefix:warehouse_database_" json:"warehouse_database"` MetadataStore ManagedWarehouseMetadataStore `gorm:"embedded;embeddedPrefix:metadata_store_" json:"metadata_store"` S3 ManagedWarehouseS3 `gorm:"embedded;embeddedPrefix:s3_" json:"s3"` @@ -112,6 +116,7 @@ type ManagedWarehouse struct { IdentityStatusMessage string `gorm:"size:1024" json:"identity_status_message"` SecretsState ManagedWarehouseProvisioningState `gorm:"size:32" json:"secrets_state"` SecretsStatusMessage string `gorm:"size:1024" json:"secrets_status_message"` + ProvisioningStartedAt *time.Time `json:"provisioning_started_at"` ReadyAt *time.Time `json:"ready_at"` FailedAt *time.Time `json:"failed_at"` CreatedAt time.Time `json:"created_at"` @@ -195,6 +200,10 @@ type OrgConfig struct { type ManagedWarehouseConfig struct { OrgID string + Image string + AuroraMinACU float64 + AuroraMaxACU float64 + WarehouseDatabase ManagedWarehouseDatabase MetadataStore ManagedWarehouseMetadataStore S3 ManagedWarehouseS3 @@ -228,6 +237,9 @@ func copyManagedWarehouseConfig(warehouse *ManagedWarehouse) *ManagedWarehouseCo cfg := &ManagedWarehouseConfig{ OrgID: warehouse.OrgID, + Image: warehouse.Image, + AuroraMinACU: warehouse.AuroraMinACU, + AuroraMaxACU: warehouse.AuroraMaxACU, WarehouseDatabase: warehouse.WarehouseDatabase, MetadataStore: warehouse.MetadataStore, S3: warehouse.S3, diff --git a/controlplane/configstore/store.go b/controlplane/configstore/store.go index a50aaf5..88353e6 100644 --- a/controlplane/configstore/store.go +++ b/controlplane/configstore/store.go @@ -217,6 +217,31 @@ func (cs *ConfigStore) OnChange(fn func(old, new *Snapshot)) { cs.onChange = append(cs.onChange, fn) } +// ListWarehousesByStates returns all warehouses with a state matching one of the given values. +// This is a direct DB query, not snapshot-based, for use by the provisioning controller. +func (cs *ConfigStore) ListWarehousesByStates(states []ManagedWarehouseProvisioningState) ([]ManagedWarehouse, error) { + var warehouses []ManagedWarehouse + if err := cs.db.Where("state IN ?", states).Find(&warehouses).Error; err != nil { + return nil, fmt.Errorf("list warehouses by states: %w", err) + } + return warehouses, nil +} + +// UpdateWarehouseState performs a compare-and-swap update on a warehouse row. +// Only updates if the current state matches expectedState, preventing races. +func (cs *ConfigStore) UpdateWarehouseState(orgID string, expectedState ManagedWarehouseProvisioningState, updates map[string]interface{}) error { + result := cs.db.Model(&ManagedWarehouse{}). + Where("org_id = ? AND state = ?", orgID, expectedState). + Updates(updates) + if result.Error != nil { + return fmt.Errorf("update warehouse state: %w", result.Error) + } + if result.RowsAffected == 0 { + return fmt.Errorf("warehouse %q not in expected state %q", orgID, expectedState) + } + return nil +} + // DB exposes the GORM database for direct CRUD operations (used by admin API). func (cs *ConfigStore) DB() *gorm.DB { return cs.db diff --git a/controlplane/control.go b/controlplane/control.go index c3af279..9a638db 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -58,9 +58,10 @@ type ControlPlaneConfig struct { // Default: 30s. ConfigPollInterval time.Duration - // AdminToken is the bearer token required for admin API requests. - // When empty, a random token is generated and logged at startup. - AdminToken string + // InternalSecret is the shared secret for API authentication. + // When empty, a random secret is generated and logged at startup. + InternalSecret string + } type ProcessConfig struct { @@ -80,6 +81,7 @@ type K8sConfig struct { ServiceAccount string // ServiceAccount name for worker pods (default: "default") MaxWorkers int // Global cap for the shared K8s worker pool (0 = auto-derived) SharedWarmTarget int // Neutral shared warm-worker target for K8s multi-tenant mode (0 = disabled) + AWSRegion string // AWS region for STS client } // ControlPlane manages the TCP listener and routes connections to Flight SQL workers. @@ -110,6 +112,7 @@ type ControlPlane struct { // Multi-tenant fields (non-nil in remote multitenant mode) orgRouter OrgRouterInterface configStore ConfigStoreInterface + apiServer *http.Server // API server on :8080 (shut down on graceful exit) } // ConfigStoreInterface abstracts the config store for the control plane. @@ -317,20 +320,14 @@ func RunControlPlane(cfg ControlPlaneConfig) { // Multi-tenant mode: config store + per-org pools (K8s remote backend only) if cfg.WorkerBackend == "remote" { - store, adapter, adminSrv, err := SetupMultiTenant(cfg, srv, memBudget, k8sMaxWorkers) + store, adapter, apiServer, err := SetupMultiTenant(cfg, srv, memBudget, k8sMaxWorkers) if err != nil { slog.Error("Failed to set up multi-tenant config store.", "error", err) os.Exit(1) } cp.configStore = store cp.orgRouter = adapter - // Replace the simple metrics server with the Gin admin server - if cfg.MetricsServer != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _ = cfg.MetricsServer.Shutdown(ctx) - cancel() - } - cfg.MetricsServer = adminSrv + cp.apiServer = apiServer cp.cfg = cfg _ = store // keep linter happy } else { @@ -961,6 +958,13 @@ func (cp *ControlPlane) handleUpgrade() { } cancel() } + if cp.apiServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := cp.apiServer.Shutdown(ctx); err != nil { + slog.Warn("API server shutdown failed.", "error", err) + } + cancel() + } // Stop ACME managers so the new CP can bind port 80 (HTTP-01) or // manage DNS records. Nil out after close so drainAfterUpgrade @@ -1181,9 +1185,6 @@ func (cp *ControlPlane) recoverMetricsAfterFailedReload() { addr := cp.cfg.MetricsServer.Addr mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) - mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }) newSrv := &http.Server{Addr: addr, Handler: mux} cp.cfg.MetricsServer = newSrv go func() { diff --git a/controlplane/multitenant.go b/controlplane/multitenant.go index 8844b16..8e94700 100644 --- a/controlplane/multitenant.go +++ b/controlplane/multitenant.go @@ -14,8 +14,9 @@ import ( "github.com/gin-gonic/gin" "github.com/posthog/duckgres/controlplane/admin" "github.com/posthog/duckgres/controlplane/configstore" + "github.com/posthog/duckgres/controlplane/provisioner" + "github.com/posthog/duckgres/controlplane/provisioning" "github.com/posthog/duckgres/server" - "github.com/prometheus/client_golang/prometheus/promhttp" ) // orgRouterAdapter wraps OrgRouter to implement both OrgRouterInterface @@ -106,8 +107,9 @@ func (a *orgRouterAdapter) AllSessionStatuses() []admin.SessionStatus { var _ OrgRouterInterface = (*orgRouterAdapter)(nil) var _ admin.OrgStackInfo = (*orgRouterAdapter)(nil) -// SetupMultiTenant initializes the config store, org router, and Gin admin server. +// SetupMultiTenant initializes the config store, org router, and API server. // Called from RunControlPlane when --config-store is set with remote backend. +// Returns the API server for graceful shutdown. func SetupMultiTenant( cfg ControlPlaneConfig, srv *server.Server, @@ -139,13 +141,31 @@ func SetupMultiTenant( MemoryBudget: int64(memBudget), } - router, err := NewOrgRouter(store, baseCfg, cfg, srv) + // Initialize STS broker for credential brokering (best-effort) + var stsBroker *STSBroker + if cfg.K8s.AWSRegion != "" { + var err error + stsBroker, err = NewSTSBroker(context.Background(), cfg.K8s.AWSRegion) + if err != nil { + slog.Warn("STS broker unavailable, workers will use pod identity for S3.", "error", err) + } + } + + router, err := NewOrgRouter(store, baseCfg, cfg, srv, stsBroker) if err != nil { return nil, nil, nil, err } adpt := &orgRouterAdapter{router: router} + // Start provisioning controller (best-effort — K8s API may not be available locally) + provCtrl, err := provisioner.NewController(store, 10*time.Second) + if err != nil { + slog.Warn("Provisioning controller unavailable.", "error", err) + } else { + go provCtrl.Run(context.Background()) + } + // Register config change handler store.OnChange(router.HandleConfigChange) @@ -153,44 +173,45 @@ func SetupMultiTenant( store.Start(context.Background()) // Resolve admin bearer token - adminToken := cfg.AdminToken - if adminToken == "" { + internalSecret := cfg.InternalSecret + if internalSecret == "" { tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { - return nil, nil, nil, fmt.Errorf("generate admin token: %w", err) + return nil, nil, nil, fmt.Errorf("generate internal secret: %w", err) } - adminToken = hex.EncodeToString(tokenBytes) - slog.Info("Generated admin API token (pass via --admin-token or DUCKGRES_ADMIN_TOKEN to set explicitly).", "token", adminToken) + internalSecret = hex.EncodeToString(tokenBytes) + slog.Info("Generated internal secret (pass via --internal-secret or DUCKGRES_INTERNAL_SECRET to set explicitly).", "secret", internalSecret) } - // Set up Gin admin server (replaces the simple metrics server) + // Set up API server (admin + provisioning + dashboard on :8080). + // The existing metrics server on :9090 stays running separately. gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(gin.Recovery()) - // Existing endpoints (unauthenticated) - engine.GET("/metrics", gin.WrapH(promhttp.Handler())) + // Health endpoint (unauthenticated, used by K8s probes) engine.GET("/health", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) - // Admin API (authenticated) - api := engine.Group("/api/v1", admin.APIAuthMiddleware(adminToken)) + // Authenticated API + api := engine.Group("/api/v1", admin.APIAuthMiddleware(internalSecret)) admin.RegisterAPI(api, store, adpt) + provisioning.RegisterAPI(api, provisioning.NewGormStore(store)) // Dashboard - admin.RegisterDashboard(engine, adminToken) + admin.RegisterDashboard(engine, internalSecret) - adminServer := &http.Server{ - Addr: ":9090", + apiServer := &http.Server{ + Addr: ":8080", Handler: engine, } go func() { - slog.Info("Starting admin server with dashboard.", "addr", adminServer.Addr) - if err := adminServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - slog.Warn("Admin server error.", "error", err) + slog.Info("Starting API server.", "addr", apiServer.Addr) + if err := apiServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Warn("API server error.", "error", err) } }() - return store, adpt, adminServer, nil + return store, adpt, apiServer, nil } diff --git a/controlplane/org_reserved_pool.go b/controlplane/org_reserved_pool.go index 3482567..7378fea 100644 --- a/controlplane/org_reserved_pool.go +++ b/controlplane/org_reserved_pool.go @@ -19,15 +19,17 @@ type OrgReservedPool struct { orgID string maxWorkers int leaseDuration time.Duration + stsBroker *STSBroker activateReservedWorker func(context.Context, *ManagedWorker) error } -func NewOrgReservedPool(shared *K8sWorkerPool, orgID string, maxWorkers int) *OrgReservedPool { +func NewOrgReservedPool(shared *K8sWorkerPool, orgID string, maxWorkers int, stsBroker *STSBroker) *OrgReservedPool { pool := &OrgReservedPool{ shared: shared, orgID: orgID, maxWorkers: maxWorkers, leaseDuration: defaultSharedWorkerReservationLease, + stsBroker: stsBroker, } pool.activateReservedWorker = pool.activateReservedWorkerDefault return pool diff --git a/controlplane/org_reserved_pool_test.go b/controlplane/org_reserved_pool_test.go index c270d72..e5ce8a8 100644 --- a/controlplane/org_reserved_pool_test.go +++ b/controlplane/org_reserved_pool_test.go @@ -17,7 +17,7 @@ func TestOrgReservedPoolAcquireReservesOrgWorker(t *testing.T) { return nil } - pool := NewOrgReservedPool(shared, "analytics", 2) + pool := NewOrgReservedPool(shared, "analytics", 2, nil) pool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { return nil } @@ -59,7 +59,7 @@ func TestOrgReservedPoolAcquireSkipsOtherOrgsWorkers(t *testing.T) { return nil } - pool := NewOrgReservedPool(shared, "analytics", 2) + pool := NewOrgReservedPool(shared, "analytics", 2, nil) pool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { return nil } @@ -89,7 +89,7 @@ func TestOrgReservedPoolReleaseWorkerRetiresOnLastSession(t *testing.T) { } shared.workers[worker.ID] = worker - pool := NewOrgReservedPool(shared, "analytics", 1) + pool := NewOrgReservedPool(shared, "analytics", 1, nil) pool.ReleaseWorker(worker.ID) time.Sleep(100 * time.Millisecond) @@ -108,7 +108,7 @@ func TestOrgReservedWorkerPoolAcquireActivatesReservedWorkerWhenEnabledWithOrgCo } activated := false - pool := NewOrgReservedPool(shared, "analytics", 2) + pool := NewOrgReservedPool(shared, "analytics", 2, nil) pool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { activated = true return nil @@ -135,7 +135,7 @@ func TestOrgReservedWorkerPoolAcquireDelegatesActivationWithoutCachedTenantRunti return nil } - pool := NewOrgReservedPool(shared, "analytics", 2) + pool := NewOrgReservedPool(shared, "analytics", 2, nil) activated := 0 pool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { activated++ @@ -171,7 +171,7 @@ func TestOrgReservedPoolAcquireWaitsWhenSharedWarmWorkerBusyAtCapacity(t *testin } shared.workers[worker.ID] = worker - pool := NewOrgReservedPool(shared, "analytics", 1) + pool := NewOrgReservedPool(shared, "analytics", 1, nil) ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) defer cancel() diff --git a/controlplane/org_router.go b/controlplane/org_router.go index 5378fd2..345556d 100644 --- a/controlplane/org_router.go +++ b/controlplane/org_router.go @@ -32,18 +32,20 @@ type OrgRouter struct { sharedPool *K8sWorkerPool globalCfg ControlPlaneConfig srv *server.Server + stsBroker *STSBroker nextWorkerID atomic.Int32 sharedCancel context.CancelFunc } // NewOrgRouter creates an OrgRouter from the initial config snapshot. -func NewOrgRouter(store *configstore.ConfigStore, baseCfg K8sWorkerPoolConfig, globalCfg ControlPlaneConfig, srv *server.Server) (*OrgRouter, error) { +func NewOrgRouter(store *configstore.ConfigStore, baseCfg K8sWorkerPoolConfig, globalCfg ControlPlaneConfig, srv *server.Server, stsBroker *STSBroker) (*OrgRouter, error) { tr := &OrgRouter{ orgs: make(map[string]*OrgStack), configStore: store, baseCfg: baseCfg, globalCfg: globalCfg, srv: srv, + stsBroker: stsBroker, } sharedCfg := baseCfg @@ -68,6 +70,11 @@ func NewOrgRouter(store *configstore.ConfigStore, baseCfg K8sWorkerPoolConfig, g snap := store.Snapshot() for _, tc := range snap.Orgs { + // Only create stacks for orgs with ready warehouses (or no warehouse at all for backwards compat) + if tc.Warehouse != nil && tc.Warehouse.State != configstore.ManagedWarehouseStateReady { + slog.Info("Skipping org stack creation (warehouse not ready).", "org", tc.Name, "state", tc.Warehouse.State) + continue + } if _, err := tr.createOrgStack(tc); err != nil { slog.Error("Failed to create org stack.", "org", tc.Name, "error", err) continue @@ -93,8 +100,8 @@ func (tr *OrgRouter) createOrgStack(tc *configstore.OrgConfig) (*OrgStack, error memoryBudget = int64(server.ParseMemoryBytes(tc.MemoryBudget)) } - pool := NewOrgReservedPool(tr.sharedPool, tc.Name, maxWorkers) - activator := NewSharedWorkerActivator(tr.sharedPool, func(orgID string) (*configstore.OrgConfig, error) { + pool := NewOrgReservedPool(tr.sharedPool, tc.Name, maxWorkers, tr.stsBroker) + activator := NewSharedWorkerActivator(tr.sharedPool, tr.stsBroker, func(orgID string) (*configstore.OrgConfig, error) { snap := tr.configStore.Snapshot() if snap == nil { return nil, fmt.Errorf("config snapshot unavailable for org %s", orgID) @@ -177,13 +184,50 @@ func (tr *OrgRouter) StackForUser(username string) (*OrgStack, bool) { // HandleConfigChange reconciles org stacks when the config snapshot changes. func (tr *OrgRouter) HandleConfigChange(old, new *configstore.Snapshot) { - // Detect new orgs + // Detect new orgs or orgs whose warehouse just became ready for name, tc := range new.Orgs { - if _, existed := old.Orgs[name]; !existed { + oldTC, existed := old.Orgs[name] + + // Skip orgs with warehouses that aren't ready + if tc.Warehouse != nil && tc.Warehouse.State != configstore.ManagedWarehouseStateReady { + // If warehouse is being deleted, destroy existing stack + if tc.Warehouse.State == configstore.ManagedWarehouseStateDeleting || + tc.Warehouse.State == configstore.ManagedWarehouseStateDeleted { + tr.mu.RLock() + _, hasStack := tr.orgs[name] + tr.mu.RUnlock() + if hasStack { + slog.Info("Warehouse deprovisioning, destroying stack.", "org", name) + tr.DestroyOrgStack(name) + } + } + continue + } + + tr.mu.RLock() + _, hasStack := tr.orgs[name] + tr.mu.RUnlock() + + if !existed && !hasStack { + // Brand new org -- create stack slog.Info("New org detected, creating stack.", "org", name) if _, err := tr.createOrgStack(tc); err != nil { slog.Error("Failed to create org stack on config change.", "org", name, "error", err) } + } else if existed && !hasStack { + // Existing org whose warehouse just became ready + warehouseJustReady := oldTC.Warehouse != nil && + oldTC.Warehouse.State != configstore.ManagedWarehouseStateReady && + tc.Warehouse != nil && + tc.Warehouse.State == configstore.ManagedWarehouseStateReady + noWarehouse := tc.Warehouse == nil + + if warehouseJustReady || noWarehouse { + slog.Info("Org warehouse ready, creating stack.", "org", name) + if _, err := tr.createOrgStack(tc); err != nil { + slog.Error("Failed to create org stack on config change.", "org", name, "error", err) + } + } } } diff --git a/controlplane/org_router_test.go b/controlplane/org_router_test.go index 8662a06..491f2e4 100644 --- a/controlplane/org_router_test.go +++ b/controlplane/org_router_test.go @@ -47,7 +47,7 @@ func TestOrgRouterReconcileWarmCapacityUsesExplicitSharedWarmTarget(t *testing.T func TestOrgRouterHandleConfigChangeRefreshesRuntimeOnlyUpdates(t *testing.T) { sharedPool, _ := newTestK8sPool(t, 10) - pool := NewOrgReservedPool(sharedPool, "analytics", 2) + pool := NewOrgReservedPool(sharedPool, "analytics", 2, nil) oldTC := &configstore.OrgConfig{ Name: "analytics", diff --git a/controlplane/provisioner/controller.go b/controlplane/provisioner/controller.go new file mode 100644 index 0000000..5e2c314 --- /dev/null +++ b/controlplane/provisioner/controller.go @@ -0,0 +1,253 @@ +//go:build kubernetes + +package provisioner + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" + apierrors "k8s.io/apimachinery/pkg/api/errors" +) + +// WarehouseStore is the subset of configstore.ConfigStore that the controller needs. +type WarehouseStore interface { + ListWarehousesByStates(states []configstore.ManagedWarehouseProvisioningState) ([]configstore.ManagedWarehouse, error) + UpdateWarehouseState(orgID string, expectedState configstore.ManagedWarehouseProvisioningState, updates map[string]interface{}) error +} + +// Controller polls the config store for actionable warehouses and reconciles +// their state against Duckling CRs in Kubernetes. +type Controller struct { + store WarehouseStore + duckling *DucklingClient + pollInterval time.Duration +} + +// NewController creates a provisioning controller. Returns an error if the +// Kubernetes client cannot be initialized (e.g., not running in-cluster). +func NewController(store WarehouseStore, pollInterval time.Duration) (*Controller, error) { + dc, err := NewDucklingClient() + if err != nil { + return nil, fmt.Errorf("create duckling client: %w", err) + } + return &Controller{ + store: store, + duckling: dc, + pollInterval: pollInterval, + }, nil +} + +// NewControllerWithClient creates a Controller with a pre-built DucklingClient (for testing). +func NewControllerWithClient(store WarehouseStore, dc *DucklingClient, pollInterval time.Duration) *Controller { + return &Controller{ + store: store, + duckling: dc, + pollInterval: pollInterval, + } +} + +// Run starts the reconciliation loop. Blocks until ctx is cancelled. +func (c *Controller) Run(ctx context.Context) { + slog.Info("Provisioning controller started.", "poll_interval", c.pollInterval) + ticker := time.NewTicker(c.pollInterval) + defer ticker.Stop() + + // Run once immediately at startup + c.reconcile(ctx) + + for { + select { + case <-ctx.Done(): + slog.Info("Provisioning controller stopped.") + return + case <-ticker.C: + c.reconcile(ctx) + } + } +} + +// actionableStates are the warehouse states the controller acts on. +var actionableStates = []configstore.ManagedWarehouseProvisioningState{ + configstore.ManagedWarehouseStatePending, + configstore.ManagedWarehouseStateProvisioning, + configstore.ManagedWarehouseStateDeleting, +} + +func (c *Controller) reconcile(ctx context.Context) { + warehouses, err := c.store.ListWarehousesByStates(actionableStates) + if err != nil { + slog.Warn("Provisioning controller: failed to list warehouses.", "error", err) + return + } + + for _, w := range warehouses { + if ctx.Err() != nil { + return + } + switch w.State { + case configstore.ManagedWarehouseStatePending: + c.reconcilePending(ctx, &w) + case configstore.ManagedWarehouseStateProvisioning: + c.reconcileProvisioning(ctx, &w) + case configstore.ManagedWarehouseStateDeleting: + c.reconcileDeleting(ctx, &w) + } + } +} + +func (c *Controller) reconcilePending(ctx context.Context, w *configstore.ManagedWarehouse) { + log := slog.With("org", w.OrgID, "phase", "pending") + + now := time.Now().UTC() + + // Check if a Duckling CR already exists (e.g., controller restart) + _, err := c.duckling.Get(ctx, w.OrgID) + if err == nil { + // CR exists — transition directly to provisioning + log.Info("Duckling CR already exists, transitioning to provisioning.") + if err := c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStatePending, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateProvisioning, + "status_message": "Duckling CR exists, polling status", + "provisioning_started_at": now, + }); err != nil { + log.Warn("Failed to update state to provisioning.", "error", err) + } + return + } + + // Create the Duckling CR + log.Info("Creating Duckling CR.") + if err := c.duckling.Create(ctx, w.OrgID, w.AuroraMinACU, w.AuroraMaxACU); err != nil { + log.Error("Failed to create Duckling CR.", "error", err) + _ = c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStatePending, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateFailed, + "status_message": fmt.Sprintf("Failed to create Duckling CR: %v", err), + "failed_at": now, + }) + return + } + + if err := c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStatePending, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateProvisioning, + "status_message": "Duckling CR created, waiting for resources", + "provisioning_started_at": now, + }); err != nil { + log.Warn("Failed to update state to provisioning.", "error", err) + } +} + +func (c *Controller) reconcileProvisioning(ctx context.Context, w *configstore.ManagedWarehouse) { + log := slog.With("org", w.OrgID, "phase", "provisioning") + + // Use ProvisioningStartedAt if set (tracks when we entered provisioning state), + // fall back to CreatedAt for warehouses created before this field existed. + startedAt := w.CreatedAt + if w.ProvisioningStartedAt != nil { + startedAt = *w.ProvisioningStartedAt + } + + // Check for timeout (30 minutes) + if time.Since(startedAt) > 30*time.Minute { + log.Warn("Provisioning timed out.") + _ = c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStateProvisioning, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateFailed, + "status_message": "Provisioning timed out after 30 minutes", + "failed_at": time.Now().UTC(), + }) + return + } + + status, err := c.duckling.Get(ctx, w.OrgID) + if err != nil { + log.Warn("Failed to get Duckling CR status.", "error", err) + return + } + + // Check for Crossplane failure — only fail on persistent sync errors. + // Crossplane resources commonly flap Synced=False transiently (e.g., IAM + // eventual consistency, Aurora cold start delays), so we only transition + // to failed if 10+ minutes have passed, giving transient errors time to resolve. + if status.SyncedFalseMessage != "" && time.Since(startedAt) > 10*time.Minute { + log.Warn("Crossplane sync failure.", "message", status.SyncedFalseMessage) + _ = c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStateProvisioning, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateFailed, + "status_message": fmt.Sprintf("Crossplane error: %s", status.SyncedFalseMessage), + "failed_at": time.Now().UTC(), + }) + return + } + + // Update per-component states based on Duckling CR status fields. + // The Duckling composition provisions AWS infrastructure only (Aurora, S3, IAM). + // K8s workloads (namespace, deployment, service) are managed by the duckgres Helm chart. + updates := map[string]interface{}{} + + if status.DataStore.BucketName != "" && w.S3State != configstore.ManagedWarehouseStateReady { + updates["s3_state"] = configstore.ManagedWarehouseStateReady + updates["s3_bucket"] = status.DataStore.BucketName + } + + if status.MetadataStore.Endpoint != "" && w.MetadataStoreState != configstore.ManagedWarehouseStateReady { + updates["metadata_store_state"] = configstore.ManagedWarehouseStateReady + updates["metadata_store_endpoint"] = status.MetadataStore.Endpoint + updates["metadata_store_port"] = 5432 + updates["metadata_store_kind"] = status.MetadataStore.Type + updates["metadata_store_engine"] = "postgres" + updates["metadata_store_username"] = status.MetadataStore.User + updates["metadata_store_database_name"] = status.MetadataStore.Database + } + + if status.MetadataStore.Password != "" && w.SecretsState != configstore.ManagedWarehouseStateReady { + updates["secrets_state"] = configstore.ManagedWarehouseStateReady + } + + if status.IAMRoleARN != "" && w.IdentityState != configstore.ManagedWarehouseStateReady { + updates["identity_state"] = configstore.ManagedWarehouseStateReady + updates["worker_identity_iam_role_arn"] = status.IAMRoleARN + } + + // Infrastructure is ready when S3, Aurora, secrets, and IAM are all provisioned. + s3Ready := w.S3State == configstore.ManagedWarehouseStateReady || updates["s3_state"] == configstore.ManagedWarehouseStateReady + metaReady := w.MetadataStoreState == configstore.ManagedWarehouseStateReady || updates["metadata_store_state"] == configstore.ManagedWarehouseStateReady + secretsReady := w.SecretsState == configstore.ManagedWarehouseStateReady || updates["secrets_state"] == configstore.ManagedWarehouseStateReady + identReady := w.IdentityState == configstore.ManagedWarehouseStateReady || updates["identity_state"] == configstore.ManagedWarehouseStateReady + + if s3Ready && metaReady && secretsReady && identReady { + now := time.Now().UTC() + updates["state"] = configstore.ManagedWarehouseStateReady + updates["status_message"] = "Infrastructure ready" + updates["ready_at"] = now + log.Info("Infrastructure ready, transitioning to ready.") + } + + if len(updates) > 0 { + if err := c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStateProvisioning, updates); err != nil { + log.Warn("Failed to update warehouse state.", "error", err) + } + } +} + +func (c *Controller) reconcileDeleting(ctx context.Context, w *configstore.ManagedWarehouse) { + log := slog.With("org", w.OrgID, "phase", "deleting") + + log.Info("Deleting Duckling CR.") + if err := c.duckling.Delete(ctx, w.OrgID); err != nil { + // Only proceed if the CR is already gone (NotFound). For other errors + // (network, RBAC, etc.) we retry on the next reconcile pass to avoid + // marking as deleted while AWS resources still exist. + if !apierrors.IsNotFound(err) { + log.Warn("Failed to delete Duckling CR, will retry.", "error", err) + return + } + } + + if err := c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStateDeleting, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateDeleted, + "status_message": "Resources deleted", + }); err != nil { + log.Warn("Failed to update state to deleted.", "error", err) + } +} diff --git a/controlplane/provisioner/controller_stub.go b/controlplane/provisioner/controller_stub.go new file mode 100644 index 0000000..73290ff --- /dev/null +++ b/controlplane/provisioner/controller_stub.go @@ -0,0 +1,22 @@ +//go:build !kubernetes + +package provisioner + +import ( + "context" + "errors" + "time" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +// Controller is a stub for non-Kubernetes builds. +type Controller struct{} + +// NewController returns an error on non-Kubernetes builds since it requires K8s API access. +func NewController(_ *configstore.ConfigStore, _ time.Duration) (*Controller, error) { + return nil, errors.New("provisioning controller requires kubernetes build tag") +} + +// Run is a no-op stub. +func (c *Controller) Run(_ context.Context) {} diff --git a/controlplane/provisioner/controller_test.go b/controlplane/provisioner/controller_test.go new file mode 100644 index 0000000..b554708 --- /dev/null +++ b/controlplane/provisioner/controller_test.go @@ -0,0 +1,382 @@ +//go:build kubernetes + +package provisioner + +import ( + "context" + "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + dynamicfake "k8s.io/client-go/dynamic/fake" + + "github.com/posthog/duckgres/controlplane/configstore" +) + +// fakeStore implements WarehouseStore for unit tests. +type fakeStore struct { + warehouses map[string]*configstore.ManagedWarehouse +} + +func newFakeStore() *fakeStore { + return &fakeStore{warehouses: make(map[string]*configstore.ManagedWarehouse)} +} + +func (s *fakeStore) ListWarehousesByStates(states []configstore.ManagedWarehouseProvisioningState) ([]configstore.ManagedWarehouse, error) { + var result []configstore.ManagedWarehouse + for _, w := range s.warehouses { + for _, st := range states { + if w.State == st { + result = append(result, *w) + break + } + } + } + return result, nil +} + +func (s *fakeStore) UpdateWarehouseState(orgID string, expectedState configstore.ManagedWarehouseProvisioningState, updates map[string]interface{}) error { + w, ok := s.warehouses[orgID] + if !ok { + return nil + } + if w.State != expectedState { + return nil + } + for k, v := range updates { + switch k { + case "state": + w.State = v.(configstore.ManagedWarehouseProvisioningState) + case "status_message": + w.StatusMessage = v.(string) + case "s3_state": + w.S3State = v.(configstore.ManagedWarehouseProvisioningState) + case "s3_bucket": + w.S3.Bucket = v.(string) + case "metadata_store_state": + w.MetadataStoreState = v.(configstore.ManagedWarehouseProvisioningState) + case "metadata_store_endpoint": + w.MetadataStore.Endpoint = v.(string) + case "metadata_store_port": + w.MetadataStore.Port = v.(int) + case "metadata_store_kind": + w.MetadataStore.Kind = v.(string) + case "metadata_store_engine": + w.MetadataStore.Engine = v.(string) + case "identity_state": + w.IdentityState = v.(configstore.ManagedWarehouseProvisioningState) + case "worker_identity_iam_role_arn": + w.WorkerIdentity.IAMRoleARN = v.(string) + case "worker_identity_namespace": + w.WorkerIdentity.Namespace = v.(string) + case "metadata_store_username": + w.MetadataStore.Username = v.(string) + case "metadata_store_database_name": + w.MetadataStore.DatabaseName = v.(string) + case "secrets_state": + w.SecretsState = v.(configstore.ManagedWarehouseProvisioningState) + case "warehouse_database_state": + w.WarehouseDatabaseState = v.(configstore.ManagedWarehouseProvisioningState) + case "ready_at": + t := v.(time.Time) + w.ReadyAt = &t + case "failed_at": + t := v.(time.Time) + w.FailedAt = &t + case "provisioning_started_at": + t := v.(time.Time) + w.ProvisioningStartedAt = &t + } + } + return nil +} + +// Compile-time check that fakeStore satisfies WarehouseStore. +var _ WarehouseStore = (*fakeStore)(nil) + +func newFakeDucklingClient() (*DucklingClient, *dynamicfake.FakeDynamicClient) { + scheme := runtime.NewScheme() + scheme.AddKnownTypeWithName(schema.GroupVersionKind{ + Group: "k8s.posthog.com", + Version: "v1alpha1", + Kind: "Duckling", + }, &unstructured.Unstructured{}) + scheme.AddKnownTypeWithName(schema.GroupVersionKind{ + Group: "k8s.posthog.com", + Version: "v1alpha1", + Kind: "DucklingList", + }, &unstructured.UnstructuredList{}) + + fakeClient := dynamicfake.NewSimpleDynamicClient(scheme) + return NewDucklingClientWithDynamic(fakeClient), fakeClient +} + +func TestReconcilePendingCreatesCR(t *testing.T) { + dc, fakeK8s := newFakeDucklingClient() + fs := newFakeStore() + fs.warehouses["org-a"] = &configstore.ManagedWarehouse{ + OrgID: "org-a", + State: configstore.ManagedWarehouseStatePending, + AuroraMinACU: 0.5, + AuroraMaxACU: 2, + } + + ctrl := NewControllerWithClient(fs, dc, time.Second) + ctx := context.Background() + + ctrl.reconcile(ctx) + + // Verify CR was created + cr, err := fakeK8s.Resource(ducklingGVR).Namespace(ducklingNamespace).Get(ctx, "org-a", metav1.GetOptions{}) + if err != nil { + t.Fatalf("expected CR to exist: %v", err) + } + + spec, ok := cr.Object["spec"].(map[string]interface{}) + if !ok { + t.Fatal("expected spec in CR") + } + metadataStore, ok := spec["metadataStore"].(map[string]interface{}) + if !ok { + t.Fatal("expected metadataStore in spec") + } + if metadataStore["type"] != "aurora" { + t.Fatalf("expected metadataStore type aurora, got %v", metadataStore["type"]) + } + aurora, ok := metadataStore["aurora"].(map[string]interface{}) + if !ok { + t.Fatal("expected aurora in metadataStore") + } + if aurora["minACU"] != 0.5 { + t.Fatalf("expected minACU 0.5, got %v", aurora["minACU"]) + } + if aurora["maxACU"] != 2.0 { + t.Fatalf("expected maxACU 2, got %v", aurora["maxACU"]) + } + + // Verify state transitioned to provisioning + if fs.warehouses["org-a"].State != configstore.ManagedWarehouseStateProvisioning { + t.Fatalf("expected provisioning state, got %q", fs.warehouses["org-a"].State) + } + if fs.warehouses["org-a"].ProvisioningStartedAt == nil { + t.Fatal("expected provisioning_started_at to be set") + } +} + +func TestReconcileProvisioningAllReady(t *testing.T) { + dc, fakeK8s := newFakeDucklingClient() + fs := newFakeStore() + fs.warehouses["org-b"] = &configstore.ManagedWarehouse{ + OrgID: "org-b", + State: configstore.ManagedWarehouseStateProvisioning, + CreatedAt: time.Now(), + } + + // Create a Duckling CR with all status fields populated + cr := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "k8s.posthog.com/v1alpha1", + "kind": "Duckling", + "metadata": map[string]interface{}{ + "name": "org-b", + "namespace": ducklingNamespace, + }, + "status": map[string]interface{}{ + "metadataStore": map[string]interface{}{ + "type": "aurora", + "endpoint": "org-b.cluster.us-east-1.rds.amazonaws.com", + "password": "supersecret123", + "user": "postgres", + "database": "postgres", + }, + "dataStore": map[string]interface{}{ + "type": "s3bucket", + "bucketName": "org-b-bucket", + }, + "iamRoleArn": "arn:aws:iam::123456789012:role/duckling-org-b", + "conditions": []interface{}{ + map[string]interface{}{ + "type": "Ready", + "status": "True", + }, + map[string]interface{}{ + "type": "Synced", + "status": "True", + }, + }, + }, + }, + } + + ctx := context.Background() + _, err := fakeK8s.Resource(ducklingGVR).Namespace(ducklingNamespace).Create(ctx, cr, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create test CR: %v", err) + } + + ctrl := NewControllerWithClient(fs, dc, time.Second) + ctrl.reconcile(ctx) + + // Verify state transitioned to ready + w := fs.warehouses["org-b"] + if w.State != configstore.ManagedWarehouseStateReady { + t.Fatalf("expected ready state, got %q", w.State) + } + if w.S3.Bucket != "org-b-bucket" { + t.Fatalf("expected bucket org-b-bucket, got %q", w.S3.Bucket) + } + if w.MetadataStore.Endpoint != "org-b.cluster.us-east-1.rds.amazonaws.com" { + t.Fatalf("expected aurora endpoint, got %q", w.MetadataStore.Endpoint) + } + if w.MetadataStore.Port != 5432 { + t.Fatalf("expected aurora port 5432, got %d", w.MetadataStore.Port) + } + if w.MetadataStore.Username != "postgres" { + t.Fatalf("expected username postgres, got %q", w.MetadataStore.Username) + } + if w.MetadataStore.DatabaseName != "postgres" { + t.Fatalf("expected database_name postgres, got %q", w.MetadataStore.DatabaseName) + } + if w.WorkerIdentity.IAMRoleARN != "arn:aws:iam::123456789012:role/duckling-org-b" { + t.Fatalf("expected IAM role ARN, got %q", w.WorkerIdentity.IAMRoleARN) + } + if w.ReadyAt == nil { + t.Fatal("expected ready_at to be set") + } +} + +func TestReconcileDeletingDeletesCR(t *testing.T) { + dc, fakeK8s := newFakeDucklingClient() + fs := newFakeStore() + fs.warehouses["org-c"] = &configstore.ManagedWarehouse{ + OrgID: "org-c", + State: configstore.ManagedWarehouseStateDeleting, + } + ctx := context.Background() + + // Create a CR first + cr := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "k8s.posthog.com/v1alpha1", + "kind": "Duckling", + "metadata": map[string]interface{}{ + "name": "org-c", + "namespace": ducklingNamespace, + }, + }, + } + _, err := fakeK8s.Resource(ducklingGVR).Namespace(ducklingNamespace).Create(ctx, cr, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create test CR: %v", err) + } + + ctrl := NewControllerWithClient(fs, dc, time.Second) + ctrl.reconcile(ctx) + + // Verify CR is gone + _, err = fakeK8s.Resource(ducklingGVR).Namespace(ducklingNamespace).Get(ctx, "org-c", metav1.GetOptions{}) + if err == nil { + t.Fatal("expected CR to be deleted") + } + + // Verify state transitioned to deleted + if fs.warehouses["org-c"].State != configstore.ManagedWarehouseStateDeleted { + t.Fatalf("expected deleted state, got %q", fs.warehouses["org-c"].State) + } +} + +func TestReconcileDeletingRetriesOnNonNotFoundError(t *testing.T) { + // When the CR doesn't exist (NotFound), deleting should still succeed. + // When it's a different error, it should NOT transition to deleted. + dc, _ := newFakeDucklingClient() + fs := newFakeStore() + fs.warehouses["org-d"] = &configstore.ManagedWarehouse{ + OrgID: "org-d", + State: configstore.ManagedWarehouseStateDeleting, + } + ctx := context.Background() + + // Don't create a CR — the fake client will return NotFound on delete. + ctrl := NewControllerWithClient(fs, dc, time.Second) + ctrl.reconcile(ctx) + + // NotFound on delete is fine — should still transition to deleted + if fs.warehouses["org-d"].State != configstore.ManagedWarehouseStateDeleted { + t.Fatalf("expected deleted state on NotFound, got %q", fs.warehouses["org-d"].State) + } +} + +func TestParseDucklingStatusSyncedFalse(t *testing.T) { + cr := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "status": map[string]interface{}{ + "conditions": []interface{}{ + map[string]interface{}{ + "type": "Synced", + "status": "False", + "message": "cannot create Aurora cluster: InvalidParameterException", + }, + }, + }, + }, + } + + status, err := parseDucklingStatus(cr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status.SyncedFalseMessage != "cannot create Aurora cluster: InvalidParameterException" { + t.Fatalf("expected synced false message, got %q", status.SyncedFalseMessage) + } + if status.ReadyCondition { + t.Fatal("expected Ready to be false") + } +} + +func TestParseDucklingStatusEmpty(t *testing.T) { + cr := &unstructured.Unstructured{ + Object: map[string]interface{}{}, + } + + status, err := parseDucklingStatus(cr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status.DataStore.BucketName != "" || status.MetadataStore.Endpoint != "" || status.MetadataStore.Password != "" { + t.Fatal("expected empty status for CR without status field") + } +} + +func TestFakeStoreUpdateWarehouseState(t *testing.T) { + fs := newFakeStore() + fs.warehouses["org-x"] = &configstore.ManagedWarehouse{ + OrgID: "org-x", + State: configstore.ManagedWarehouseStatePending, + } + + // CAS update should succeed + err := fs.UpdateWarehouseState("org-x", configstore.ManagedWarehouseStatePending, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateProvisioning, + "status_message": "transitioning", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fs.warehouses["org-x"].State != configstore.ManagedWarehouseStateProvisioning { + t.Fatalf("expected provisioning state, got %q", fs.warehouses["org-x"].State) + } + + // CAS update with wrong expected state should be no-op + err = fs.UpdateWarehouseState("org-x", configstore.ManagedWarehouseStatePending, map[string]interface{}{ + "state": configstore.ManagedWarehouseStateFailed, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fs.warehouses["org-x"].State != configstore.ManagedWarehouseStateProvisioning { + t.Fatalf("expected state to remain provisioning, got %q", fs.warehouses["org-x"].State) + } +} diff --git a/controlplane/provisioner/k8s_client.go b/controlplane/provisioner/k8s_client.go new file mode 100644 index 0000000..6221564 --- /dev/null +++ b/controlplane/provisioner/k8s_client.go @@ -0,0 +1,168 @@ +//go:build kubernetes + +package provisioner + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/rest" +) + +var ducklingGVR = schema.GroupVersionResource{ + Group: "k8s.posthog.com", + Version: "v1alpha1", + Resource: "ducklings", +} + +const ducklingNamespace = "crossplane-system" + +// DucklingStatus holds the parsed status from a Duckling CR. +// The Duckling composition provisions AWS infrastructure (Aurora, S3, IAM) +// but not K8s workloads — those are managed by the duckgres Helm chart. +type DucklingStatus struct { + MetadataStore struct { + Type string + Endpoint string + Password string + User string + Database string + } + DataStore struct { + Type string + BucketName string + } + IAMRoleARN string + ReadyCondition bool + SyncedFalseMessage string +} + +// DucklingClient wraps a Kubernetes dynamic client for Duckling CR operations. +type DucklingClient struct { + client dynamic.Interface +} + +// NewDucklingClient creates a DucklingClient using in-cluster config. +func NewDucklingClient() (*DucklingClient, error) { + config, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("in-cluster config: %w", err) + } + dc, err := dynamic.NewForConfig(config) + if err != nil { + return nil, fmt.Errorf("dynamic client: %w", err) + } + return &DucklingClient{client: dc}, nil +} + +// NewDucklingClientWithDynamic creates a DucklingClient with a provided dynamic.Interface (for testing). +func NewDucklingClientWithDynamic(client dynamic.Interface) *DucklingClient { + return &DucklingClient{client: client} +} + +// Create creates a Duckling CR for the given org. +func (d *DucklingClient) Create(ctx context.Context, orgID string, minACU, maxACU float64) error { + cr := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "k8s.posthog.com/v1alpha1", + "kind": "Duckling", + "metadata": map[string]interface{}{ + "name": orgID, + "namespace": ducklingNamespace, + }, + "spec": map[string]interface{}{ + "metadataStore": map[string]interface{}{ + "type": "aurora", + "aurora": map[string]interface{}{ + "minACU": minACU, + "maxACU": maxACU, + }, + }, + "dataStore": map[string]interface{}{ + "type": "s3bucket", + }, + }, + }, + } + + _, err := d.client.Resource(ducklingGVR).Namespace(ducklingNamespace).Create(ctx, cr, metav1.CreateOptions{}) + if err != nil { + return fmt.Errorf("create duckling CR %q: %w", orgID, err) + } + return nil +} + +// Get fetches the Duckling CR and parses its status. +func (d *DucklingClient) Get(ctx context.Context, orgID string) (*DucklingStatus, error) { + cr, err := d.client.Resource(ducklingGVR).Namespace(ducklingNamespace).Get(ctx, orgID, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("get duckling CR %q: %w", orgID, err) + } + return parseDucklingStatus(cr) +} + +// Delete removes the Duckling CR for the given org. +func (d *DucklingClient) Delete(ctx context.Context, orgID string) error { + err := d.client.Resource(ducklingGVR).Namespace(ducklingNamespace).Delete(ctx, orgID, metav1.DeleteOptions{}) + if err != nil { + return fmt.Errorf("delete duckling CR %q: %w", orgID, err) + } + return nil +} + +func parseDucklingStatus(cr *unstructured.Unstructured) (*DucklingStatus, error) { + status, ok := cr.Object["status"].(map[string]interface{}) + if !ok { + return &DucklingStatus{}, nil + } + + ds := &DucklingStatus{ + IAMRoleARN: getNestedString(status, "iamRoleArn"), + } + + // Parse status.metadataStore + if md, ok := status["metadataStore"].(map[string]interface{}); ok { + ds.MetadataStore.Type = getNestedString(md, "type") + ds.MetadataStore.Endpoint = getNestedString(md, "endpoint") + ds.MetadataStore.Password = getNestedString(md, "password") + ds.MetadataStore.User = getNestedString(md, "user") + ds.MetadataStore.Database = getNestedString(md, "database") + } + + // Parse status.dataStore + if store, ok := status["dataStore"].(map[string]interface{}); ok { + ds.DataStore.Type = getNestedString(store, "type") + ds.DataStore.BucketName = getNestedString(store, "bucketName") + } + + // Parse conditions + conditions, _ := status["conditions"].([]interface{}) + for _, cond := range conditions { + condMap, ok := cond.(map[string]interface{}) + if !ok { + continue + } + condType := getNestedString(condMap, "type") + condStatus := getNestedString(condMap, "status") + + switch condType { + case "Ready": + ds.ReadyCondition = condStatus == "True" + case "Synced": + if condStatus == "False" { + ds.SyncedFalseMessage = getNestedString(condMap, "message") + } + } + } + + return ds, nil +} + +func getNestedString(obj map[string]interface{}, key string) string { + v, _ := obj[key].(string) + return v +} diff --git a/controlplane/provisioning/api.go b/controlplane/provisioning/api.go new file mode 100644 index 0000000..9a8b757 --- /dev/null +++ b/controlplane/provisioning/api.go @@ -0,0 +1,137 @@ +package provisioning + +import ( + "errors" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/posthog/duckgres/controlplane/configstore" + "gorm.io/gorm" +) + +// Store defines the config store operations needed by the provisioning API. +type Store interface { + GetManagedWarehouse(orgID string) (*configstore.ManagedWarehouse, error) + CreatePendingWarehouse(orgID string, warehouse *configstore.ManagedWarehouse) error + SetWarehouseDeleting(orgID string, expectedState configstore.ManagedWarehouseProvisioningState) error +} + +// RegisterAPI registers provisioning endpoints on the given router group. +func RegisterAPI(r *gin.RouterGroup, store Store) { + h := &handler{store: store} + r.POST("/orgs/:id/provision", h.provisionWarehouse) + r.POST("/orgs/:id/deprovision", h.deprovisionWarehouse) + r.GET("/orgs/:id/warehouse/status", h.getWarehouseStatus) +} + +type handler struct { + store Store +} + +// warehouseStatusResponse is the public-facing view of warehouse state. +// Only exposes lifecycle status — no infrastructure secrets or internal config. +type warehouseStatusResponse struct { + OrgID string `json:"org_id"` + State configstore.ManagedWarehouseProvisioningState `json:"state"` + StatusMessage string `json:"status_message"` + S3State configstore.ManagedWarehouseProvisioningState `json:"s3_state"` + MetadataStoreState configstore.ManagedWarehouseProvisioningState `json:"metadata_store_state"` + IdentityState configstore.ManagedWarehouseProvisioningState `json:"identity_state"` + SecretsState configstore.ManagedWarehouseProvisioningState `json:"secrets_state"` + ReadyAt *time.Time `json:"ready_at,omitempty"` + FailedAt *time.Time `json:"failed_at,omitempty"` +} + +type provisionRequest struct { + MetadataStore *provisionMetadataReq `json:"metadata_store,omitempty"` +} + +type provisionMetadataReq struct { + Type string `json:"type"` + Aurora *provisionAuroraReq `json:"aurora,omitempty"` +} + +type provisionAuroraReq struct { + MinACU float64 `json:"min_acu"` + MaxACU float64 `json:"max_acu"` +} + +func (h *handler) provisionWarehouse(c *gin.Context) { + orgID := c.Param("id") + + var req provisionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.MetadataStore == nil || req.MetadataStore.Aurora == nil || req.MetadataStore.Aurora.MaxACU <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "metadata_store.aurora.max_acu must be greater than 0"}) + return + } + + warehouse := &configstore.ManagedWarehouse{ + AuroraMinACU: req.MetadataStore.Aurora.MinACU, + AuroraMaxACU: req.MetadataStore.Aurora.MaxACU, + } + + if err := h.store.CreatePendingWarehouse(orgID, warehouse); err != nil { + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusAccepted, gin.H{"status": "provisioning started", "org": orgID}) +} + +func (h *handler) deprovisionWarehouse(c *gin.Context) { + orgID := c.Param("id") + + // Try CAS from each deprovisionable state. Order doesn't matter — + // only one will match. This avoids a read-then-write TOCTOU race. + deprovisionableStates := []configstore.ManagedWarehouseProvisioningState{ + configstore.ManagedWarehouseStateReady, + configstore.ManagedWarehouseStateFailed, + configstore.ManagedWarehouseStateProvisioning, + } + + var err error + for _, state := range deprovisionableStates { + if err = h.store.SetWarehouseDeleting(orgID, state); err == nil { + c.JSON(http.StatusAccepted, gin.H{"status": "deprovisioning started", "org": orgID}) + return + } + } + + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "warehouse not found"}) + return + } + c.JSON(http.StatusConflict, gin.H{"error": "warehouse must be in ready, failed, or provisioning state to deprovision"}) +} + +func (h *handler) getWarehouseStatus(c *gin.Context) { + orgID := c.Param("id") + + warehouse, err := h.store.GetManagedWarehouse(orgID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "warehouse not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, warehouseStatusResponse{ + OrgID: warehouse.OrgID, + State: warehouse.State, + StatusMessage: warehouse.StatusMessage, + S3State: warehouse.S3State, + MetadataStoreState: warehouse.MetadataStoreState, + IdentityState: warehouse.IdentityState, + SecretsState: warehouse.SecretsState, + ReadyAt: warehouse.ReadyAt, + FailedAt: warehouse.FailedAt, + }) +} diff --git a/controlplane/provisioning/api_test.go b/controlplane/provisioning/api_test.go new file mode 100644 index 0000000..b1d9dcc --- /dev/null +++ b/controlplane/provisioning/api_test.go @@ -0,0 +1,331 @@ +package provisioning + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/posthog/duckgres/controlplane/configstore" + "gorm.io/gorm" +) + +type fakeStore struct { + orgs map[string]*configstore.Org + warehouses map[string]*configstore.ManagedWarehouse +} + +func newFakeStore() *fakeStore { + return &fakeStore{ + orgs: make(map[string]*configstore.Org), + warehouses: make(map[string]*configstore.ManagedWarehouse), + } +} + +func (s *fakeStore) GetManagedWarehouse(orgID string) (*configstore.ManagedWarehouse, error) { + w, ok := s.warehouses[orgID] + if !ok { + return nil, gorm.ErrRecordNotFound + } + clone := *w + return &clone, nil +} + +func (s *fakeStore) CreatePendingWarehouse(orgID string, warehouse *configstore.ManagedWarehouse) error { + // Auto-create org if needed (mirrors production behavior) + if _, ok := s.orgs[orgID]; !ok { + s.orgs[orgID] = &configstore.Org{Name: orgID} + } + existing, ok := s.warehouses[orgID] + if ok && existing.State != configstore.ManagedWarehouseStateFailed && existing.State != configstore.ManagedWarehouseStateDeleted { + return errors.New("warehouse already exists in non-terminal state") + } + clone := *warehouse + clone.OrgID = orgID + clone.State = configstore.ManagedWarehouseStatePending + clone.WarehouseDatabaseState = configstore.ManagedWarehouseStatePending + clone.MetadataStoreState = configstore.ManagedWarehouseStatePending + clone.S3State = configstore.ManagedWarehouseStatePending + clone.IdentityState = configstore.ManagedWarehouseStatePending + clone.SecretsState = configstore.ManagedWarehouseStatePending + s.warehouses[orgID] = &clone + return nil +} + +func (s *fakeStore) SetWarehouseDeleting(orgID string, expectedState configstore.ManagedWarehouseProvisioningState) error { + w, ok := s.warehouses[orgID] + if !ok { + return gorm.ErrRecordNotFound + } + if w.State != expectedState { + return fmt.Errorf("warehouse %q not in expected state %q", orgID, expectedState) + } + w.State = configstore.ManagedWarehouseStateDeleting + return nil +} + +func newTestRouter(store Store) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + RegisterAPI(r.Group("/api/v1"), store) + return r +} + +func TestProvisionCreatesWarehouse(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + router := newTestRouter(store) + + body := []byte(`{ + "metadata_store": { + "type": "aurora", + "aurora": {"min_acu": 0.5, "max_acu": 2} + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + + w := store.warehouses["analytics"] + if w == nil { + t.Fatal("expected warehouse to be created") + } + if w.State != configstore.ManagedWarehouseStatePending { + t.Fatalf("expected state pending, got %q", w.State) + } + if w.AuroraMinACU != 0.5 { + t.Fatalf("expected min_acu 0.5, got %f", w.AuroraMinACU) + } + if w.AuroraMaxACU != 2 { + t.Fatalf("expected max_acu 2, got %f", w.AuroraMaxACU) + } +} + +func TestProvisionAutoCreatesOrg(t *testing.T) { + store := newFakeStore() + router := newTestRouter(store) + + body := []byte(`{"metadata_store": {"type": "aurora", "aurora": {"max_acu": 1}}}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/new-org/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + if _, ok := store.orgs["new-org"]; !ok { + t.Fatal("expected org to be auto-created") + } + if store.warehouses["new-org"] == nil { + t.Fatal("expected warehouse to be created") + } +} + +func TestProvisionRejectsEmptyBody(t *testing.T) { + store := newFakeStore() + router := newTestRouter(store) + + body := []byte(`{}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} + +func TestProvisionRejectsZeroMaxACU(t *testing.T) { + store := newFakeStore() + router := newTestRouter(store) + + body := []byte(`{"metadata_store": {"type": "aurora", "aurora": {"min_acu": 0.5, "max_acu": 0}}}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} + +func TestProvisionRejectsExistingNonTerminal(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateProvisioning, + } + router := newTestRouter(store) + + body := []byte(`{"metadata_store": {"type": "aurora", "aurora": {"max_acu": 1}}}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusConflict, rec.Body.String()) + } +} + +func TestProvisionAllowsRetryAfterFailure(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateFailed, + } + router := newTestRouter(store) + + body := []byte(`{"metadata_store": {"type": "aurora", "aurora": {"min_acu": 0, "max_acu": 2}}}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/provision", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + if store.warehouses["analytics"].AuroraMaxACU != 2 { + t.Fatalf("expected max_acu 2, got %f", store.warehouses["analytics"].AuroraMaxACU) + } +} + +func TestDeprovisionReadyWarehouse(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateReady, + } + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/deprovision", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + if store.warehouses["analytics"].State != configstore.ManagedWarehouseStateDeleting { + t.Fatalf("expected deleting state, got %q", store.warehouses["analytics"].State) + } +} + +func TestDeprovisionFailedWarehouse(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateFailed, + } + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/deprovision", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + if store.warehouses["analytics"].State != configstore.ManagedWarehouseStateDeleting { + t.Fatalf("expected deleting state, got %q", store.warehouses["analytics"].State) + } +} + +func TestDeprovisionProvisioningWarehouse(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateProvisioning, + } + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/deprovision", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusAccepted, rec.Body.String()) + } + if store.warehouses["analytics"].State != configstore.ManagedWarehouseStateDeleting { + t.Fatalf("expected deleting state, got %q", store.warehouses["analytics"].State) + } +} + +func TestDeprovisionRejectsPendingWarehouse(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStatePending, + } + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/orgs/analytics/deprovision", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusConflict, rec.Body.String()) + } +} + +func TestGetWarehouseStatus(t *testing.T) { + store := newFakeStore() + store.orgs["analytics"] = &configstore.Org{Name: "analytics"} + store.warehouses["analytics"] = &configstore.ManagedWarehouse{ + OrgID: "analytics", + State: configstore.ManagedWarehouseStateProvisioning, + S3State: configstore.ManagedWarehouseStateReady, + MetadataStoreState: configstore.ManagedWarehouseStatePending, + } + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/orgs/analytics/warehouse/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp warehouseStatusResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.State != configstore.ManagedWarehouseStateProvisioning { + t.Fatalf("expected provisioning state, got %q", resp.State) + } + if resp.S3State != configstore.ManagedWarehouseStateReady { + t.Fatalf("expected s3 ready, got %q", resp.S3State) + } +} + +func TestGetWarehouseNotFound(t *testing.T) { + store := newFakeStore() + router := newTestRouter(store) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/orgs/unknown/warehouse/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusNotFound, rec.Body.String()) + } +} diff --git a/controlplane/provisioning/store.go b/controlplane/provisioning/store.go new file mode 100644 index 0000000..cedb213 --- /dev/null +++ b/controlplane/provisioning/store.go @@ -0,0 +1,82 @@ +package provisioning + +import ( + "errors" + "fmt" + + "github.com/posthog/duckgres/controlplane/configstore" + "gorm.io/gorm" +) + +// gormStore implements Store using a ConfigStore's GORM DB. +type gormStore struct { + cs *configstore.ConfigStore +} + +// NewGormStore creates a Store backed by the given ConfigStore. +func NewGormStore(cs *configstore.ConfigStore) Store { + return &gormStore{cs: cs} +} + +func (s *gormStore) GetManagedWarehouse(orgID string) (*configstore.ManagedWarehouse, error) { + var warehouse configstore.ManagedWarehouse + if err := s.cs.DB().First(&warehouse, "org_id = ?", orgID).Error; err != nil { + return nil, err + } + return &warehouse, nil +} + +func (s *gormStore) CreatePendingWarehouse(orgID string, warehouse *configstore.ManagedWarehouse) error { + return s.cs.DB().Transaction(func(tx *gorm.DB) error { + // Auto-create org if it doesn't exist (PostHog calls provision, duckgres creates everything) + org := configstore.Org{Name: orgID} + if err := tx.Where("name = ?", orgID).FirstOrCreate(&org).Error; err != nil { + return err + } + + // Check for existing warehouse in non-terminal state + var existing configstore.ManagedWarehouse + err := tx.First(&existing, "org_id = ?", orgID).Error + if err == nil { + if existing.State != configstore.ManagedWarehouseStateFailed && + existing.State != configstore.ManagedWarehouseStateDeleted { + return errors.New("warehouse already exists in non-terminal state") + } + if err := tx.Delete(&existing).Error; err != nil { + return err + } + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + warehouse.OrgID = orgID + warehouse.State = configstore.ManagedWarehouseStatePending + warehouse.WarehouseDatabaseState = configstore.ManagedWarehouseStatePending + warehouse.MetadataStoreState = configstore.ManagedWarehouseStatePending + warehouse.S3State = configstore.ManagedWarehouseStatePending + warehouse.IdentityState = configstore.ManagedWarehouseStatePending + warehouse.SecretsState = configstore.ManagedWarehouseStatePending + return tx.Create(warehouse).Error + }) +} + +// SetWarehouseDeleting atomically transitions a warehouse from expectedState to deleting. +// Returns gorm.ErrRecordNotFound if no warehouse exists, or an error if the CAS fails. +func (s *gormStore) SetWarehouseDeleting(orgID string, expectedState configstore.ManagedWarehouseProvisioningState) error { + result := s.cs.DB().Model(&configstore.ManagedWarehouse{}). + Where("org_id = ? AND state = ?", orgID, expectedState). + Update("state", configstore.ManagedWarehouseStateDeleting) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + // Distinguish "not found" from "wrong state" + var count int64 + s.cs.DB().Model(&configstore.ManagedWarehouse{}).Where("org_id = ?", orgID).Count(&count) + if count == 0 { + return gorm.ErrRecordNotFound + } + return fmt.Errorf("warehouse %q not in expected state %q", orgID, expectedState) + } + return nil +} diff --git a/controlplane/shared_worker_activator.go b/controlplane/shared_worker_activator.go index 4d9da0a..050ec81 100644 --- a/controlplane/shared_worker_activator.go +++ b/controlplane/shared_worker_activator.go @@ -19,6 +19,7 @@ import ( type SharedWorkerActivator struct { clientset kubernetes.Interface defaultNamespace string + stsBroker *STSBroker resolveOrgConfig func(string) (*configstore.OrgConfig, error) activateReservedWorker func(context.Context, *ManagedWorker, TenantActivationPayload) error } @@ -30,13 +31,14 @@ type TenantActivationPayload struct { DuckLake server.DuckLakeConfig `json:"ducklake"` } -func NewSharedWorkerActivator(shared *K8sWorkerPool, resolveOrgConfig func(string) (*configstore.OrgConfig, error)) *SharedWorkerActivator { +func NewSharedWorkerActivator(shared *K8sWorkerPool, stsBroker *STSBroker, resolveOrgConfig func(string) (*configstore.OrgConfig, error)) *SharedWorkerActivator { if shared == nil { return nil } return &SharedWorkerActivator{ clientset: shared.clientset, defaultNamespace: shared.namespace, + stsBroker: stsBroker, resolveOrgConfig: resolveOrgConfig, activateReservedWorker: shared.ActivateReservedWorker, } @@ -110,7 +112,19 @@ func (a *SharedWorkerActivator) BuildActivationRequest(ctx context.Context, org dl.S3AccessKey = accessKey dl.S3SecretKey = secretKey case strings.EqualFold(warehouse.S3.Provider, "aws"): - dl.S3Provider = "aws_sdk" + roleARN := warehouse.WorkerIdentity.IAMRoleARN + if roleARN != "" && a.stsBroker != nil { + creds, err := a.stsBroker.AssumeRole(ctx, roleARN) + if err != nil { + return TenantActivationPayload{}, fmt.Errorf("STS AssumeRole for org %q: %w", orgName(org), err) + } + dl.S3Provider = "config" + dl.S3AccessKey = creds.AccessKeyID + dl.S3SecretKey = creds.SecretAccessKey + dl.S3SessionToken = creds.SessionToken + } else { + dl.S3Provider = "aws_sdk" + } } usernames := make([]string, 0, len(org.Users)) @@ -127,10 +141,11 @@ func (a *SharedWorkerActivator) BuildActivationRequest(ctx context.Context, org }, nil } -func BuildTenantActivationPayload(ctx context.Context, clientset kubernetes.Interface, defaultNamespace string, org *configstore.OrgConfig) (TenantActivationPayload, error) { +func BuildTenantActivationPayload(ctx context.Context, clientset kubernetes.Interface, defaultNamespace string, org *configstore.OrgConfig, stsBroker *STSBroker) (TenantActivationPayload, error) { activator := &SharedWorkerActivator{ clientset: clientset, defaultNamespace: defaultNamespace, + stsBroker: stsBroker, } assignment := &WorkerAssignment{ OrgID: orgName(org), diff --git a/controlplane/sts_broker.go b/controlplane/sts_broker.go new file mode 100644 index 0000000..6a9682a --- /dev/null +++ b/controlplane/sts_broker.go @@ -0,0 +1,69 @@ +//go:build kubernetes + +package controlplane + +import ( + "context" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const ( + stsSessionDuration = 1 * time.Hour + stsSessionName = "duckgres-cp" +) + +// STSBroker brokers short-lived AWS credentials by assuming per-org IAM roles. +type STSBroker struct { + client *sts.Client +} + +// AssumedCredentials holds the temporary credentials from STS AssumeRole. +type AssumedCredentials struct { + AccessKeyID string + SecretAccessKey string + SessionToken string + Expiration time.Time +} + +// NewSTSBroker creates an STS broker using the control plane's own credentials. +func NewSTSBroker(ctx context.Context, region string) (*STSBroker, error) { + opts := []func(*awsconfig.LoadOptions) error{} + if region != "" { + opts = append(opts, awsconfig.WithRegion(region)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("load AWS config: %w", err) + } + return &STSBroker{ + client: sts.NewFromConfig(cfg), + }, nil +} + +// AssumeRole mints short-lived credentials for the given IAM role ARN. +func (b *STSBroker) AssumeRole(ctx context.Context, roleARN string) (*AssumedCredentials, error) { + durationSeconds := int32(stsSessionDuration.Seconds()) + sessionName := stsSessionName + out, err := b.client.AssumeRole(ctx, &sts.AssumeRoleInput{ + RoleArn: aws.String(roleARN), + RoleSessionName: aws.String(sessionName), + DurationSeconds: &durationSeconds, + }) + if err != nil { + return nil, fmt.Errorf("STS AssumeRole %s: %w", roleARN, err) + } + if out.Credentials == nil { + return nil, fmt.Errorf("STS AssumeRole returned nil credentials for %s", roleARN) + } + return &AssumedCredentials{ + AccessKeyID: aws.ToString(out.Credentials.AccessKeyId), + SecretAccessKey: aws.ToString(out.Credentials.SecretAccessKey), + SessionToken: aws.ToString(out.Credentials.SessionToken), + Expiration: aws.ToTime(out.Credentials.Expiration), + }, nil +} diff --git a/controlplane/warm_pool_metrics_test.go b/controlplane/warm_pool_metrics_test.go index c0a77b0..f9a4cfa 100644 --- a/controlplane/warm_pool_metrics_test.go +++ b/controlplane/warm_pool_metrics_test.go @@ -157,7 +157,7 @@ func TestActivateWorkerForOrgUpdatesActivatingGauge(t *testing.T) { pool.workers[1] = worker observeWarmPoolLifecycleGauges(pool.workers) - orgPool := NewOrgReservedPool(pool, "org-1", 1) + orgPool := NewOrgReservedPool(pool, "org-1", 1, nil) orgPool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { assertGaugeValue(t, reservedWorkersGauge, 0) assertGaugeValue(t, activatingWorkersGauge, 1) @@ -179,7 +179,7 @@ func TestActivateWorkerForOrgRecordsActivationDurationWhenWorkerAlreadyHot(t *te worker.reservedAt = time.Now().Add(-2 * time.Second) pool.workers[1] = worker - orgPool := NewOrgReservedPool(pool, "org-1", 1) + orgPool := NewOrgReservedPool(pool, "org-1", 1, nil) orgPool.activateReservedWorker = func(ctx context.Context, worker *ManagedWorker) error { nextState, err := worker.SharedState().Transition(WorkerLifecycleHot, nil) if err != nil { diff --git a/justfile b/justfile index 3f0a3ff..740342f 100644 --- a/justfile +++ b/justfile @@ -175,7 +175,7 @@ deploy-multitenant-kind: KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl apply -f k8s/managed-warehouse-secrets.yaml KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl apply -f k8s/networkpolicy.yaml KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl apply -f k8s/kind/control-plane.yaml - KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl -n duckgres wait deployment/duckgres-control-plane --for=condition=available --timeout=120s + KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl -n duckgres wait deployment/duckgres-control-plane --for=condition=available --timeout=120s || { echo "=== Pod status ==="; KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl -n duckgres get pods -o wide; echo "=== Pod describe ==="; KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl -n duckgres describe pod -l app=duckgres-control-plane; echo "=== Pod logs ==="; KUBECONFIG="${DUCKGRES_KIND_KUBECONFIG:-/tmp/duckgres-kind-kubeconfig}" kubectl -n duckgres logs -l app=duckgres-control-plane --tail=100 --all-containers; exit 1; } # End-to-end local multi-tenant setup: optional OrbStack K8s + config store + control plane [group('dev')] @@ -185,7 +185,7 @@ run-multitenant-local: multitenant-config-store-up build-k8s-image deploy-multit @echo "Multi-tenant control plane ready." @echo "Default login: postgres / postgres" @echo "Fetch admin token with: kubectl -n duckgres logs deployment/duckgres-control-plane | rg 'Generated admin API token'" - @echo "Run 'just multitenant-port-forward-pg' in one terminal and 'just multitenant-port-forward-admin' in another." + @echo "Run 'just multitenant-port-forward-pg' and 'just multitenant-port-forward-api' in separate terminals." # End-to-end local multi-tenant setup: kind K8s + config store + control plane [group('dev')] @@ -215,10 +215,10 @@ cleanup-multitenant-kind: multitenant-port-forward-pg: kubectl -n duckgres port-forward svc/duckgres 5432:5432 -# Port-forward the admin dashboard and API from the local control plane +# Port-forward the API server (admin + provisioning) from the local control plane [group('dev')] -multitenant-port-forward-admin: - kubectl -n duckgres port-forward deployment/duckgres-control-plane 9090:9090 +multitenant-port-forward-api: + kubectl -n duckgres port-forward deployment/duckgres-control-plane 8080:8080 # Run with DuckLake config [group('dev')] diff --git a/k8s/control-plane-multitenant-local.yaml b/k8s/control-plane-multitenant-local.yaml index 53f2ff7..16f4cee 100644 --- a/k8s/control-plane-multitenant-local.yaml +++ b/k8s/control-plane-multitenant-local.yaml @@ -61,8 +61,8 @@ spec: - name: flight containerPort: 8815 protocol: TCP - - name: admin - containerPort: 9090 + - name: api + containerPort: 8080 protocol: TCP volumeMounts: - name: config @@ -75,7 +75,7 @@ spec: readinessProbe: httpGet: path: /health - port: admin + port: api initialDelaySeconds: 2 periodSeconds: 2 failureThreshold: 15 diff --git a/k8s/kind/config-store.seed.sql b/k8s/kind/config-store.seed.sql index 1dc6d0f..db86c4d 100644 --- a/k8s/kind/config-store.seed.sql +++ b/k8s/kind/config-store.seed.sql @@ -5,6 +5,9 @@ SET updated_at = NOW(); INSERT INTO duckgres_managed_warehouses ( org_id, + image, + aurora_min_acu, + aurora_max_acu, warehouse_database_region, warehouse_database_endpoint, warehouse_database_port, @@ -58,6 +61,9 @@ INSERT INTO duckgres_managed_warehouses ( ) VALUES ( 'local', + '', + 0, + 0, 'kind-dev', 'duckgres-local-warehouse-db', 5432, diff --git a/k8s/kind/control-plane.yaml b/k8s/kind/control-plane.yaml index 9736446..2a1334c 100644 --- a/k8s/kind/control-plane.yaml +++ b/k8s/kind/control-plane.yaml @@ -52,8 +52,8 @@ spec: - name: pg containerPort: 5432 protocol: TCP - - name: admin - containerPort: 9090 + - name: api + containerPort: 8080 protocol: TCP volumeMounts: - name: config @@ -66,7 +66,7 @@ spec: readinessProbe: httpGet: path: /health - port: admin + port: api initialDelaySeconds: 2 periodSeconds: 2 failureThreshold: 15 @@ -102,5 +102,9 @@ spec: port: 5432 targetPort: pg protocol: TCP + - name: api + port: 8080 + targetPort: api + protocol: TCP selector: app: duckgres-control-plane diff --git a/main.go b/main.go index a5ca480..d8712e7 100644 --- a/main.go +++ b/main.go @@ -153,9 +153,6 @@ func env(key, defaultVal string) string { func initMetrics() *http.Server { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) - mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }) srv := &http.Server{ Addr: ":9090", Handler: mux, @@ -239,11 +236,12 @@ func main() { k8sWorkerServiceAccount := flag.String("k8s-worker-service-account", "", "ServiceAccount name for K8s worker pods (env: DUCKGRES_K8S_WORKER_SERVICE_ACCOUNT)") k8sMaxWorkers := flag.Int("k8s-max-workers", 0, "Max K8s workers in the shared pool, 0=auto-derived (env: DUCKGRES_K8S_MAX_WORKERS)") k8sSharedWarmTarget := flag.Int("k8s-shared-warm-target", 0, "Neutral shared warm-worker target for K8s multi-tenant mode, 0=disabled (env: DUCKGRES_K8S_SHARED_WARM_TARGET)") + awsRegion := flag.String("aws-region", "", "AWS region for STS client (env: DUCKGRES_AWS_REGION)") // Config store flags (multi-tenant mode) configStore := flag.String("config-store", "", "PostgreSQL connection string for config store (env: DUCKGRES_CONFIG_STORE)") configPollInterval := flag.String("config-poll-interval", "", "How often to poll config store for changes (default: 30s) (env: DUCKGRES_CONFIG_POLL_INTERVAL)") - adminToken := flag.String("admin-token", "", "Bearer token for admin API authentication (env: DUCKGRES_ADMIN_TOKEN)") + internalSecret := flag.String("internal-secret", "", "Shared secret for API authentication (env: DUCKGRES_INTERNAL_SECRET)") // ACME/Let's Encrypt flags acmeDomain := flag.String("acme-domain", "", "Domain for ACME/Let's Encrypt certificate (env: DUCKGRES_ACME_DOMAIN)") @@ -299,9 +297,10 @@ func main() { fmt.Fprintf(os.Stderr, " DUCKGRES_DUCKDB_MAX_SESSIONS DuckDB service max sessions (duckdb-service mode)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG_STORE PostgreSQL connection string for config store (multi-tenant)\n") fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG_POLL_INTERVAL Config store poll interval (default: 30s)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_ADMIN_TOKEN Bearer token for admin API authentication\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_INTERNAL_SECRET Shared secret for API authentication\n") fmt.Fprintf(os.Stderr, " DUCKGRES_K8S_MAX_WORKERS Max K8s workers in the shared pool\n") fmt.Fprintf(os.Stderr, " DUCKGRES_K8S_SHARED_WARM_TARGET Neutral shared warm-worker target for K8s multi-tenant mode\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_AWS_REGION AWS region for STS client\n") fmt.Fprintf(os.Stderr, " DUCKGRES_LOG_LEVEL Log level: debug, info, warn, error (default: info)\n") fmt.Fprintf(os.Stderr, "\nPrecedence: CLI flags > environment variables > config file > defaults\n") } @@ -402,7 +401,7 @@ func main() { MaxConnections: *maxConnections, ConfigStoreConn: *configStore, ConfigPollInterval: *configPollInterval, - AdminToken: *adminToken, + InternalSecret: *internalSecret, WorkerBackend: *workerBackend, K8sWorkerImage: *k8sWorkerImage, K8sWorkerNamespace: *k8sWorkerNamespace, @@ -414,6 +413,7 @@ func main() { K8sWorkerServiceAccount: *k8sWorkerServiceAccount, K8sMaxWorkers: *k8sMaxWorkers, K8sSharedWarmTarget: *k8sSharedWarmTarget, + AWSRegion: *awsRegion, QueryLog: *queryLog, }, os.Getenv, func(msg string) { slog.Warn(msg) @@ -546,7 +546,7 @@ func main() { WorkerBackend: resolved.WorkerBackend, ConfigStoreConn: resolved.ConfigStoreConn, ConfigPollInterval: resolved.ConfigPollInterval, - AdminToken: resolved.AdminToken, + InternalSecret: resolved.InternalSecret, K8s: controlplane.K8sConfig{ WorkerImage: resolved.K8sWorkerImage, WorkerNamespace: resolved.K8sWorkerNamespace, @@ -558,6 +558,7 @@ func main() { ServiceAccount: resolved.K8sWorkerServiceAccount, MaxWorkers: resolved.K8sMaxWorkers, SharedWarmTarget: resolved.K8sSharedWarmTarget, + AWSRegion: resolved.AWSRegion, }, } controlplane.RunControlPlane(cpCfg) diff --git a/server/server.go b/server/server.go index 070a554..a463766 100644 --- a/server/server.go +++ b/server/server.go @@ -219,12 +219,13 @@ type DuckLakeConfig struct { S3Provider string // S3 configuration for "config" provider (explicit credentials for MinIO or S3) - S3Endpoint string // e.g., "localhost:9000" for MinIO - S3AccessKey string // S3 access key ID - S3SecretKey string // S3 secret access key - S3Region string // S3 region (default: us-east-1) - S3UseSSL bool // Use HTTPS for S3 connections (default: false for MinIO) - S3URLStyle string // "path" or "vhost" (default: "path" for MinIO compatibility) + S3Endpoint string // e.g., "localhost:9000" for MinIO + S3AccessKey string // S3 access key ID + S3SecretKey string // S3 secret access key + S3SessionToken string // STS session token for temporary credentials + S3Region string // S3 region (default: us-east-1) + S3UseSSL bool // Use HTTPS for S3 connections (default: false for MinIO) + S3URLStyle string // "path" or "vhost" (default: "path" for MinIO compatibility) // S3 configuration for "credential_chain" provider (AWS SDK credential chain) // Chain specifies which credential sources to check, semicolon-separated @@ -1215,6 +1216,10 @@ func buildConfigSecret(dlCfg DuckLakeConfig) string { secret += fmt.Sprintf(",\n\t\t\tENDPOINT '%s'", dlCfg.S3Endpoint) } + if dlCfg.S3SessionToken != "" { + secret += fmt.Sprintf(",\n\t\t\tSESSION_TOKEN '%s'", dlCfg.S3SessionToken) + } + secret += "\n\t\t)" return secret } @@ -1357,7 +1362,7 @@ func needsCredentialRefresh(dlCfg DuckLakeConfig) bool { return false } p := s3ProviderForConfig(dlCfg) - return p == "credential_chain" || p == "aws_sdk" + return p == "credential_chain" || p == "aws_sdk" || dlCfg.S3SessionToken != "" } // isTransactionAborted returns true if the error indicates DuckDB's connection