Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cmd/thv/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,13 @@ func unsetCACertCmdFunc(_ *cobra.Command, _ []string) error {
func setRegistryCmdFunc(cmd *cobra.Command, args []string) error {
input := args[0]

issuer, clientID := registry.ResolveAuthDefaults(
cmd.Context(), registryAuthIssuer, registryAuthClientID, registry.ActiveAuthDefaulter(),
)

cfg := &registry.UpdateRegistryConfig{
AllowPrivateIP: allowPrivateRegistryIp,
HasAuth: registryAuthIssuer != "" && registryAuthClientID != "",
HasAuth: issuer != "" && clientID != "",
}
if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
cfg.URL = input
Expand All @@ -209,9 +213,9 @@ func setRegistryCmdFunc(cmd *cobra.Command, args []string) error {
return enhanceRegistryError(err, input, registryType)
}

// If auth flags were provided, configure the new auth
if registryAuthIssuer != "" && registryAuthClientID != "" {
if err := authManager.SetOAuthAuth(cmd.Context(), registryAuthIssuer, registryAuthClientID, registryAuthAudience,
// If auth was provided (via flags or discovered from the platform), configure it.
if issuer != "" && clientID != "" {
if err := authManager.SetOAuthAuth(cmd.Context(), issuer, clientID, registryAuthAudience,
registryAuthScopes); err != nil {
return fmt.Errorf("failed to configure registry auth: %w", err)
}
Expand Down
62 changes: 48 additions & 14 deletions pkg/api/v1/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,20 +552,9 @@ func (rr *RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request)
}
responseType = registryType

// Always overwrite auth: if auth is provided, set it; if not, clear it.
// This prevents stale tokens from being sent to the wrong registry server.
if req.Auth != nil {
if err := rr.processAuthUpdate(r.Context(), req.Auth); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
} else {
authMgr := regpkg.NewAuthManager(rr.configProvider)
if err := authMgr.UnsetAuth(); err != nil {
slog.Error("failed to clear registry auth", "error", err)
http.Error(w, "Failed to clear registry auth", http.StatusInternalServerError)
return
}
if err := rr.applyRegistryAuth(r.Context(), req.Auth); err != nil {
writeRegistryAuthError(w, err)
return
}

// Reset the registry provider cache to pick up configuration changes
Expand Down Expand Up @@ -621,6 +610,51 @@ func updateRegistryConfigFromRequest(req *UpdateRegistryRequest) *regpkg.UpdateR
return cfg
}

// authClearError signals that clearing the registry auth (the no-auth path)
// failed. updateRegistry maps it to a 500; the explicit-auth path returns 400.
type authClearError struct{ err error }

func (e *authClearError) Error() string { return e.err.Error() }
func (e *authClearError) Unwrap() error { return e.err }

// writeRegistryAuthError translates an applyRegistryAuth error into the
// appropriate HTTP status. Clearing failures are 500 (internal); validation
// or OIDC discovery failures from processAuthUpdate are 400 (client input).
func writeRegistryAuthError(w http.ResponseWriter, err error) {
var clearErr *authClearError
if errors.As(err, &clearErr) {
slog.Error("failed to clear registry auth", "error", clearErr.err)
http.Error(w, "Failed to clear registry auth", http.StatusInternalServerError)
return
}
http.Error(w, err.Error(), http.StatusBadRequest)
}

// applyRegistryAuth overwrites the registry auth config to match the incoming
// request. If the caller supplied auth it is applied verbatim. If not, an
// enterprise-registered AuthDefaulter gets a chance to supply issuer/client_id
// from /.well-known discovery before falling back to clearing the auth — this
// lets Studio and `thv config set-registry` work without distributing OAuth
// params out of band. Always overwrites to prevent stale tokens from being
// sent to a different registry server.
func (rr *RegistryRoutes) applyRegistryAuth(ctx context.Context, authReq *UpdateRegistryAuthRequest) error {
if authReq == nil {
if issuer, clientID := regpkg.ResolveAuthDefaults(
ctx, "", "", regpkg.ActiveAuthDefaulter(),
); issuer != "" && clientID != "" {
authReq = &UpdateRegistryAuthRequest{Issuer: issuer, ClientID: clientID}
}
}
if authReq != nil {
return rr.processAuthUpdate(ctx, authReq)
}
authMgr := regpkg.NewAuthManager(rr.configProvider)
if err := authMgr.UnsetAuth(); err != nil {
return &authClearError{err: err}
}
return nil
}

// processAuthUpdate validates and applies OAuth configuration for registry auth.
func (rr *RegistryRoutes) processAuthUpdate(ctx context.Context, authReq *UpdateRegistryAuthRequest) error {
if authReq.Issuer == "" || authReq.ClientID == "" {
Expand Down
69 changes: 69 additions & 0 deletions pkg/api/v1/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,75 @@ func TestRemoveRegistry_BlockedByPolicyGate(t *testing.T) {
assert.Contains(t, w.Body.String(), "organization policy")
}

//nolint:paralleltest // Mutates global registry auth defaulter singleton
func TestUpdateRegistry_AuthDefaulterFillsMissingAuth(t *testing.T) {
originalGate := registry.ActivePolicyGate()
originalDefaulter := registry.ActiveAuthDefaulter()
t.Cleanup(func() {
registry.RegisterPolicyGate(originalGate)
registry.RegisterAuthDefaulter(originalDefaulter)
})
registry.RegisterPolicyGate(registry.NoopPolicyGate{})

// Mock an OIDC discovery server so SetOAuthAuth's discovery succeeds.
var oidcSrv *httptest.Server
oidcSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" &&
r.URL.Path != "/.well-known/oauth-authorization-server" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"issuer": oidcSrv.URL,
"authorization_endpoint": oidcSrv.URL + "/authorize",
"token_endpoint": oidcSrv.URL + "/token",
})
}))
t.Cleanup(oidcSrv.Close)

const discoveredClientID = "disco-client"
registry.RegisterAuthDefaulter(func(_ context.Context) (string, string, error) {
return oidcSrv.URL, discoveredClientID, nil
})

validRegistryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"$schema": "https://example.com/schema.json",
"version": "1.0.0",
"meta": map[string]interface{}{"last_updated": "2025-01-01T00:00:00Z"},
"data": map[string]interface{}{
"servers": []interface{}{
map[string]interface{}{"name": "io.example.test-server"},
},
},
})
}))
t.Cleanup(validRegistryServer.Close)

