From 6dec0bb1f59d0e69f2cc21b6223a7cc0d5175b5b Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 09:19:15 +0000 Subject: [PATCH 1/8] Add UpstreamTokens field to Identity with redacted serialization The auth middleware will populate upstream provider access tokens on the Identity struct, allowing downstream middleware (upstreamswap) to read tokens without coupling to the storage layer. MarshalJSON redacts token values while preserving provider keys, and both nil and empty maps are omitted via omitempty. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/identity.go | 56 ++++++++++++++++------- pkg/auth/identity_test.go | 96 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/pkg/auth/identity.go b/pkg/auth/identity.go index 4af1fefdf4..0215e54a23 100644 --- a/pkg/auth/identity.go +++ b/pkg/auth/identity.go @@ -56,6 +56,13 @@ type Identity struct { // Metadata stores additional identity information. Metadata map[string]string + + // UpstreamTokens maps upstream provider names to their access tokens. + // This is populated by the auth middleware when an embedded auth server + // is active and the JWT contains a token session ID (tsid claim). + // Redacted in MarshalJSON() to prevent token leakage. + // MUST NOT be mutated after the Identity is placed in the request context. + UpstreamTokens map[string]string } // String returns a string representation of the Identity with sensitive fields redacted. @@ -77,14 +84,15 @@ func (i *Identity) MarshalJSON() ([]byte, error) { // Create a safe representation with lowercase field names and redacted token type SafeIdentity struct { - Subject string `json:"subject"` - Name string `json:"name"` - Email string `json:"email"` - Groups []string `json:"groups"` - Claims map[string]any `json:"claims"` - Token string `json:"token"` - TokenType string `json:"tokenType"` - Metadata map[string]string `json:"metadata"` + Subject string `json:"subject"` + Name string `json:"name"` + Email string `json:"email"` + Groups []string `json:"groups"` + Claims map[string]any `json:"claims"` + Token string `json:"token"` + TokenType string `json:"tokenType"` + Metadata map[string]string `json:"metadata"` + UpstreamTokens map[string]string `json:"upstreamTokens,omitempty"` } token := i.Token @@ -92,15 +100,31 @@ func (i *Identity) MarshalJSON() ([]byte, error) { token = "REDACTED" } + // Redact upstream tokens: preserve keys, replace non-empty values + var redactedUpstreamTokens map[string]string + // Guard with len() > 0 (not != nil) so that both nil and empty maps + // produce a nil redactedUpstreamTokens, which omitempty then omits. + if len(i.UpstreamTokens) > 0 { + redactedUpstreamTokens = make(map[string]string, len(i.UpstreamTokens)) + for k, v := range i.UpstreamTokens { + if v != "" { + redactedUpstreamTokens[k] = "REDACTED" + } else { + redactedUpstreamTokens[k] = "" + } + } + } + return json.Marshal(&SafeIdentity{ - Subject: i.Subject, - Name: i.Name, - Email: i.Email, - Groups: i.Groups, - Claims: i.Claims, - Token: token, - TokenType: i.TokenType, - Metadata: i.Metadata, + Subject: i.Subject, + Name: i.Name, + Email: i.Email, + Groups: i.Groups, + Claims: i.Claims, + Token: token, + TokenType: i.TokenType, + Metadata: i.Metadata, + UpstreamTokens: redactedUpstreamTokens, }) } diff --git a/pkg/auth/identity_test.go b/pkg/auth/identity_test.go index 5ec9d1661e..f37d995c57 100644 --- a/pkg/auth/identity_test.go +++ b/pkg/auth/identity_test.go @@ -150,6 +150,16 @@ func TestIdentity_String(t *testing.T) { identity: nil, want: "", }, + { + name: "does_not_leak_upstream_tokens", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + }, + }, + want: `Identity{Subject:"user123"}`, + }, } for _, tt := range tests { @@ -293,6 +303,92 @@ func TestIdentity_MarshalJSON(t *testing.T) { assert.Equal(t, "null", string(data)) }, }, + { + name: "redacts_upstream_tokens", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + "atlassian": "atl_secret456", + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + tokens, ok := result["upstreamTokens"].(map[string]any) + require.True(t, ok, "upstreamTokens should be a map") + assert.Equal(t, "REDACTED", tokens["github"]) + assert.Equal(t, "REDACTED", tokens["atlassian"]) + assert.NotContains(t, string(data), "gho_secret123") + assert.NotContains(t, string(data), "atl_secret456") + }, + }, + { + name: "empty_upstream_tokens_omitted", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{}, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + // Empty map should be omitted because len() == 0 produces nil redacted map + _, exists := result["upstreamTokens"] + assert.False(t, exists, "empty upstreamTokens should be omitted") + }, + }, + { + name: "nil_upstream_tokens_omitted", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: nil, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + _, exists := result["upstreamTokens"] + assert.False(t, exists, "nil upstreamTokens should be omitted") + }, + }, + { + name: "upstream_tokens_mixed_empty_and_populated", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + "pending": "", + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + tokens, ok := result["upstreamTokens"].(map[string]any) + require.True(t, ok, "upstreamTokens should be a map") + assert.Equal(t, "REDACTED", tokens["github"]) + assert.Equal(t, "", tokens["pending"]) + assert.NotContains(t, string(data), "gho_secret123") + }, + }, } for _, tt := range tests { From 975a97301a65d840dd5227e42558e55645aafa46 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 09:41:51 +0000 Subject: [PATCH 2/8] Add UpstreamTokenReader interface and bulk token retrieval Introduce a narrow UpstreamTokenReader interface that decouples the auth middleware from storage internals. InProcessService.GetAllValidTokens performs a bulk read of all upstream providers for a session, refreshing expired tokens transparently and falling back to expired tokens when refresh fails. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/upstreamtoken/service.go | 51 +++++- pkg/auth/upstreamtoken/service_test.go | 210 +++++++++++++++++++++++++ pkg/auth/upstreamtoken/types.go | 13 ++ 3 files changed, 272 insertions(+), 2 deletions(-) diff --git a/pkg/auth/upstreamtoken/service.go b/pkg/auth/upstreamtoken/service.go index 537a0e0a4c..131ace05e4 100644 --- a/pkg/auth/upstreamtoken/service.go +++ b/pkg/auth/upstreamtoken/service.go @@ -29,8 +29,11 @@ type InProcessService struct { sfGroup singleflight.Group } -// Compile-time check. -var _ Service = (*InProcessService)(nil) +// Compile-time checks. +var ( + _ Service = (*InProcessService)(nil) + _ UpstreamTokenReader = (*InProcessService)(nil) +) // NewInProcessService creates a new InProcessService. // The refresher may be nil if upstream token refresh is not configured; @@ -77,6 +80,50 @@ func (s *InProcessService) GetValidTokens(ctx context.Context, sessionID, provid return &UpstreamCredential{AccessToken: tokens.AccessToken}, nil } +// GetAllValidTokens returns access tokens for all upstream providers in a session. +// Expired tokens are refreshed transparently; if refresh fails, the provider is +// omitted from the result so downstream middleware can return a clean 401. +func (s *InProcessService) GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) { + allTokens, err := s.storage.GetAllUpstreamTokens(ctx, sessionID) + if err != nil { + return nil, fmt.Errorf("bulk read upstream tokens: %w", err) + } + + if len(allTokens) == 0 { + return map[string]string{}, nil + } + + result := make(map[string]string, len(allTokens)) + // TODO(auth): Refresh providers in parallel using errgroup to avoid + // worst-case latency of N * refreshTimeout when multiple providers need refresh. + for providerName, tokens := range allTokens { + if tokens == nil { + continue + } + + // If token is not expired, use it directly. + if tokens.ExpiresAt.IsZero() || !tokens.IsExpired(time.Now()) { + result[providerName] = tokens.AccessToken + continue + } + + // Token is expired — attempt refresh. + refreshed, refreshErr := s.refreshOrFail(ctx, sessionID, providerName, tokens) + if refreshErr != nil { + // Refresh failed — omit provider so downstream middleware returns 401. + slog.WarnContext(ctx, "omitting provider with unrefreshable expired token", + "session_id", sessionID, + "provider", providerName, + "error", refreshErr, + ) + continue + } + result[providerName] = refreshed.AccessToken + } + + return result, nil +} + // refreshOrFail attempts a singleflight-deduplicated refresh and maps errors // to the service's sentinel errors. func (s *InProcessService) refreshOrFail( diff --git a/pkg/auth/upstreamtoken/service_test.go b/pkg/auth/upstreamtoken/service_test.go index 0dde8d1acc..70a3553d68 100644 --- a/pkg/auth/upstreamtoken/service_test.go +++ b/pkg/auth/upstreamtoken/service_test.go @@ -233,3 +233,213 @@ func TestInProcessService_NilRefresher(t *testing.T) { assert.ErrorIs(t, err, ErrNoRefreshToken) assert.Nil(t, cred) } + +func TestInProcessService_GetAllValidTokens(t *testing.T) { + t.Parallel() + + freshTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "github-access-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + freshTokens2 := &storage.UpstreamTokens{ + ProviderID: "atlassian", + AccessToken: "atlassian-access-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + expiredTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "expired-github-token", + RefreshToken: "github-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + refreshedTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "new-github-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + tests := []struct { + name string + sessionID string + setupStorage func(*storagemocks.MockUpstreamTokenStorage) + setupRefresher func(*storagemocks.MockUpstreamTokenRefresher) + wantResult map[string]string + wantErr bool + }{ + { + name: "all fresh tokens returned directly", + sessionID: "session-1", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-1"). + Return(map[string]*storage.UpstreamTokens{ + "github": freshTokens, + "atlassian": freshTokens2, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "github-access-token", + "atlassian": "atlassian-access-token", + }, + }, + { + name: "mixed fresh and expired with successful refresh", + sessionID: "session-2", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-2"). + Return(map[string]*storage.UpstreamTokens{ + "atlassian": freshTokens2, + "github": expiredTokens, + }, nil) + }, + setupRefresher: func(r *storagemocks.MockUpstreamTokenRefresher) { + r.EXPECT().RefreshAndStore(gomock.Any(), "session-2", expiredTokens). + Return(refreshedTokens, nil) + }, + wantResult: map[string]string{ + "atlassian": "atlassian-access-token", + "github": "new-github-token", + }, + }, + { + name: "expired refresh fails omits provider", + sessionID: "session-3", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-3"). + Return(map[string]*storage.UpstreamTokens{ + "github": expiredTokens, + }, nil) + }, + setupRefresher: func(r *storagemocks.MockUpstreamTokenRefresher) { + r.EXPECT().RefreshAndStore(gomock.Any(), "session-3", expiredTokens). + Return(nil, errors.New("upstream IDP unavailable")) + }, + wantResult: map[string]string{}, + }, + { + name: "empty session returns empty map", + sessionID: "session-4", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-4"). + Return(map[string]*storage.UpstreamTokens{}, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{}, + }, + { + name: "storage error propagated", + sessionID: "session-5", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-5"). + Return(nil, errors.New("redis connection lost")) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantErr: true, + }, + { + name: "nil tokens entry skipped", + sessionID: "session-6", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-6"). + Return(map[string]*storage.UpstreamTokens{ + "github": freshTokens, + "atlassian": nil, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "github-access-token", + }, + }, + { + name: "expired with no refresh token omits provider", + sessionID: "session-7", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-7"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "expired-no-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + RefreshToken: "", + }, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{}, + }, + { + name: "zero ExpiresAt treated as non-expiring", + sessionID: "session-8", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-8"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "no-expiry-token", + ExpiresAt: time.Time{}, + }, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "no-expiry-token", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + + tt.setupStorage(mockStorage) + tt.setupRefresher(mockRefresher) + + svc := NewInProcessService(mockStorage, mockRefresher) + + result, err := svc.GetAllValidTokens(context.Background(), tt.sessionID) + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, result) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +// TestInProcessService_GetAllValidTokens_NilRefresher verifies that when the +// refresher is nil, expired tokens in the bulk path are omitted (not panicking). +func TestInProcessService_GetAllValidTokens_NilRefresher(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetAllUpstreamTokens(gomock.Any(), "session-1"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "expired-token", + RefreshToken: "has-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, + }, nil) + + svc := NewInProcessService(mockStorage, nil) + + result, err := svc.GetAllValidTokens(context.Background(), "session-1") + + require.NoError(t, err) + assert.Equal(t, map[string]string{}, result) +} diff --git a/pkg/auth/upstreamtoken/types.go b/pkg/auth/upstreamtoken/types.go index 7eafa96577..e77383a215 100644 --- a/pkg/auth/upstreamtoken/types.go +++ b/pkg/auth/upstreamtoken/types.go @@ -19,6 +19,19 @@ type UpstreamCredential struct { AccessToken string } +// UpstreamTokenReader retrieves upstream provider access tokens for a session. +// This narrow interface decouples the auth middleware from storage internals. +// +// TODO(auth): Consider enriching the return type from map[string]string to +// map[string]UpstreamCredential to carry per-provider freshness/error metadata. +type UpstreamTokenReader interface { + // GetAllValidTokens returns access tokens for all upstream providers in a session. + // Expired tokens are refreshed transparently when possible; if refresh fails, + // the provider is omitted from the result. + // Returns an empty map (not error) for unknown sessions. + GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) +} + // Service owns the upstream token lifecycle: read, refresh, error handling. type Service interface { // GetValidTokens returns a valid upstream credential for a session and provider. From 20d01ebc065f56f7f6a22b1e91800a8a4ac18df7 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 10:04:19 +0000 Subject: [PATCH 3/8] Thread UpstreamTokenReader through middleware runner and auth factory Add GetUpstreamTokenReader to MiddlewareRunner so the auth middleware can access the bulk token reader for identity enrichment. The runner stores the reader as a separate field (set alongside the token service) avoiding type assertions. GetAuthenticationMiddleware now accepts variadic TokenValidatorOption for forward-compatible option passing. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/middleware.go | 7 +++++- pkg/auth/middleware_test.go | 6 +++++ pkg/auth/utils.go | 4 +-- pkg/runner/runner.go | 17 ++++++++++++- pkg/runner/runner_test.go | 28 +++++++++++++++++++++ pkg/transport/types/mocks/mock_transport.go | 14 +++++++++++ pkg/transport/types/transport.go | 4 +++ 7 files changed, 76 insertions(+), 4 deletions(-) diff --git a/pkg/auth/middleware.go b/pkg/auth/middleware.go index 3c1bcd7662..ed2661d67c 100644 --- a/pkg/auth/middleware.go +++ b/pkg/auth/middleware.go @@ -52,7 +52,12 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("failed to unmarshal auth middleware parameters: %w", err) } - middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig) + var opts []TokenValidatorOption + if reader := runner.GetUpstreamTokenReader(); reader != nil { + opts = append(opts, WithUpstreamTokenReader(reader)) + } + + middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/auth/middleware_test.go b/pkg/auth/middleware_test.go index dca9cd6976..c4a4fd38c4 100644 --- a/pkg/auth/middleware_test.go +++ b/pkg/auth/middleware_test.go @@ -102,6 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) { // Create mock runner mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + // Expect GetUpstreamTokenReader to be called (returns nil = no auth server) + mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil) + // Expect AddMiddleware to be called with a middleware instance mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) { // Verify it's our auth middleware @@ -210,6 +213,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) { mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + // Expect GetUpstreamTokenReader to be called (returns nil = no auth server) + mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil) + // Expect AddMiddleware to be called mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()) diff --git a/pkg/auth/utils.go b/pkg/auth/utils.go index dc5350b9fe..10137afb9c 100644 --- a/pkg/auth/utils.go +++ b/pkg/auth/utils.go @@ -61,13 +61,13 @@ func ExtractBearerToken(r *http.Request) (string, error) { // GetAuthenticationMiddleware returns the appropriate authentication middleware based on the configuration. // If OIDC config is provided, it returns JWT middleware. Otherwise, it returns local user middleware. -func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, +func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, opts ...TokenValidatorOption, ) (func(http.Handler) http.Handler, http.Handler, error) { if oidcConfig != nil { slog.Debug("oidc validation enabled") // Create JWT validator - jwtValidator, err := NewTokenValidator(ctx, *oidcConfig) + jwtValidator, err := NewTokenValidator(ctx, *oidcConfig, opts...) if err != nil { return nil, nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 0d102774d7..2159547167 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -80,6 +80,12 @@ type Runner struct { // after the embedded auth server is initialized in Run(). // Nil when no embedded auth server is configured. upstreamTokenService upstreamtoken.Service + + // upstreamTokenReader provides read-only access to upstream tokens for + // identity enrichment in auth middleware. Set alongside upstreamTokenService + // when the embedded auth server is initialized in Run(). + // Nil when no embedded auth server is configured. + upstreamTokenReader upstreamtoken.UpstreamTokenReader } // statusManagerAdapter adapts statuses.StatusManager to auth.StatusUpdater interface @@ -142,6 +148,13 @@ func (r *Runner) GetUpstreamTokenService() func() upstreamtoken.Service { } } +// GetUpstreamTokenReader returns the UpstreamTokenReader for identity +// enrichment in the auth middleware. Returns nil if no embedded auth +// server is configured. +func (r *Runner) GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader { + return r.upstreamTokenReader +} + // GetName returns the name of the mcp-service from the runner config (implements types.RunnerConfig) func (c *RunConfig) GetName() string { return c.Name @@ -258,7 +271,9 @@ func (r *Runner) Run(ctx context.Context) error { // InProcessService handles this gracefully (returns ErrNoRefreshToken). stor := r.embeddedAuthServer.IDPTokenStorage() refresher := r.embeddedAuthServer.UpstreamTokenRefresher() - r.upstreamTokenService = upstreamtoken.NewInProcessService(stor, refresher) + inProc := upstreamtoken.NewInProcessService(stor, refresher) + r.upstreamTokenService = inProc + r.upstreamTokenReader = inProc // Mount auth server routes at specific prefixes to avoid conflicts with MCP endpoints // (e.g., /.well-known/oauth-protected-resource is an MCP endpoint, not auth server) diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index deef4bfbd6..ef9ac4bb18 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/authserver" authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner" + storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/transport/types" statusesmocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" @@ -560,3 +561,30 @@ func TestRunner_RejectsMultiUpstreamConfig(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "does not support multiple upstream providers") } + +func TestRunner_GetUpstreamTokenReader(t *testing.T) { + t.Parallel() + + t.Run("returns nil when no auth server configured", func(t *testing.T) { + t.Parallel() + + r := &Runner{} + reader := r.GetUpstreamTokenReader() + assert.Nil(t, reader) + }) + + t.Run("returns reader when auth server configured", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + svc := upstreamtoken.NewInProcessService(mockStorage, nil) + + r := &Runner{ + upstreamTokenReader: svc, + } + reader := r.GetUpstreamTokenReader() + assert.NotNil(t, reader) + assert.Equal(t, svc, reader) + }) +} diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index f852f046e4..7ab524ffc2 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -123,6 +123,20 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetConfig)) } +// GetUpstreamTokenReader mocks base method. +func (m *MockMiddlewareRunner) GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUpstreamTokenReader") + ret0, _ := ret[0].(upstreamtoken.UpstreamTokenReader) + return ret0 +} + +// GetUpstreamTokenReader indicates an expected call of GetUpstreamTokenReader. +func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenReader() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenReader", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenReader)) +} + // GetUpstreamTokenService mocks base method. func (m *MockMiddlewareRunner) GetUpstreamTokenService() func() upstreamtoken.Service { m.ctrl.T.Helper() diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index a8892863f1..976c902a92 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -84,6 +84,10 @@ type MiddlewareRunner interface { // This method always returns a non-nil function. Service availability is // determined at request time when the returned function is called. GetUpstreamTokenService() func() upstreamtoken.Service + + // GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment. + // Returns nil if the embedded auth server is not configured. + GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader } // RunnerConfig defines the config interface needed by middleware to access runner configuration From 764aa5536eeec4076e56f6e54029b901609e9706 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 09:53:21 +0000 Subject: [PATCH 4/8] Enrich Identity with upstream tokens during JWT validation Add WithUpstreamTokenReader option to TokenValidator that loads all upstream provider tokens from storage when a tsid claim is present in the JWT. The enrichment happens between claimsToIdentity and context injection, populating Identity.UpstreamTokens for downstream middleware consumption. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/context.go | 25 +++- pkg/auth/token.go | 75 +++++++++-- pkg/auth/token_test.go | 279 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 365 insertions(+), 14 deletions(-) diff --git a/pkg/auth/context.go b/pkg/auth/context.go index af7ea8d8fa..c7e5420541 100644 --- a/pkg/auth/context.go +++ b/pkg/auth/context.go @@ -67,10 +67,16 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) { return nil, errors.New("missing or invalid 'sub' claim (required by OIDC Core 1.0 § 5.1)") } + // Filter internal claims that should not be externalized (e.g., in + // webhook payloads or audit logs). The tsid is a session identifier + // used to look up upstream tokens in storage; exposing it widens the + // attack surface if a webhook receiver is compromised. + filteredClaims := filterInternalClaims(claims) + identity := &Identity{ PrincipalInfo: PrincipalInfo{ Subject: sub, - Claims: claims, + Claims: filteredClaims, }, Token: token, TokenType: "Bearer", @@ -86,3 +92,20 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) { return identity, nil } + +// internalClaims are JWT claim keys used internally by the auth server +// that must not be externalized in webhook payloads, audit logs, etc. +// "tsid" is the token session ID used to look up upstream tokens in storage. +var internalClaims = []string{"tsid"} + +// filterInternalClaims returns a copy of claims with internal keys removed. +func filterInternalClaims(claims jwt.MapClaims) jwt.MapClaims { + filtered := make(jwt.MapClaims, len(claims)) + for k, v := range claims { + filtered[k] = v + } + for _, key := range internalClaims { + delete(filtered, key) + } + return filtered +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 7d5e369674..3d5de73c8c 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -25,6 +25,7 @@ import ( "github.com/stacklok/toolhive-core/env" "github.com/stacklok/toolhive/pkg/auth/oauth" + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/networking" oauthproto "github.com/stacklok/toolhive/pkg/oauth" ) @@ -367,6 +368,10 @@ type TokenValidator struct { registry *Registry // Token introspection providers insecureAllowHTTP bool // Allow HTTP (non-HTTPS) OIDC issuers for development/testing + // upstreamTokenReader loads upstream provider tokens for identity enrichment. + // nil means no enrichment (no embedded auth server). + upstreamTokenReader upstreamtoken.UpstreamTokenReader + // Lazy JWKS registration jwksRegistered bool jwksRegistrationMu sync.Mutex @@ -540,7 +545,8 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st // tokenValidatorOptions holds optional dependencies for NewTokenValidator. type tokenValidatorOptions struct { - envReader env.Reader + envReader env.Reader + upstreamTokenReader upstreamtoken.UpstreamTokenReader } // TokenValidatorOption is a functional option for NewTokenValidator. @@ -554,6 +560,16 @@ func WithEnvReader(reader env.Reader) TokenValidatorOption { } } +// WithUpstreamTokenReader configures the token validator to enrich Identity +// with upstream provider tokens. When set, the Middleware extracts the token +// session ID (tsid) from JWT claims and loads all upstream tokens into +// Identity.UpstreamTokens before placing the Identity in the request context. +func WithUpstreamTokenReader(reader upstreamtoken.UpstreamTokenReader) TokenValidatorOption { + return func(o *tokenValidatorOptions) { + o.upstreamTokenReader = reader + } +} + // NewTokenValidator creates a new token validator. func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ...TokenValidatorOption) (*TokenValidator, error) { // Apply functional options @@ -638,18 +654,19 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts .. } validator := &TokenValidator{ - issuer: config.Issuer, - audience: config.Audience, - jwksURL: jwksURL, - introspectURL: config.IntrospectionURL, - clientID: config.ClientID, - clientSecret: clientSecret, - jwksClient: cache, - client: config.httpClient, - resourceURL: config.ResourceURL, - scopes: config.Scopes, - registry: registry, - insecureAllowHTTP: config.InsecureAllowHTTP, + issuer: config.Issuer, + audience: config.Audience, + jwksURL: jwksURL, + introspectURL: config.IntrospectionURL, + clientID: config.ClientID, + clientSecret: clientSecret, + jwksClient: cache, + client: config.httpClient, + resourceURL: config.ResourceURL, + scopes: config.Scopes, + registry: registry, + insecureAllowHTTP: config.InsecureAllowHTTP, + upstreamTokenReader: o.upstreamTokenReader, } return validator, nil @@ -1064,6 +1081,24 @@ func writeOAuthError(w http.ResponseWriter, errorCode, description string, statu _, _ = w.Write(body) } +// loadUpstreamTokens extracts the token session ID from claims and loads +// all upstream provider tokens for that session. Returns (nil, nil) if no +// tsid claim exists. Returns a non-nil error when a tsid claim is present +// but token loading fails (infrastructure error). +func (v *TokenValidator) loadUpstreamTokens(ctx context.Context, claims jwt.MapClaims) (map[string]string, error) { + tsid, ok := claims[upstreamtoken.TokenSessionIDClaimKey].(string) + if !ok || tsid == "" { + return nil, nil + } + + tokens, err := v.upstreamTokenReader.GetAllValidTokens(ctx, tsid) + if err != nil { + return nil, fmt.Errorf("load upstream tokens for session %s: %w", tsid, err) + } + + return tokens, nil +} + // Middleware creates an HTTP middleware that validates JWT tokens and creates Identity. func (v *TokenValidator) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1092,6 +1127,20 @@ func (v *TokenValidator) Middleware(next http.Handler) http.Handler { return } + // Enrich Identity with upstream provider tokens when an embedded + // auth server is active (reader configured via WithUpstreamTokenReader). + if v.upstreamTokenReader != nil { + tokens, loadErr := v.loadUpstreamTokens(r.Context(), claims) + if loadErr != nil { + slog.WarnContext(r.Context(), "upstream token storage unavailable", + "error", loadErr, + ) + http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) + return + } + identity.UpstreamTokens = tokens + } + // Add the Identity to the request context ctx := WithIdentity(r.Context(), identity) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index b6006f369c..5ba11df27a 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -24,6 +24,7 @@ import ( "go.uber.org/mock/gomock" envmocks "github.com/stacklok/toolhive-core/env/mocks" + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/networking" oauthproto "github.com/stacklok/toolhive/pkg/oauth" ) @@ -2223,3 +2224,281 @@ func TestMiddleware_RFC6750JSONErrorResponse(t *testing.T) { }) } } + +func TestLoadUpstreamTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + claims jwt.MapClaims + reader upstreamtoken.UpstreamTokenReader + wantResult map[string]string + wantErr bool + }{ + { + name: "loads tokens when tsid present", + claims: jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }, + reader: &mockUpstreamTokenReader{ + tokens: map[string]string{ + "github": "gh-token", + "atlassian": "atl-token", + }, + }, + wantResult: map[string]string{ + "github": "gh-token", + "atlassian": "atl-token", + }, + }, + { + name: "returns nil when no tsid claim", + claims: jwt.MapClaims{ + "sub": "user123", + }, + reader: &mockUpstreamTokenReader{}, + wantResult: nil, + }, + { + name: "returns nil when tsid is empty string", + claims: jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "", + }, + reader: &mockUpstreamTokenReader{}, + wantResult: nil, + }, + { + name: "returns nil when tsid is non-string type", + claims: jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: 12345, + }, + reader: &mockUpstreamTokenReader{}, + wantResult: nil, + }, + { + name: "returns error when reader fails", + claims: jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }, + reader: &mockUpstreamTokenReader{ + err: errors.New("storage unavailable"), + }, + wantErr: true, + }, + { + name: "returns empty map from reader", + claims: jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }, + reader: &mockUpstreamTokenReader{ + tokens: map[string]string{}, + }, + wantResult: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + v := &TokenValidator{ + upstreamTokenReader: tt.reader, + } + + result, err := v.loadUpstreamTokens(context.Background(), tt.claims) + + if tt.wantErr { + require.Error(t, err) + require.Nil(t, result) + return + } + + require.NoError(t, err) + if tt.wantResult == nil { + require.Nil(t, result) + } else { + require.Equal(t, tt.wantResult, result) + } + }) + } +} + +func TestLoadUpstreamTokens_PassesCorrectSessionID(t *testing.T) { + t.Parallel() + + reader := &mockUpstreamTokenReader{ + tokens: map[string]string{"github": "token"}, + } + + v := &TokenValidator{ + upstreamTokenReader: reader, + } + + claims := jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-xyz", + } + + result, err := v.loadUpstreamTokens(context.Background(), claims) + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "session-xyz", reader.calledWith) +} + +// mockUpstreamTokenReader is a simple mock for testing loadUpstreamTokens. +type mockUpstreamTokenReader struct { + tokens map[string]string + err error + calledWith string +} + +func (m *mockUpstreamTokenReader) GetAllValidTokens(_ context.Context, sessionID string) (map[string]string, error) { + m.calledWith = sessionID + return m.tokens, m.err +} + +func TestWithUpstreamTokenReader(t *testing.T) { + t.Parallel() + + reader := &mockUpstreamTokenReader{} + opt := WithUpstreamTokenReader(reader) + + o := &tokenValidatorOptions{} + opt(o) + + require.Equal(t, reader, o.upstreamTokenReader) +} + +// TestMiddleware_UpstreamTokenEnrichment verifies the full middleware pipeline: +// JWT validation → tsid extraction → token loading → Identity.UpstreamTokens. +func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) { + t.Parallel() + + // Shared JWKS infrastructure + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + key, err := jwk.Import(&privateKey.PublicKey) + require.NoError(t, err) + require.NoError(t, key.Set(jwk.KeyIDKey, testKeyID)) + require.NoError(t, key.Set(jwk.AlgorithmKey, "RS256")) + require.NoError(t, key.Set(jwk.KeyUsageKey, "sig")) + + keySet := jwk.NewSet() + require.NoError(t, keySet.AddKey(key)) + jwksServer, caCertPath := createTestJWKSServer(t, keySet) + t.Cleanup(jwksServer.Close) + + makeValidator := func(t *testing.T, opts ...TokenValidatorOption) *TokenValidator { + t.Helper() + v, vErr := NewTokenValidator(context.Background(), TokenValidatorConfig{ + Issuer: "test-issuer", Audience: "test-audience", + JWKSURL: jwksServer.URL, ClientID: "test-client", + CACertPath: caCertPath, AllowPrivateIP: true, + }, opts...) + require.NoError(t, vErr) + require.NoError(t, v.ensureJWKSRegistered(context.Background())) + _, lErr := v.jwksClient.Lookup(context.Background(), jwksServer.URL) + require.NoError(t, lErr) + return v + } + + signToken := func(claims jwt.MapClaims) string { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = testKeyID + s, sErr := tok.SignedString(privateKey) + require.NoError(t, sErr) + return s + } + + claimsWithTsid := jwt.MapClaims{ + "iss": "test-issuer", "aud": "test-audience", "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + upstreamtoken.TokenSessionIDClaimKey: "session-xyz", + } + + t.Run("enriches identity with upstream tokens", func(t *testing.T) { + t.Parallel() + reader := &mockUpstreamTokenReader{tokens: map[string]string{"github": "gh-tok"}} + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, map[string]string{"github": "gh-tok"}, captured.UpstreamTokens) + }) + + t.Run("returns 503 when storage fails", func(t *testing.T) { + t.Parallel() + reader := &mockUpstreamTokenReader{err: errors.New("redis down")} + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + nextCalled := false + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + require.False(t, nextCalled) + }) + + t.Run("no enrichment without tsid", func(t *testing.T) { + t.Parallel() + reader := &mockUpstreamTokenReader{tokens: map[string]string{"github": "should-not-appear"}} + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + noTsid := jwt.MapClaims{ + "iss": "test-issuer", "aud": "test-audience", "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + } + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(noTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Nil(t, captured.UpstreamTokens) + }) + + t.Run("no enrichment when reader is nil", func(t *testing.T) { + t.Parallel() + v := makeValidator(t) // no WithUpstreamTokenReader + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Nil(t, captured.UpstreamTokens) + }) +} From c978192984612fbe2721600b6ce0fa3f8d7b48cf Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 10:14:01 +0000 Subject: [PATCH 5/8] Simplify upstreamswap to read from Identity.UpstreamTokens Replace the ServiceGetter/GetValidTokens pattern with a direct read from the pre-enriched Identity.UpstreamTokens map. The auth middleware now handles storage reads and token refresh, so upstreamswap becomes a pure reader that looks up the provider token and injects it into the request header. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/middleware.md | 27 +- pkg/auth/upstreamswap/middleware.go | 78 +-- pkg/auth/upstreamswap/middleware_test.go | 660 ++--------------------- 3 files changed, 76 insertions(+), 689 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 28cc1bfc00..8ab0eedc73 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -133,11 +133,9 @@ sequenceDiagram **Availability**: Automatically enabled when using the embedded auth server (`EmbeddedAuthServerConfig`) **Responsibilities**: -- Extract the token session ID (`tsid`) claim from the ToolHive JWT -- Look up the stored upstream IdP tokens associated with that session +- Read the upstream access token for the configured provider from `Identity.UpstreamTokens` - Inject the upstream access token into the request (replacing Authorization header or using a custom header) -- Return 401 Unauthorized with WWW-Authenticate header when tokens are expired or not found -- Gracefully proceed without modification if identity, session ID, or storage is unavailable +- Return 401 Unauthorized with WWW-Authenticate header when the provider token is missing or empty **Configuration**: @@ -148,16 +146,16 @@ sequenceDiagram **Behavior**: - **Automatic activation**: Enabled whenever the embedded auth server is configured, even without explicit `UpstreamSwapConfig` -- **Expired tokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required -- **Tokens not found**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required -- **Missing identity/tsid**: Proceeds without modification (logs debug message) -- **Storage unavailable**: Proceeds without modification (logs warning) -- **Other storage errors**: Returns 503 Service Unavailable to fail closed (logs warning) +- **Provider token found**: Injects the token into the request using the configured header strategy +- **Provider not in UpstreamTokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required +- **Empty token value**: Returns 401 Unauthorized (same as missing provider) +- **No identity in context**: Passes through without modification (auth middleware not in chain) +- **Storage unavailable**: The auth middleware returns 503 before the request reaches this middleware **Context Data Used**: -- Identity from Authentication middleware (specifically the `tsid` claim) +- `Identity.UpstreamTokens` map populated by the Authentication middleware during JWT validation -**Note**: This middleware is designed for use with ToolHive's embedded auth server. It reads from the auth server's token storage to retrieve upstream IdP tokens that were captured during the OAuth authorization flow. +**Note**: This middleware is a simple map reader. All upstream token loading, refresh, and error handling occurs in the Authentication middleware (Step 3), which populates `Identity.UpstreamTokens` from the token session ID (`tsid`) claim during JWT validation. --- @@ -785,10 +783,9 @@ type MiddlewareRunner interface { // GetConfig returns a config interface for middleware to access runner configuration GetConfig() RunnerConfig - // GetUpstreamTokenService returns a lazy accessor for the upstream token service. - // Returns a function that provides the service at request time. - // Used by upstream swap middleware to get valid upstream tokens (with transparent refresh). - GetUpstreamTokenService() func() upstreamtoken.Service + // GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment. + // Returns nil if the embedded auth server is not configured. + GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader } ``` diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index ca4edd1fbf..d4f515aafb 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -7,13 +7,10 @@ package upstreamswap import ( "encoding/json" - "errors" "fmt" - "log/slog" "net/http" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -46,10 +43,6 @@ type MiddlewareParams struct { Config *Config `json:"config,omitempty"` } -// ServiceGetter is a function that returns an upstream token service. -// It returns nil when the service is unavailable (e.g., auth server not configured). -type ServiceGetter func() upstreamtoken.Service - // Middleware wraps the upstream swap middleware functionality. type Middleware struct { middleware types.MiddlewareFunction @@ -83,10 +76,7 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid upstream swap configuration: %w", err) } - // Get the lazy service accessor from the runner. - serviceGetter := ServiceGetter(runner.GetUpstreamTokenService()) - - middleware := createMiddlewareFunc(cfg, serviceGetter) + middleware := createMiddlewareFunc(cfg) upstreamSwapMw := &Middleware{ middleware: middleware, @@ -146,7 +136,9 @@ func createCustomInjector(headerName string) injectionFunc { } // createMiddlewareFunc creates the actual middleware function. -func createMiddlewareFunc(cfg *Config, serviceGetter ServiceGetter) types.MiddlewareFunction { +// It reads upstream tokens from Identity.UpstreamTokens, which are populated +// during JWT validation by the auth middleware (Step 3). +func createMiddlewareFunc(cfg *Config) types.MiddlewareFunction { // Determine injection strategy at startup time strategy := cfg.HeaderStrategy if strategy == "" { @@ -166,73 +158,19 @@ func createMiddlewareFunc(cfg *Config, serviceGetter ServiceGetter) types.Middle return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 1. Get identity from auth middleware identity, ok := auth.IdentityFromContext(r.Context()) if !ok { - slog.Debug("No identity in context, proceeding without swap", - "middleware", "upstreamswap") next.ServeHTTP(w, r) return } - // 2. Extract tsid from claims - tsid, ok := identity.Claims[upstreamtoken.TokenSessionIDClaimKey].(string) - if !ok || tsid == "" { - slog.Debug("No tsid claim in identity, proceeding without swap", - "middleware", "upstreamswap") - next.ServeHTTP(w, r) + token, exists := identity.UpstreamTokens[cfg.ProviderName] + if !exists || token == "" { + writeUpstreamAuthRequired(w) return } - // 3. Get token service — fail closed if unavailable. - // The tsid claim confirms this request expects upstream token injection; - // passing through with the original JWT would leak it to the backend. - svc := serviceGetter() - if svc == nil { - slog.Warn("Token service unavailable, cannot perform required upstream swap", - "middleware", "upstreamswap") - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - // 4. Get valid upstream tokens (with transparent refresh) - cred, err := svc.GetValidTokens(r.Context(), tsid, cfg.ProviderName) - if err != nil { - // Client-attributable errors require re-authentication. - if errors.Is(err, upstreamtoken.ErrSessionNotFound) || - errors.Is(err, upstreamtoken.ErrNoRefreshToken) || - errors.Is(err, upstreamtoken.ErrRefreshFailed) || - errors.Is(err, upstreamtoken.ErrInvalidBinding) { - slog.Debug("Upstream token needs re-authentication", - "middleware", "upstreamswap", - "provider", cfg.ProviderName, - "error", err) - writeUpstreamAuthRequired(w) - return - } - // Other errors: fail closed to avoid bypassing the token swap - slog.Warn("Failed to get upstream tokens", - "middleware", "upstreamswap", - "provider", cfg.ProviderName, - "error", err) - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - // 5. Inject access token — fail closed if empty to prevent bypassing the swap - if cred.AccessToken == "" { - slog.Warn("Upstream token service returned empty access token", - "middleware", "upstreamswap", - "provider", cfg.ProviderName) - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - injectToken(r, cred.AccessToken) - slog.Debug("Injected upstream access token", - "middleware", "upstreamswap", - "provider", cfg.ProviderName) - + injectToken(r, token) next.ServeHTTP(w, r) }) } diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 39454e36ca..3f298d3b98 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -6,51 +6,28 @@ package upstreamswap import ( "context" "encoding/json" - "errors" "net/http" "net/http/httptest" - "sync" - "sync/atomic" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" - "github.com/stacklok/toolhive/pkg/authserver/storage" - storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) -// serviceGetterFromMocks creates a ServiceGetter that wraps mock storage and refresher -// in an InProcessService. This is the standard pattern for middleware tests. -func serviceGetterFromMocks( - stor storage.UpstreamTokenStorage, - refresher storage.UpstreamTokenRefresher, -) ServiceGetter { - svc := upstreamtoken.NewInProcessService(stor, refresher) - return func() upstreamtoken.Service { return svc } -} - -// nilServiceGetter returns a ServiceGetter that always returns nil (service unavailable). -func nilServiceGetter() ServiceGetter { - return func() upstreamtoken.Service { return nil } -} - // requestWithIdentity creates an HTTP request with the given identity in context. -func requestWithIdentity(tsid string) *http.Request { +func requestWithIdentity(providerName, token string) *http.Request { req := httptest.NewRequest(http.MethodGet, "/test", nil) identity := &auth.Identity{ PrincipalInfo: auth.PrincipalInfo{ Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: tsid, - }, + }, + UpstreamTokens: map[string]string{ + providerName: token, }, } ctx := auth.WithIdentity(req.Context(), identity) @@ -126,13 +103,12 @@ func TestValidateConfig(t *testing.T) { func TestMiddleware_NoIdentity(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { nextCalled = true - // Verify Authorization header was NOT modified assert.Empty(t, r.Header.Get("Authorization")) }) @@ -143,31 +119,25 @@ func TestMiddleware_NoIdentity(t *testing.T) { handler.ServeHTTP(rr, req) assert.True(t, nextCalled, "next handler should be called") + assert.Equal(t, http.StatusOK, rr.Code) } -func TestMiddleware_NoTsidClaim(t *testing.T) { +func TestMiddleware_NilUpstreamTokens(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when upstream tokens are nil") }) handler := middleware(nextHandler) - // Create request with identity but no tsid claim + // Identity exists but UpstreamTokens is nil (not populated by auth middleware) req := httptest.NewRequest(http.MethodGet, "/test", nil) identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - // No tsid claim - }, - }, + PrincipalInfo: auth.PrincipalInfo{Subject: "user123"}, } ctx := auth.WithIdentity(req.Context(), identity) req = req.WithContext(ctx) @@ -175,142 +145,56 @@ func TestMiddleware_NoTsidClaim(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.True(t, nextCalled, "next handler should be called") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) } -func TestMiddleware_ServiceUnavailable_FailsClosed(t *testing.T) { +func TestMiddleware_ProviderMissing_Returns401(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "atlassian"} + middleware := createMiddlewareFunc(cfg) - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when provider is missing") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("github", "gh-token") // has github but not atlassian rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called when service is unavailable") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) -} - -func TestMiddleware_ClientAttributableErrors_Returns401(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupStorage func(*storagemocks.MockUpstreamTokenStorage) - }{ - { - name: "not found", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrNotFound) - }, - }, - { - name: "expired with nil tokens", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrExpired) - }, - }, - { - name: "invalid binding", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrInvalidBinding) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - tt.setupStorage(mockStorage) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) - }) - } + assert.Equal(t, http.StatusUnauthorized, rr.Code) } -func TestMiddleware_StorageError(t *testing.T) { +func TestMiddleware_EmptyToken_Returns401(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, errors.New("database error")) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when token is empty") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("github", "") // empty token rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called on storage error") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Equal(t, http.StatusUnauthorized, rr.Code) } -func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { +func TestMiddleware_SuccessfulSwap_ReplaceStrategy(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "upstream-access-token", - IDToken: "upstream-id-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - cfg := &Config{ HeaderStrategy: HeaderStrategyReplace, - ProviderName: "default", + ProviderName: "github", } - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + middleware := createMiddlewareFunc(cfg) var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -318,7 +202,7 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("github", "upstream-access-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -326,118 +210,56 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { assert.Equal(t, "Bearer upstream-access-token", capturedAuthHeader) } -func TestMiddleware_CustomHeader(t *testing.T) { +func TestMiddleware_SuccessfulSwap_DefaultStrategy(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "upstream-access-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - cfg := &Config{ - HeaderStrategy: HeaderStrategyCustom, - CustomHeaderName: "X-Upstream-Token", - ProviderName: "default", + ProviderName: "github", + // HeaderStrategy intentionally empty — defaults to replace } - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + middleware := createMiddlewareFunc(cfg) - var capturedCustomHeader string var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedCustomHeader = r.Header.Get("X-Upstream-Token") capturedAuthHeader = r.Header.Get("Authorization") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - req.Header.Set("Authorization", "Bearer original-token") + req := requestWithIdentity("github", "default-strategy-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.Equal(t, "Bearer upstream-access-token", capturedCustomHeader) - // Original Authorization header should remain unchanged - assert.Equal(t, "Bearer original-token", capturedAuthHeader) -} - -func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Token is expired and has no refresh token - tokens := &storage.UpstreamTokens{ - AccessToken: "expired-upstream-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired 1 hour ago - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called for expired tokens") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "Bearer default-strategy-token", capturedAuthHeader) } -func TestMiddleware_EmptySelectedToken_FailsClosed(t *testing.T) { +func TestMiddleware_SuccessfulSwap_CustomHeader(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Tokens exist but the access token is empty - tokens := &storage.UpstreamTokens{ - AccessToken: "", // Empty access token - IDToken: "upstream-id-token", - ExpiresAt: time.Now().Add(1 * time.Hour), + cfg := &Config{ + HeaderStrategy: HeaderStrategyCustom, + CustomHeaderName: "X-Upstream-Token", + ProviderName: "github", } + middleware := createMiddlewareFunc(cfg) - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + var capturedCustomHeader string + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedCustomHeader = r.Header.Get("X-Upstream-Token") + capturedAuthHeader = r.Header.Get("Authorization") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("github", "upstream-access-token") + req.Header.Set("Authorization", "Bearer original-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called when access token is empty") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Equal(t, "Bearer upstream-access-token", capturedCustomHeader) + assert.Equal(t, "Bearer original-token", capturedAuthHeader) } func TestMiddleware_Close(t *testing.T) { @@ -483,35 +305,20 @@ func TestCreateInjectors(t *testing.T) { func TestMiddlewareWithContext(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "ctx-test-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-ctx", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - // Test that context is properly passed through var receivedCtx context.Context nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { receivedCtx = r.Context() }) handler := middleware(nextHandler) - req := requestWithIdentity("session-ctx") + req := requestWithIdentity("github", "ctx-test-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - // Verify identity is still in context identityFromCtx, ok := auth.IdentityFromContext(receivedCtx) assert.True(t, ok) assert.Equal(t, "user123", identityFromCtx.Subject) @@ -601,12 +408,7 @@ func TestCreateMiddleware(t *testing.T) { mockRunner := mocks.NewMockMiddlewareRunner(ctrl) - // Service getter is only called if validation passes (expectAddMiddleware) - // because validation happens before GetUpstreamTokenService is called if tt.expectAddMiddleware { - mockRunner.EXPECT().GetUpstreamTokenService().Return(func() upstreamtoken.Service { - return nil // Service availability is checked at request time - }) mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(_ string, mw types.Middleware) { _, ok := mw.(*Middleware) assert.True(t, ok, "Expected middleware to be of type *upstreamswap.Middleware") @@ -652,353 +454,3 @@ func TestCreateMiddleware_InvalidJSON(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "failed to unmarshal upstream swap middleware parameters") } - -// TestMiddleware_TsidClaimWrongType tests behavior when tsid claim exists but is wrong type. -func TestMiddleware_TsidClaimWrongType(t *testing.T) { - t.Parallel() - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - - // Create request with identity but tsid claim is wrong type (int instead of string) - req := httptest.NewRequest(http.MethodGet, "/test", nil) - identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: 12345, // Wrong type: int instead of string - }, - }, - } - ctx := auth.WithIdentity(req.Context(), identity) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called when tsid is wrong type") -} - -func TestMiddleware_ExpiredTokens_RefreshSuccess(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "new-access-token", - RefreshToken: "new-refresh-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(refreshedTokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - var capturedAuthHeader string - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - nextCalled = true - capturedAuthHeader = r.Header.Get("Authorization") - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called after successful refresh") - assert.Equal(t, "Bearer new-access-token", capturedAuthHeader) -} - -func TestMiddleware_ExpiredTokens_RefreshFails(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(nil, errors.New("refresh failed")) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called when refresh fails") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) -} - -func TestMiddleware_ExpiredTokens_NoRefreshToken(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "", // No refresh token - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called when no refresh token available") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) -} - -func TestMiddleware_DefenseInDepth_ExpiredButNoError_RefreshSuccess(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Storage returns tokens with ExpiresAt in the past but NO error - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "refreshed-access-token", - RefreshToken: "refreshed-refresh-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, nil) // No error, but token is expired - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(refreshedTokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - var capturedAuthHeader string - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - nextCalled = true - capturedAuthHeader = r.Header.Get("Authorization") - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called after successful defense-in-depth refresh") - assert.Equal(t, "Bearer refreshed-access-token", capturedAuthHeader) -} - -// TestSingleFlightRefresh_ConcurrentRequests verifies that concurrent requests -// with the same expired session only trigger a single upstream refresh call. -// Without singleflight, providers that rotate refresh tokens (single-use) -// would fail all but the first concurrent caller. -func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { - t.Parallel() - - const numRequests = 10 - var refreshCallCount atomic.Int32 - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access", - RefreshToken: "one-time-refresh", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "fresh-access", - RefreshToken: "new-refresh", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - // Use a barrier so all goroutines complete GetUpstreamTokens before any - // enters singleflight. This guarantees they all contend on the same key. - var storageBarrier sync.WaitGroup - storageBarrier.Add(numRequests) - storageGate := make(chan struct{}) - - stor := &fakeTokenStorage{ - tokens: expiredTokens, - err: storage.ErrExpired, - barrier: &storageBarrier, - gate: storageGate, - } - proceed := make(chan struct{}) - refresher := &fakeRefresher{ - result: refreshedTokens, - callCount: &refreshCallCount, - proceed: proceed, - } - - // Create service with fake storage/refresher — singleflight lives in the service - svc := upstreamtoken.NewInProcessService(stor, refresher) - serviceGetter := func() upstreamtoken.Service { return svc } - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetter) - - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) - handler := middleware(nextHandler) - - // Use a barrier to ensure all goroutines start at the same time - ready := make(chan struct{}) - var wg sync.WaitGroup - results := make([]int, numRequests) - - for i := range numRequests { - wg.Add(1) - go func(idx int) { - defer wg.Done() - <-ready // Wait for all goroutines to be ready - - req := requestWithIdentity("sf-concurrent-session") - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - results[idx] = rr.Code - }(i) - } - - // Release all goroutines simultaneously - close(ready) - - // Wait for ALL goroutines to reach GetUpstreamTokens - storageBarrier.Wait() - - // Release them all at once — they all proceed to singleflight.Do concurrently - close(storageGate) - - // Give goroutines a moment to enter singleflight, then let refresh complete - // The singleflight ensures only one actually calls RefreshAndStore - time.Sleep(10 * time.Millisecond) - close(proceed) - - wg.Wait() - - // KEY ASSERTION: RefreshAndStore should be called exactly once. - // Without singleflight, all 10 goroutines would call it independently. - assert.Equal(t, int32(1), refreshCallCount.Load(), - "RefreshAndStore should be called exactly once — singleflight deduplicates concurrent refreshes") - - // All requests should succeed - for i, code := range results { - assert.Equal(t, http.StatusOK, code, - "request %d should succeed", i) - } -} - -// fakeTokenStorage returns configured tokens and optionally blocks until -// a barrier is released, ensuring all goroutines reach storage before any -// proceeds to the singleflight refresh. -type fakeTokenStorage struct { - tokens *storage.UpstreamTokens - err error - barrier *sync.WaitGroup // if set, each call does barrier.Done() then waits - gate chan struct{} // if set, blocks until closed -} - -func (f *fakeTokenStorage) GetUpstreamTokens(_ context.Context, _, _ string) (*storage.UpstreamTokens, error) { - if f.barrier != nil { - f.barrier.Done() - } - if f.gate != nil { - <-f.gate - } - return f.tokens, f.err -} - -func (*fakeTokenStorage) GetAllUpstreamTokens(_ context.Context, _ string) (map[string]*storage.UpstreamTokens, error) { - return nil, nil -} - -func (*fakeTokenStorage) StoreUpstreamTokens(_ context.Context, _, _ string, _ *storage.UpstreamTokens) error { - return nil -} - -func (*fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error { - return nil -} - -// fakeRefresher counts calls and blocks until proceed is closed. -type fakeRefresher struct { - result *storage.UpstreamTokens - callCount *atomic.Int32 - proceed chan struct{} -} - -func (f *fakeRefresher) RefreshAndStore(_ context.Context, _ string, _ *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { - f.callCount.Add(1) - <-f.proceed - return f.result, nil -} From d12dd31ee6c65986a44674a5847955b556bf1249 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 19 Mar 2026 10:30:12 +0000 Subject: [PATCH 6/8] Remove GetUpstreamTokenService from MiddlewareRunner The upstreamswap middleware no longer calls the token service directly, so GetUpstreamTokenService is dead code. Remove the interface method, its Runner implementation, the mock, tests, and docs reference. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/runner/runner.go | 25 ++-------- pkg/runner/runner_test.go | 51 --------------------- pkg/transport/types/mocks/mock_transport.go | 14 ------ pkg/transport/types/transport.go | 8 ---- 4 files changed, 3 insertions(+), 95 deletions(-) diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 2159547167..e7834a594f 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -76,14 +76,9 @@ type Runner struct { // Only initialized when Config.EmbeddedAuthServerConfig is set. embeddedAuthServer *authserverrunner.EmbeddedAuthServer - // upstreamTokenService is the upstream token service, created eagerly - // after the embedded auth server is initialized in Run(). - // Nil when no embedded auth server is configured. - upstreamTokenService upstreamtoken.Service - // upstreamTokenReader provides read-only access to upstream tokens for - // identity enrichment in auth middleware. Set alongside upstreamTokenService - // when the embedded auth server is initialized in Run(). + // identity enrichment in auth middleware. Set when the embedded auth + // server is initialized in Run(). // Nil when no embedded auth server is configured. upstreamTokenReader upstreamtoken.UpstreamTokenReader } @@ -136,18 +131,6 @@ func (r *Runner) GetConfig() types.RunnerConfig { return r.Config } -// GetUpstreamTokenService returns an accessor for the upstream token service. -// The returned function should be called at request time; it returns nil if -// the embedded auth server is not configured. -// -// This method always returns a non-nil function. Service availability is -// determined at request time when the returned function is called. -func (r *Runner) GetUpstreamTokenService() func() upstreamtoken.Service { - return func() upstreamtoken.Service { - return r.upstreamTokenService - } -} - // GetUpstreamTokenReader returns the UpstreamTokenReader for identity // enrichment in the auth middleware. Returns nil if no embedded auth // server is configured. @@ -271,9 +254,7 @@ func (r *Runner) Run(ctx context.Context) error { // InProcessService handles this gracefully (returns ErrNoRefreshToken). stor := r.embeddedAuthServer.IDPTokenStorage() refresher := r.embeddedAuthServer.UpstreamTokenRefresher() - inProc := upstreamtoken.NewInProcessService(stor, refresher) - r.upstreamTokenService = inProc - r.upstreamTokenReader = inProc + r.upstreamTokenReader = upstreamtoken.NewInProcessService(stor, refresher) // Mount auth server routes at specific prefixes to avoid conflicts with MCP endpoints // (e.g., /.well-known/oauth-protected-resource is an MCP endpoint, not auth server) diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index ef9ac4bb18..03eecdd394 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -464,57 +464,6 @@ func TestRunner_EmbeddedAuthServer_Integration(t *testing.T) { }) } -func TestRunner_GetUpstreamTokenService(t *testing.T) { - t.Parallel() - - t.Run("returns nil when no auth server configured", func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusManager := statusesmocks.NewMockStatusManager(ctrl) - - runConfig := NewRunConfig() - runner := NewRunner(runConfig, mockStatusManager) - - serviceGetter := runner.GetUpstreamTokenService() - assert.NotNil(t, serviceGetter, "GetUpstreamTokenService should always return a non-nil function") - - svc := serviceGetter() - assert.Nil(t, svc, "service should be nil when no auth server is configured") - }) - - t.Run("returns service when upstreamTokenService is set", func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusManager := statusesmocks.NewMockStatusManager(ctrl) - - runConfig := NewRunConfig() - runner := NewRunner(runConfig, mockStatusManager) - - // Simulate what Run() does: create auth server, then set the service - authServerCfg := createMinimalAuthServerConfig() - embeddedServer, err := authserverrunner.NewEmbeddedAuthServer(context.Background(), authServerCfg) - require.NoError(t, err) - require.NotNil(t, embeddedServer) - defer func() { _ = embeddedServer.Close() }() - - runner.embeddedAuthServer = embeddedServer - - stor := embeddedServer.IDPTokenStorage() - refresher := embeddedServer.UpstreamTokenRefresher() - runner.upstreamTokenService = upstreamtoken.NewInProcessService(stor, refresher) - - serviceGetter := runner.GetUpstreamTokenService() - svc := serviceGetter() - assert.NotNil(t, svc, "service should not be nil when upstreamTokenService is set") - }) -} - func TestRunner_RejectsMultiUpstreamConfig(t *testing.T) { t.Parallel() diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index 7ab524ffc2..c277459484 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -137,20 +137,6 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenReader() *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenReader", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenReader)) } -// GetUpstreamTokenService mocks base method. -func (m *MockMiddlewareRunner) GetUpstreamTokenService() func() upstreamtoken.Service { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUpstreamTokenService") - ret0, _ := ret[0].(func() upstreamtoken.Service) - return ret0 -} - -// GetUpstreamTokenService indicates an expected call of GetUpstreamTokenService. -func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenService() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenService", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenService)) -} - // SetAuthInfoHandler mocks base method. func (m *MockMiddlewareRunner) SetAuthInfoHandler(handler http.Handler) { m.ctrl.T.Helper() diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index 976c902a92..1b232cfa4b 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -77,14 +77,6 @@ type MiddlewareRunner interface { // GetConfig returns a config interface for middleware to access runner configuration GetConfig() RunnerConfig - // GetUpstreamTokenService returns an accessor for the upstream token service. - // The returned function should be called at request time; it returns nil if - // the embedded auth server is not configured. - // - // This method always returns a non-nil function. Service availability is - // determined at request time when the returned function is called. - GetUpstreamTokenService() func() upstreamtoken.Service - // GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment. // Returns nil if the embedded auth server is not configured. GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader From 5e6e907926b458ba1ea584923e709fbddf6f6c93 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 25 Mar 2026 15:43:05 +0000 Subject: [PATCH 7/8] Fix lint: rename UpstreamTokenReader to TokenReader, fix unparam Rename the type to avoid stutter (upstreamtoken.TokenReader instead of upstreamtoken.UpstreamTokenReader). Method and function names that include "UpstreamTokenReader" are unchanged since they don't stutter. Also remove the unused providerName parameter from the test helper requestWithIdentity in upstreamswap tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/token.go | 6 +++--- pkg/auth/upstreamswap/middleware_test.go | 18 +++++++++--------- pkg/auth/upstreamtoken/service.go | 10 +++++++--- pkg/auth/upstreamtoken/types.go | 6 ++++-- pkg/runner/runner.go | 4 ++-- pkg/transport/types/mocks/mock_transport.go | 4 ++-- pkg/transport/types/transport.go | 4 ++-- 7 files changed, 29 insertions(+), 23 deletions(-) diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 3d5de73c8c..5c955e499a 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -370,7 +370,7 @@ type TokenValidator struct { // upstreamTokenReader loads upstream provider tokens for identity enrichment. // nil means no enrichment (no embedded auth server). - upstreamTokenReader upstreamtoken.UpstreamTokenReader + upstreamTokenReader upstreamtoken.TokenReader // Lazy JWKS registration jwksRegistered bool @@ -546,7 +546,7 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st // tokenValidatorOptions holds optional dependencies for NewTokenValidator. type tokenValidatorOptions struct { envReader env.Reader - upstreamTokenReader upstreamtoken.UpstreamTokenReader + upstreamTokenReader upstreamtoken.TokenReader } // TokenValidatorOption is a functional option for NewTokenValidator. @@ -564,7 +564,7 @@ func WithEnvReader(reader env.Reader) TokenValidatorOption { // with upstream provider tokens. When set, the Middleware extracts the token // session ID (tsid) from JWT claims and loads all upstream tokens into // Identity.UpstreamTokens before placing the Identity in the request context. -func WithUpstreamTokenReader(reader upstreamtoken.UpstreamTokenReader) TokenValidatorOption { +func WithUpstreamTokenReader(reader upstreamtoken.TokenReader) TokenValidatorOption { return func(o *tokenValidatorOptions) { o.upstreamTokenReader = reader } diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 3f298d3b98..9594c3b9f3 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -19,15 +19,15 @@ import ( "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) -// requestWithIdentity creates an HTTP request with the given identity in context. -func requestWithIdentity(providerName, token string) *http.Request { +// requestWithIdentity creates an HTTP request with a "github" upstream token in context. +func requestWithIdentity(token string) *http.Request { req := httptest.NewRequest(http.MethodGet, "/test", nil) identity := &auth.Identity{ PrincipalInfo: auth.PrincipalInfo{ Subject: "user123", }, UpstreamTokens: map[string]string{ - providerName: token, + "github": token, }, } ctx := auth.WithIdentity(req.Context(), identity) @@ -160,7 +160,7 @@ func TestMiddleware_ProviderMissing_Returns401(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "gh-token") // has github but not atlassian + req := requestWithIdentity("gh-token") // has github but not atlassian rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -179,7 +179,7 @@ func TestMiddleware_EmptyToken_Returns401(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "") // empty token + req := requestWithIdentity("") // empty token rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -202,7 +202,7 @@ func TestMiddleware_SuccessfulSwap_ReplaceStrategy(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "upstream-access-token") + req := requestWithIdentity("upstream-access-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -225,7 +225,7 @@ func TestMiddleware_SuccessfulSwap_DefaultStrategy(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "default-strategy-token") + req := requestWithIdentity("default-strategy-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -252,7 +252,7 @@ func TestMiddleware_SuccessfulSwap_CustomHeader(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "upstream-access-token") + req := requestWithIdentity("upstream-access-token") req.Header.Set("Authorization", "Bearer original-token") rr := httptest.NewRecorder() @@ -314,7 +314,7 @@ func TestMiddlewareWithContext(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("github", "ctx-test-token") + req := requestWithIdentity("ctx-test-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) diff --git a/pkg/auth/upstreamtoken/service.go b/pkg/auth/upstreamtoken/service.go index 131ace05e4..ff12423455 100644 --- a/pkg/auth/upstreamtoken/service.go +++ b/pkg/auth/upstreamtoken/service.go @@ -31,8 +31,8 @@ type InProcessService struct { // Compile-time checks. var ( - _ Service = (*InProcessService)(nil) - _ UpstreamTokenReader = (*InProcessService)(nil) + _ Service = (*InProcessService)(nil) + _ TokenReader = (*InProcessService)(nil) ) // NewInProcessService creates a new InProcessService. @@ -137,7 +137,10 @@ func (s *InProcessService) refreshOrFail( } if s.refresher == nil { - slog.Debug("token refresher not configured, cannot refresh upstream tokens") + slog.Debug("token refresher not configured, cannot refresh upstream tokens", + "session_id", sessionID, + "provider", providerName, + ) return nil, ErrNoRefreshToken } @@ -156,6 +159,7 @@ func (s *InProcessService) refreshOrFail( if err != nil { slog.Warn("upstream token refresh failed", "session_id", sessionID, + "provider", providerName, "error", err, ) return nil, fmt.Errorf("%w: %w", ErrRefreshFailed, err) diff --git a/pkg/auth/upstreamtoken/types.go b/pkg/auth/upstreamtoken/types.go index e77383a215..1668f71144 100644 --- a/pkg/auth/upstreamtoken/types.go +++ b/pkg/auth/upstreamtoken/types.go @@ -5,6 +5,8 @@ // lifecycle, including transparent refresh of expired access tokens. package upstreamtoken +//go:generate go run go.uber.org/mock/mockgen -destination=mocks/mock_token_reader.go -package=mocks github.com/stacklok/toolhive/pkg/auth/upstreamtoken TokenReader + import "context" // TokenSessionIDClaimKey is the JWT claim key for the token session ID. @@ -19,12 +21,12 @@ type UpstreamCredential struct { AccessToken string } -// UpstreamTokenReader retrieves upstream provider access tokens for a session. +// TokenReader retrieves upstream provider access tokens for a session. // This narrow interface decouples the auth middleware from storage internals. // // TODO(auth): Consider enriching the return type from map[string]string to // map[string]UpstreamCredential to carry per-provider freshness/error metadata. -type UpstreamTokenReader interface { +type TokenReader interface { // GetAllValidTokens returns access tokens for all upstream providers in a session. // Expired tokens are refreshed transparently when possible; if refresh fails, // the provider is omitted from the result. diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e7834a594f..f9e107f617 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -80,7 +80,7 @@ type Runner struct { // identity enrichment in auth middleware. Set when the embedded auth // server is initialized in Run(). // Nil when no embedded auth server is configured. - upstreamTokenReader upstreamtoken.UpstreamTokenReader + upstreamTokenReader upstreamtoken.TokenReader } // statusManagerAdapter adapts statuses.StatusManager to auth.StatusUpdater interface @@ -134,7 +134,7 @@ func (r *Runner) GetConfig() types.RunnerConfig { // GetUpstreamTokenReader returns the UpstreamTokenReader for identity // enrichment in the auth middleware. Returns nil if no embedded auth // server is configured. -func (r *Runner) GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader { +func (r *Runner) GetUpstreamTokenReader() upstreamtoken.TokenReader { return r.upstreamTokenReader } diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index c277459484..9ae9e5b009 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -124,10 +124,10 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetConfig() *gomock.Call { } // GetUpstreamTokenReader mocks base method. -func (m *MockMiddlewareRunner) GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader { +func (m *MockMiddlewareRunner) GetUpstreamTokenReader() upstreamtoken.TokenReader { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUpstreamTokenReader") - ret0, _ := ret[0].(upstreamtoken.UpstreamTokenReader) + ret0, _ := ret[0].(upstreamtoken.TokenReader) return ret0 } diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index 1b232cfa4b..ae1216f676 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -77,9 +77,9 @@ type MiddlewareRunner interface { // GetConfig returns a config interface for middleware to access runner configuration GetConfig() RunnerConfig - // GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment. + // GetUpstreamTokenReader returns a TokenReader for identity enrichment. // Returns nil if the embedded auth server is not configured. - GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader + GetUpstreamTokenReader() upstreamtoken.TokenReader } // RunnerConfig defines the config interface needed by middleware to access runner configuration From bd92306cacd07c8ac7cf82c7931c85dea820af29 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 25 Mar 2026 15:43:14 +0000 Subject: [PATCH 8/8] Address review feedback: generated mock, providerName in logs Replace hand-written mockUpstreamTokenReader with go.uber.org/mock generated MockTokenReader, per project convention of never hand-writing mocks. Add //go:generate directive to types.go. Add providerName to log lines in refreshOrFail so operators can identify which upstream provider failed during concurrent refreshes. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/auth/token_test.go | 222 ++++++++---------- .../upstreamtoken/mocks/mock_token_reader.go | 56 +++++ 2 files changed, 149 insertions(+), 129 deletions(-) create mode 100644 pkg/auth/upstreamtoken/mocks/mock_token_reader.go diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index 5ba11df27a..ea7eac2520 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -25,6 +25,7 @@ import ( envmocks "github.com/stacklok/toolhive-core/env/mocks" "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" + upstreamtokenmocks "github.com/stacklok/toolhive/pkg/auth/upstreamtoken/mocks" "github.com/stacklok/toolhive/pkg/networking" oauthproto "github.com/stacklok/toolhive/pkg/oauth" ) @@ -2228,145 +2229,100 @@ func TestMiddleware_RFC6750JSONErrorResponse(t *testing.T) { func TestLoadUpstreamTokens(t *testing.T) { t.Parallel() - tests := []struct { - name string - claims jwt.MapClaims - reader upstreamtoken.UpstreamTokenReader - wantResult map[string]string - wantErr bool - }{ - { - name: "loads tokens when tsid present", - claims: jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: "session-abc", - }, - reader: &mockUpstreamTokenReader{ - tokens: map[string]string{ - "github": "gh-token", - "atlassian": "atl-token", - }, - }, - wantResult: map[string]string{ - "github": "gh-token", - "atlassian": "atl-token", - }, - }, - { - name: "returns nil when no tsid claim", - claims: jwt.MapClaims{ - "sub": "user123", - }, - reader: &mockUpstreamTokenReader{}, - wantResult: nil, - }, - { - name: "returns nil when tsid is empty string", - claims: jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: "", - }, - reader: &mockUpstreamTokenReader{}, - wantResult: nil, - }, - { - name: "returns nil when tsid is non-string type", - claims: jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: 12345, - }, - reader: &mockUpstreamTokenReader{}, - wantResult: nil, - }, - { - name: "returns error when reader fails", - claims: jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: "session-abc", - }, - reader: &mockUpstreamTokenReader{ - err: errors.New("storage unavailable"), - }, - wantErr: true, - }, - { - name: "returns empty map from reader", - claims: jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: "session-abc", - }, - reader: &mockUpstreamTokenReader{ - tokens: map[string]string{}, - }, - wantResult: map[string]string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - v := &TokenValidator{ - upstreamTokenReader: tt.reader, - } - - result, err := v.loadUpstreamTokens(context.Background(), tt.claims) - - if tt.wantErr { - require.Error(t, err) - require.Nil(t, result) - return - } - - require.NoError(t, err) - if tt.wantResult == nil { - require.Nil(t, result) - } else { - require.Equal(t, tt.wantResult, result) - } + t.Run("loads tokens when tsid present", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(map[string]string{"github": "gh-token", "atlassian": "atl-token"}, nil) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", }) - } -} - -func TestLoadUpstreamTokens_PassesCorrectSessionID(t *testing.T) { - t.Parallel() + require.NoError(t, err) + require.Equal(t, map[string]string{"github": "gh-token", "atlassian": "atl-token"}, result) + }) - reader := &mockUpstreamTokenReader{ - tokens: map[string]string{"github": "token"}, - } + t.Run("returns nil when no tsid claim", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + // No EXPECT — reader should not be called + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{"sub": "user123"}) + require.NoError(t, err) + require.Nil(t, result) + }) - v := &TokenValidator{ - upstreamTokenReader: reader, - } + t.Run("returns nil when tsid is empty string", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) - claims := jwt.MapClaims{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: "session-xyz", - } + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "", + }) + require.NoError(t, err) + require.Nil(t, result) + }) - result, err := v.loadUpstreamTokens(context.Background(), claims) + t.Run("returns nil when tsid is non-string type", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "session-xyz", reader.calledWith) -} + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: 12345, + }) + require.NoError(t, err) + require.Nil(t, result) + }) -// mockUpstreamTokenReader is a simple mock for testing loadUpstreamTokens. -type mockUpstreamTokenReader struct { - tokens map[string]string - err error - calledWith string -} + t.Run("returns error when reader fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(nil, errors.New("storage unavailable")) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }) + require.Error(t, err) + require.Nil(t, result) + }) -func (m *mockUpstreamTokenReader) GetAllValidTokens(_ context.Context, sessionID string) (map[string]string, error) { - m.calledWith = sessionID - return m.tokens, m.err + t.Run("returns empty map from reader", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(map[string]string{}, nil) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }) + require.NoError(t, err) + require.Equal(t, map[string]string{}, result) + }) } func TestWithUpstreamTokenReader(t *testing.T) { t.Parallel() - reader := &mockUpstreamTokenReader{} + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) opt := WithUpstreamTokenReader(reader) o := &tokenValidatorOptions{} @@ -2425,7 +2381,10 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) { t.Run("enriches identity with upstream tokens", func(t *testing.T) { t.Parallel() - reader := &mockUpstreamTokenReader{tokens: map[string]string{"github": "gh-tok"}} + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-xyz"). + Return(map[string]string{"github": "gh-tok"}, nil) v := makeValidator(t, WithUpstreamTokenReader(reader)) var captured *Identity @@ -2444,7 +2403,10 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) { t.Run("returns 503 when storage fails", func(t *testing.T) { t.Parallel() - reader := &mockUpstreamTokenReader{err: errors.New("redis down")} + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-xyz"). + Return(nil, errors.New("redis down")) v := makeValidator(t, WithUpstreamTokenReader(reader)) nextCalled := false @@ -2463,7 +2425,9 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) { t.Run("no enrichment without tsid", func(t *testing.T) { t.Parallel() - reader := &mockUpstreamTokenReader{tokens: map[string]string{"github": "should-not-appear"}} + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + // No EXPECT — reader should not be called when tsid is absent v := makeValidator(t, WithUpstreamTokenReader(reader)) var captured *Identity diff --git a/pkg/auth/upstreamtoken/mocks/mock_token_reader.go b/pkg/auth/upstreamtoken/mocks/mock_token_reader.go new file mode 100644 index 0000000000..c29a64f5c2 --- /dev/null +++ b/pkg/auth/upstreamtoken/mocks/mock_token_reader.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/auth/upstreamtoken (interfaces: TokenReader) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_token_reader.go -package=mocks github.com/stacklok/toolhive/pkg/auth/upstreamtoken TokenReader +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockTokenReader is a mock of TokenReader interface. +type MockTokenReader struct { + ctrl *gomock.Controller + recorder *MockTokenReaderMockRecorder + isgomock struct{} +} + +// MockTokenReaderMockRecorder is the mock recorder for MockTokenReader. +type MockTokenReaderMockRecorder struct { + mock *MockTokenReader +} + +// NewMockTokenReader creates a new mock instance. +func NewMockTokenReader(ctrl *gomock.Controller) *MockTokenReader { + mock := &MockTokenReader{ctrl: ctrl} + mock.recorder = &MockTokenReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenReader) EXPECT() *MockTokenReaderMockRecorder { + return m.recorder +} + +// GetAllValidTokens mocks base method. +func (m *MockTokenReader) GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllValidTokens", ctx, sessionID) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllValidTokens indicates an expected call of GetAllValidTokens. +func (mr *MockTokenReaderMockRecorder) GetAllValidTokens(ctx, sessionID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllValidTokens", reflect.TypeOf((*MockTokenReader)(nil).GetAllValidTokens), ctx, sessionID) +}