diff --git a/cmd/thv/app/config.go b/cmd/thv/app/config.go index 517591ed9d..007de1272b 100644 --- a/cmd/thv/app/config.go +++ b/cmd/thv/app/config.go @@ -181,9 +181,13 @@ func unsetCACertCmdFunc(_ *cobra.Command, _ []string) error { func setRegistryCmdFunc(cmd *cobra.Command, args []string) error { input := args[0] + issuer, clientID := registry.ResolveAuthDefaults( + cmd.Context(), registryAuthIssuer, registryAuthClientID, registry.ActiveAuthDefaulter(), + ) + cfg := ®istry.UpdateRegistryConfig{ AllowPrivateIP: allowPrivateRegistryIp, - HasAuth: registryAuthIssuer != "" && registryAuthClientID != "", + HasAuth: issuer != "" && clientID != "", } if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { cfg.URL = input @@ -209,9 +213,9 @@ func setRegistryCmdFunc(cmd *cobra.Command, args []string) error { return enhanceRegistryError(err, input, registryType) } - // If auth flags were provided, configure the new auth - if registryAuthIssuer != "" && registryAuthClientID != "" { - if err := authManager.SetOAuthAuth(cmd.Context(), registryAuthIssuer, registryAuthClientID, registryAuthAudience, + // If auth was provided (via flags or discovered from the platform), configure it. + if issuer != "" && clientID != "" { + if err := authManager.SetOAuthAuth(cmd.Context(), issuer, clientID, registryAuthAudience, registryAuthScopes); err != nil { return fmt.Errorf("failed to configure registry auth: %w", err) } diff --git a/pkg/api/v1/registry.go b/pkg/api/v1/registry.go index 3727ef3f8f..b603d1047d 100644 --- a/pkg/api/v1/registry.go +++ b/pkg/api/v1/registry.go @@ -552,20 +552,9 @@ func (rr *RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) } responseType = registryType - // Always overwrite auth: if auth is provided, set it; if not, clear it. - // This prevents stale tokens from being sent to the wrong registry server. - if req.Auth != nil { - if err := rr.processAuthUpdate(r.Context(), req.Auth); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - } else { - authMgr := regpkg.NewAuthManager(rr.configProvider) - if err := authMgr.UnsetAuth(); err != nil { - slog.Error("failed to clear registry auth", "error", err) - http.Error(w, "Failed to clear registry auth", http.StatusInternalServerError) - return - } + if err := rr.applyRegistryAuth(r.Context(), req.Auth); err != nil { + writeRegistryAuthError(w, err) + return } // Reset the registry provider cache to pick up configuration changes @@ -621,6 +610,51 @@ func updateRegistryConfigFromRequest(req *UpdateRegistryRequest) *regpkg.UpdateR return cfg } +// authClearError signals that clearing the registry auth (the no-auth path) +// failed. updateRegistry maps it to a 500; the explicit-auth path returns 400. +type authClearError struct{ err error } + +func (e *authClearError) Error() string { return e.err.Error() } +func (e *authClearError) Unwrap() error { return e.err } + +// writeRegistryAuthError translates an applyRegistryAuth error into the +// appropriate HTTP status. Clearing failures are 500 (internal); validation +// or OIDC discovery failures from processAuthUpdate are 400 (client input). +func writeRegistryAuthError(w http.ResponseWriter, err error) { + var clearErr *authClearError + if errors.As(err, &clearErr) { + slog.Error("failed to clear registry auth", "error", clearErr.err) + http.Error(w, "Failed to clear registry auth", http.StatusInternalServerError) + return + } + http.Error(w, err.Error(), http.StatusBadRequest) +} + +// applyRegistryAuth overwrites the registry auth config to match the incoming +// request. If the caller supplied auth it is applied verbatim. If not, an +// enterprise-registered AuthDefaulter gets a chance to supply issuer/client_id +// from /.well-known discovery before falling back to clearing the auth — this +// lets Studio and `thv config set-registry` work without distributing OAuth +// params out of band. Always overwrites to prevent stale tokens from being +// sent to a different registry server. +func (rr *RegistryRoutes) applyRegistryAuth(ctx context.Context, authReq *UpdateRegistryAuthRequest) error { + if authReq == nil { + if issuer, clientID := regpkg.ResolveAuthDefaults( + ctx, "", "", regpkg.ActiveAuthDefaulter(), + ); issuer != "" && clientID != "" { + authReq = &UpdateRegistryAuthRequest{Issuer: issuer, ClientID: clientID} + } + } + if authReq != nil { + return rr.processAuthUpdate(ctx, authReq) + } + authMgr := regpkg.NewAuthManager(rr.configProvider) + if err := authMgr.UnsetAuth(); err != nil { + return &authClearError{err: err} + } + return nil +} + // processAuthUpdate validates and applies OAuth configuration for registry auth. func (rr *RegistryRoutes) processAuthUpdate(ctx context.Context, authReq *UpdateRegistryAuthRequest) error { if authReq.Issuer == "" || authReq.ClientID == "" { diff --git a/pkg/api/v1/registry_test.go b/pkg/api/v1/registry_test.go index 2813feaf09..5fa270c8cf 100644 --- a/pkg/api/v1/registry_test.go +++ b/pkg/api/v1/registry_test.go @@ -411,6 +411,75 @@ func TestRemoveRegistry_BlockedByPolicyGate(t *testing.T) { assert.Contains(t, w.Body.String(), "organization policy") } +//nolint:paralleltest // Mutates global registry auth defaulter singleton +func TestUpdateRegistry_AuthDefaulterFillsMissingAuth(t *testing.T) { + originalGate := registry.ActivePolicyGate() + originalDefaulter := registry.ActiveAuthDefaulter() + t.Cleanup(func() { + registry.RegisterPolicyGate(originalGate) + registry.RegisterAuthDefaulter(originalDefaulter) + }) + registry.RegisterPolicyGate(registry.NoopPolicyGate{}) + + // Mock an OIDC discovery server so SetOAuthAuth's discovery succeeds. + var oidcSrv *httptest.Server + oidcSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" && + r.URL.Path != "/.well-known/oauth-authorization-server" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": oidcSrv.URL, + "authorization_endpoint": oidcSrv.URL + "/authorize", + "token_endpoint": oidcSrv.URL + "/token", + }) + })) + t.Cleanup(oidcSrv.Close) + + const discoveredClientID = "disco-client" + registry.RegisterAuthDefaulter(func(_ context.Context) (string, string, error) { + return oidcSrv.URL, discoveredClientID, nil + }) + + validRegistryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "$schema": "https://example.com/schema.json", + "version": "1.0.0", + "meta": map[string]interface{}{"last_updated": "2025-01-01T00:00:00Z"}, + "data": map[string]interface{}{ + "servers": []interface{}{ + map[string]interface{}{"name": "io.example.test-server"}, + }, + }, + }) + })) + t.Cleanup(validRegistryServer.Close) + + provider, cleanup := CreateTestConfigProvider(t, nil) + defer cleanup() + routes := NewRegistryRoutesWithProvider(provider) + + body := `{"url":"` + validRegistryServer.URL + `","allow_private_ip":true}` + req := httptest.NewRequest(http.MethodPut, "/default", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rctx := chi.NewRouteContext() + rctx.URLParams.Add("name", "default") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + routes.updateRegistry(w, req) + + require.Equal(t, http.StatusOK, w.Code, "expected success: %s", w.Body.String()) + + cfg := provider.GetConfig() + require.NotNil(t, cfg.RegistryAuth.OAuth, "defaulter should have populated OAuth config") + assert.Equal(t, oidcSrv.URL, cfg.RegistryAuth.OAuth.Issuer) + assert.Equal(t, discoveredClientID, cfg.RegistryAuth.OAuth.ClientID) +} + //nolint:paralleltest // Mutates global registry policy gate singleton func TestUpdateRegistry_AllowedByDefaultGate(t *testing.T) { original := registry.ActivePolicyGate() diff --git a/pkg/registry/auth_defaulter.go b/pkg/registry/auth_defaulter.go new file mode 100644 index 0000000000..f60bbc40e1 --- /dev/null +++ b/pkg/registry/auth_defaulter.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "context" + "log/slog" + "sync" +) + +// AuthDefaulter resolves OIDC issuer and client ID for a registry update +// when the caller did not supply them explicitly. Enterprise builds register +// a defaulter that queries the platform's well-known configuration endpoint, +// so admins enforcing a registry no longer need to distribute issuer and +// client_id out of band. +type AuthDefaulter func(ctx context.Context) (issuer, clientID string, err error) + +var ( + authDefaulterMu sync.RWMutex + authDefaulter AuthDefaulter +) + +// RegisterAuthDefaulter sets the active registry auth defaulter. Safe for +// concurrent use, though it is intended to be called once at startup. +// Passing nil clears any previously registered defaulter. +func RegisterAuthDefaulter(d AuthDefaulter) { + authDefaulterMu.Lock() + defer authDefaulterMu.Unlock() + authDefaulter = d +} + +// ActiveAuthDefaulter returns the currently registered auth defaulter, or +// nil if none has been registered. +func ActiveAuthDefaulter() AuthDefaulter { + authDefaulterMu.RLock() + defer authDefaulterMu.RUnlock() + return authDefaulter +} + +// ResolveAuthDefaults returns the OIDC issuer and client ID for a registry +// update. Explicit values always win; when both are empty and a defaulter +// is registered, the defaulter is consulted. A defaulter error falls back +// to empty so callers preserve the legacy "no auth" behaviour. +func ResolveAuthDefaults(ctx context.Context, issuer, clientID string, defaulter AuthDefaulter) (string, string) { + if issuer != "" || clientID != "" { + return issuer, clientID + } + if defaulter == nil { + return "", "" + } + resolvedIssuer, resolvedClientID, err := defaulter(ctx) + if err != nil { + slog.Debug("registry auth discovery failed, proceeding without auth", "error", err) + return "", "" + } + return resolvedIssuer, resolvedClientID +} diff --git a/pkg/registry/auth_defaulter_test.go b/pkg/registry/auth_defaulter_test.go new file mode 100644 index 0000000000..25637818b4 --- /dev/null +++ b/pkg/registry/auth_defaulter_test.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveAuthDefaults(t *testing.T) { + t.Parallel() + + const ( + discoveredIssuer = "https://disco.example.com" + discoveredClientID = "disco-client" + flagIssuer = "https://flag.example.com" + flagClientID = "flag-client" + ) + + defaulterOK := func(_ context.Context) (string, string, error) { + return discoveredIssuer, discoveredClientID, nil + } + defaulterErr := func(_ context.Context) (string, string, error) { + return "", "", errors.New("config server unreachable") + } + + tests := []struct { + name string + issuer string + clientID string + defaulter AuthDefaulter + wantIssuer string + wantClientID string + }{ + { + name: "explicit values take precedence over defaulter", + issuer: flagIssuer, + clientID: flagClientID, + defaulter: defaulterOK, + wantIssuer: flagIssuer, + wantClientID: flagClientID, + }, + { + name: "no explicit values and defaulter succeeds returns discovered values", + defaulter: defaulterOK, + wantIssuer: discoveredIssuer, + wantClientID: discoveredClientID, + }, + { + name: "no explicit values and defaulter fails falls back to empty", + defaulter: defaulterErr, + }, + { + name: "no explicit values and nil defaulter returns empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotIssuer, gotClientID := ResolveAuthDefaults( + t.Context(), tt.issuer, tt.clientID, tt.defaulter, + ) + + require.Equal(t, tt.wantIssuer, gotIssuer) + assert.Equal(t, tt.wantClientID, gotClientID) + }) + } +} + +//nolint:paralleltest // Mutates global registry auth defaulter singleton +func TestRegisterAuthDefaulter(t *testing.T) { + original := ActiveAuthDefaulter() + t.Cleanup(func() { RegisterAuthDefaulter(original) }) + + const wantIssuer = "https://reg.example.com" + RegisterAuthDefaulter(func(_ context.Context) (string, string, error) { + return wantIssuer, "reg-client", nil + }) + + active := ActiveAuthDefaulter() + require.NotNil(t, active) + issuer, clientID, err := active(t.Context()) + require.NoError(t, err) + assert.Equal(t, wantIssuer, issuer) + assert.Equal(t, "reg-client", clientID) + + RegisterAuthDefaulter(nil) + assert.Nil(t, ActiveAuthDefaulter()) +}