provider, cleanup := CreateTestConfigProvider(t, nil)
defer cleanup()
routes := NewRegistryRoutesWithProvider(provider)

body := `{"url":"` + validRegistryServer.URL + `","allow_private_ip":true}`
req := httptest.NewRequest(http.MethodPut, "/default", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rctx := chi.NewRouteContext()
rctx.URLParams.Add("name", "default")
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))

w := httptest.NewRecorder()
routes.updateRegistry(w, req)

require.Equal(t, http.StatusOK, w.Code, "expected success: %s", w.Body.String())

cfg := provider.GetConfig()
require.NotNil(t, cfg.RegistryAuth.OAuth, "defaulter should have populated OAuth config")
assert.Equal(t, oidcSrv.URL, cfg.RegistryAuth.OAuth.Issuer)
assert.Equal(t, discoveredClientID, cfg.RegistryAuth.OAuth.ClientID)
}

//nolint:paralleltest // Mutates global registry policy gate singleton
func TestUpdateRegistry_AllowedByDefaultGate(t *testing.T) {
original := registry.ActivePolicyGate()
Expand Down
58 changes: 58 additions & 0 deletions pkg/registry/auth_defaulter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package registry

import (
"context"
"log/slog"
"sync"
)

// AuthDefaulter resolves OIDC issuer and client ID for a registry update
// when the caller did not supply them explicitly. Enterprise builds register
// a defaulter that queries the platform's well-known configuration endpoint,
// so admins enforcing a registry no longer need to distribute issuer and
// client_id out of band.
type AuthDefaulter func(ctx context.Context) (issuer, clientID string, err error)

var (
authDefaulterMu sync.RWMutex
authDefaulter AuthDefaulter
)

// RegisterAuthDefaulter sets the active registry auth defaulter. Safe for
// concurrent use, though it is intended to be called once at startup.
// Passing nil clears any previously registered defaulter.
func RegisterAuthDefaulter(d AuthDefaulter) {
authDefaulterMu.Lock()
defer authDefaulterMu.Unlock()
authDefaulter = d
}

// ActiveAuthDefaulter returns the currently registered auth defaulter, or
// nil if none has been registered.
func ActiveAuthDefaulter() AuthDefaulter {
authDefaulterMu.RLock()
defer authDefaulterMu.RUnlock()
return authDefaulter
}

// ResolveAuthDefaults returns the OIDC issuer and client ID for a registry
// update. Explicit values always win; when both are empty and a defaulter
// is registered, the defaulter is consulted. A defaulter error falls back
// to empty so callers preserve the legacy "no auth" behaviour.
func ResolveAuthDefaults(ctx context.Context, issuer, clientID string, defaulter AuthDefaulter) (string, string) {
if issuer != "" || clientID != "" {
return issuer, clientID
}
if defaulter == nil {
return "", ""
}
resolvedIssuer, resolvedClientID, err := defaulter(ctx)
if err != nil {
slog.Debug("registry auth discovery failed, proceeding without auth", "error", err)
return "", ""
}
return resolvedIssuer, resolvedClientID
}
96 changes: 96 additions & 0 deletions pkg/registry/auth_defaulter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package registry

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestResolveAuthDefaults(t *testing.T) {
t.Parallel()

const (
discoveredIssuer = "https://disco.example.com"
discoveredClientID = "disco-client"
flagIssuer = "https://flag.example.com"
flagClientID = "flag-client"
)

defaulterOK := func(_ context.Context) (string, string, error) {
return discoveredIssuer, discoveredClientID, nil
}
defaulterErr := func(_ context.Context) (string, string, error) {
return "", "", errors.New("config server unreachable")
}

tests := []struct {
name string
issuer string
clientID string
defaulter AuthDefaulter
wantIssuer string
wantClientID string
}{
{
name: "explicit values take precedence over defaulter",
issuer: flagIssuer,
clientID: flagClientID,
defaulter: defaulterOK,
wantIssuer: flagIssuer,
wantClientID: flagClientID,
},
{
name: "no explicit values and defaulter succeeds returns discovered values",
defaulter: defaulterOK,
wantIssuer: discoveredIssuer,
wantClientID: discoveredClientID,
},
{
name: "no explicit values and defaulter fails falls back to empty",
defaulter: defaulterErr,
},
{
name: "no explicit values and nil defaulter returns empty",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

gotIssuer, gotClientID := ResolveAuthDefaults(
t.Context(), tt.issuer, tt.clientID, tt.defaulter,
)

require.Equal(t, tt.wantIssuer, gotIssuer)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}

//nolint:paralleltest // Mutates global registry auth defaulter singleton
func TestRegisterAuthDefaulter(t *testing.T) {
original := ActiveAuthDefaulter()
t.Cleanup(func() { RegisterAuthDefaulter(original) })

const wantIssuer = "https://reg.example.com"
RegisterAuthDefaulter(func(_ context.Context) (string, string, error) {
return wantIssuer, "reg-client", nil
})

active := ActiveAuthDefaulter()
require.NotNil(t, active)
issuer, clientID, err := active(t.Context())
require.NoError(t, err)
assert.Equal(t, wantIssuer, issuer)
assert.Equal(t, "reg-client", clientID)

RegisterAuthDefaulter(nil)
assert.Nil(t, ActiveAuthDefaulter())
}
Loading