diff --git a/docs/middleware.md b/docs/middleware.md index 28cc1bfc00..8ab0eedc73 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -133,11 +133,9 @@ sequenceDiagram **Availability**: Automatically enabled when using the embedded auth server (`EmbeddedAuthServerConfig`) **Responsibilities**: -- Extract the token session ID (`tsid`) claim from the ToolHive JWT -- Look up the stored upstream IdP tokens associated with that session +- Read the upstream access token for the configured provider from `Identity.UpstreamTokens` - Inject the upstream access token into the request (replacing Authorization header or using a custom header) -- Return 401 Unauthorized with WWW-Authenticate header when tokens are expired or not found -- Gracefully proceed without modification if identity, session ID, or storage is unavailable +- Return 401 Unauthorized with WWW-Authenticate header when the provider token is missing or empty **Configuration**: @@ -148,16 +146,16 @@ sequenceDiagram **Behavior**: - **Automatic activation**: Enabled whenever the embedded auth server is configured, even without explicit `UpstreamSwapConfig` -- **Expired tokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required -- **Tokens not found**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required -- **Missing identity/tsid**: Proceeds without modification (logs debug message) -- **Storage unavailable**: Proceeds without modification (logs warning) -- **Other storage errors**: Returns 503 Service Unavailable to fail closed (logs warning) +- **Provider token found**: Injects the token into the request using the configured header strategy +- **Provider not in UpstreamTokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required +- **Empty token value**: Returns 401 Unauthorized (same as missing provider) +- **No identity in context**: Passes through without modification (auth middleware not in chain) +- **Storage unavailable**: The auth middleware returns 503 before the request reaches this middleware **Context Data Used**: -- Identity from Authentication middleware (specifically the `tsid` claim) +- `Identity.UpstreamTokens` map populated by the Authentication middleware during JWT validation -**Note**: This middleware is designed for use with ToolHive's embedded auth server. It reads from the auth server's token storage to retrieve upstream IdP tokens that were captured during the OAuth authorization flow. +**Note**: This middleware is a simple map reader. All upstream token loading, refresh, and error handling occurs in the Authentication middleware (Step 3), which populates `Identity.UpstreamTokens` from the token session ID (`tsid`) claim during JWT validation. --- @@ -785,10 +783,9 @@ type MiddlewareRunner interface { // GetConfig returns a config interface for middleware to access runner configuration GetConfig() RunnerConfig - // GetUpstreamTokenService returns a lazy accessor for the upstream token service. - // Returns a function that provides the service at request time. - // Used by upstream swap middleware to get valid upstream tokens (with transparent refresh). - GetUpstreamTokenService() func() upstreamtoken.Service + // GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment. + // Returns nil if the embedded auth server is not configured. + GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader } ``` diff --git a/pkg/auth/context.go b/pkg/auth/context.go index af7ea8d8fa..c7e5420541 100644 --- a/pkg/auth/context.go +++ b/pkg/auth/context.go @@ -67,10 +67,16 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) { return nil, errors.New("missing or invalid 'sub' claim (required by OIDC Core 1.0 § 5.1)") } + // Filter internal claims that should not be externalized (e.g., in + // webhook payloads or audit logs). The tsid is a session identifier + // used to look up upstream tokens in storage; exposing it widens the + // attack surface if a webhook receiver is compromised. + filteredClaims := filterInternalClaims(claims) + identity := &Identity{ PrincipalInfo: PrincipalInfo{ Subject: sub, - Claims: claims, + Claims: filteredClaims, }, Token: token, TokenType: "Bearer", @@ -86,3 +92,20 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) { return identity, nil } + +// internalClaims are JWT claim keys used internally by the auth server +// that must not be externalized in webhook payloads, audit logs, etc. +// "tsid" is the token session ID used to look up upstream tokens in storage. +var internalClaims = []string{"tsid"} + +// filterInternalClaims returns a copy of claims with internal keys removed. +func filterInternalClaims(claims jwt.MapClaims) jwt.MapClaims { + filtered := make(jwt.MapClaims, len(claims)) + for k, v := range claims { + filtered[k] = v + } + for _, key := range internalClaims { + delete(filtered, key) + } + return filtered +} diff --git a/pkg/auth/identity.go b/pkg/auth/identity.go index 4af1fefdf4..0215e54a23 100644 --- a/pkg/auth/identity.go +++ b/pkg/auth/identity.go @@ -56,6 +56,13 @@ type Identity struct { // Metadata stores additional identity information. Metadata map[string]string + + // UpstreamTokens maps upstream provider names to their access tokens. + // This is populated by the auth middleware when an embedded auth server + // is active and the JWT contains a token session ID (tsid claim). + // Redacted in MarshalJSON() to prevent token leakage. + // MUST NOT be mutated after the Identity is placed in the request context. + UpstreamTokens map[string]string } // String returns a string representation of the Identity with sensitive fields redacted. @@ -77,14 +84,15 @@ func (i *Identity) MarshalJSON() ([]byte, error) { // Create a safe representation with lowercase field names and redacted token type SafeIdentity struct { - Subject string `json:"subject"` - Name string `json:"name"` - Email string `json:"email"` - Groups []string `json:"groups"` - Claims map[string]any `json:"claims"` - Token string `json:"token"` - TokenType string `json:"tokenType"` - Metadata map[string]string `json:"metadata"` + Subject string `json:"subject"` + Name string `json:"name"` + Email string `json:"email"` + Groups []string `json:"groups"` + Claims map[string]any `json:"claims"` + Token string `json:"token"` + TokenType string `json:"tokenType"` + Metadata map[string]string `json:"metadata"` + UpstreamTokens map[string]string `json:"upstreamTokens,omitempty"` } token := i.Token @@ -92,15 +100,31 @@ func (i *Identity) MarshalJSON() ([]byte, error) { token = "REDACTED" } + // Redact upstream tokens: preserve keys, replace non-empty values + var redactedUpstreamTokens map[string]string + // Guard with len() > 0 (not != nil) so that both nil and empty maps + // produce a nil redactedUpstreamTokens, which omitempty then omits. + if len(i.UpstreamTokens) > 0 { + redactedUpstreamTokens = make(map[string]string, len(i.UpstreamTokens)) + for k, v := range i.UpstreamTokens { + if v != "" { + redactedUpstreamTokens[k] = "REDACTED" + } else { + redactedUpstreamTokens[k] = "" + } + } + } + return json.Marshal(&SafeIdentity{ - Subject: i.Subject, - Name: i.Name, - Email: i.Email, - Groups: i.Groups, - Claims: i.Claims, - Token: token, - TokenType: i.TokenType, - Metadata: i.Metadata, + Subject: i.Subject, + Name: i.Name, + Email: i.Email, + Groups: i.Groups, + Claims: i.Claims, + Token: token, + TokenType: i.TokenType, + Metadata: i.Metadata, + UpstreamTokens: redactedUpstreamTokens, }) } diff --git a/pkg/auth/identity_test.go b/pkg/auth/identity_test.go index 5ec9d1661e..f37d995c57 100644 --- a/pkg/auth/identity_test.go +++ b/pkg/auth/identity_test.go @@ -150,6 +150,16 @@ func TestIdentity_String(t *testing.T) { identity: nil, want: "", }, + { + name: "does_not_leak_upstream_tokens", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + }, + }, + want: `Identity{Subject:"user123"}`, + }, } for _, tt := range tests { @@ -293,6 +303,92 @@ func TestIdentity_MarshalJSON(t *testing.T) { assert.Equal(t, "null", string(data)) }, }, + { + name: "redacts_upstream_tokens", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + "atlassian": "atl_secret456", + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + tokens, ok := result["upstreamTokens"].(map[string]any) + require.True(t, ok, "upstreamTokens should be a map") + assert.Equal(t, "REDACTED", tokens["github"]) + assert.Equal(t, "REDACTED", tokens["atlassian"]) + assert.NotContains(t, string(data), "gho_secret123") + assert.NotContains(t, string(data), "atl_secret456") + }, + }, + { + name: "empty_upstream_tokens_omitted", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{}, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + // Empty map should be omitted because len() == 0 produces nil redacted map + _, exists := result["upstreamTokens"] + assert.False(t, exists, "empty upstreamTokens should be omitted") + }, + }, + { + name: "nil_upstream_tokens_omitted", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: nil, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + _, exists := result["upstreamTokens"] + assert.False(t, exists, "nil upstreamTokens should be omitted") + }, + }, + { + name: "upstream_tokens_mixed_empty_and_populated", + identity: &Identity{ + PrincipalInfo: PrincipalInfo{Subject: "user123"}, + UpstreamTokens: map[string]string{ + "github": "gho_secret123", + "pending": "", + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, data []byte) { + t.Helper() + + var result map[string]any + err := json.Unmarshal(data, &result) + require.NoError(t, err) + + tokens, ok := result["upstreamTokens"].(map[string]any) + require.True(t, ok, "upstreamTokens should be a map") + assert.Equal(t, "REDACTED", tokens["github"]) + assert.Equal(t, "", tokens["pending"]) + assert.NotContains(t, string(data), "gho_secret123") + }, + }, } for _, tt := range tests { diff --git a/pkg/auth/middleware.go b/pkg/auth/middleware.go index 3c1bcd7662..ed2661d67c 100644 --- a/pkg/auth/middleware.go +++ b/pkg/auth/middleware.go @@ -52,7 +52,12 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("failed to unmarshal auth middleware parameters: %w", err) } - middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig) + var opts []TokenValidatorOption + if reader := runner.GetUpstreamTokenReader(); reader != nil { + opts = append(opts, WithUpstreamTokenReader(reader)) + } + + middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/auth/middleware_test.go b/pkg/auth/middleware_test.go index dca9cd6976..c4a4fd38c4 100644 --- a/pkg/auth/middleware_test.go +++ b/pkg/auth/middleware_test.go @@ -102,6 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) { // Create mock runner mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + // Expect GetUpstreamTokenReader to be called (returns nil = no auth server) + mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil) + // Expect AddMiddleware to be called with a middleware instance mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) { // Verify it's our auth middleware @@ -210,6 +213,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) { mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + // Expect GetUpstreamTokenReader to be called (returns nil = no auth server) + mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil) + // Expect AddMiddleware to be called mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()) diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 7d5e369674..5c955e499a 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -25,6 +25,7 @@ import ( "github.com/stacklok/toolhive-core/env" "github.com/stacklok/toolhive/pkg/auth/oauth" + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/networking" oauthproto "github.com/stacklok/toolhive/pkg/oauth" ) @@ -367,6 +368,10 @@ type TokenValidator struct { registry *Registry // Token introspection providers insecureAllowHTTP bool // Allow HTTP (non-HTTPS) OIDC issuers for development/testing + // upstreamTokenReader loads upstream provider tokens for identity enrichment. + // nil means no enrichment (no embedded auth server). + upstreamTokenReader upstreamtoken.TokenReader + // Lazy JWKS registration jwksRegistered bool jwksRegistrationMu sync.Mutex @@ -540,7 +545,8 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st // tokenValidatorOptions holds optional dependencies for NewTokenValidator. type tokenValidatorOptions struct { - envReader env.Reader + envReader env.Reader + upstreamTokenReader upstreamtoken.TokenReader } // TokenValidatorOption is a functional option for NewTokenValidator. @@ -554,6 +560,16 @@ func WithEnvReader(reader env.Reader) TokenValidatorOption { } } +// WithUpstreamTokenReader configures the token validator to enrich Identity +// with upstream provider tokens. When set, the Middleware extracts the token +// session ID (tsid) from JWT claims and loads all upstream tokens into +// Identity.UpstreamTokens before placing the Identity in the request context. +func WithUpstreamTokenReader(reader upstreamtoken.TokenReader) TokenValidatorOption { + return func(o *tokenValidatorOptions) { + o.upstreamTokenReader = reader + } +} + // NewTokenValidator creates a new token validator. func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ...TokenValidatorOption) (*TokenValidator, error) { // Apply functional options @@ -638,18 +654,19 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts .. } validator := &TokenValidator{ - issuer: config.Issuer, - audience: config.Audience, - jwksURL: jwksURL, - introspectURL: config.IntrospectionURL, - clientID: config.ClientID, - clientSecret: clientSecret, - jwksClient: cache, - client: config.httpClient, - resourceURL: config.ResourceURL, - scopes: config.Scopes, - registry: registry, - insecureAllowHTTP: config.InsecureAllowHTTP, + issuer: config.Issuer, + audience: config.Audience, + jwksURL: jwksURL, + introspectURL: config.IntrospectionURL, + clientID: config.ClientID, + clientSecret: clientSecret, + jwksClient: cache, + client: config.httpClient, + resourceURL: config.ResourceURL, + scopes: config.Scopes, + registry: registry, + insecureAllowHTTP: config.InsecureAllowHTTP, + upstreamTokenReader: o.upstreamTokenReader, } return validator, nil @@ -1064,6 +1081,24 @@ func writeOAuthError(w http.ResponseWriter, errorCode, description string, statu _, _ = w.Write(body) } +// loadUpstreamTokens extracts the token session ID from claims and loads +// all upstream provider tokens for that session. Returns (nil, nil) if no +// tsid claim exists. Returns a non-nil error when a tsid claim is present +// but token loading fails (infrastructure error). +func (v *TokenValidator) loadUpstreamTokens(ctx context.Context, claims jwt.MapClaims) (map[string]string, error) { + tsid, ok := claims[upstreamtoken.TokenSessionIDClaimKey].(string) + if !ok || tsid == "" { + return nil, nil + } + + tokens, err := v.upstreamTokenReader.GetAllValidTokens(ctx, tsid) + if err != nil { + return nil, fmt.Errorf("load upstream tokens for session %s: %w", tsid, err) + } + + return tokens, nil +} + // Middleware creates an HTTP middleware that validates JWT tokens and creates Identity. func (v *TokenValidator) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1092,6 +1127,20 @@ func (v *TokenValidator) Middleware(next http.Handler) http.Handler { return } + // Enrich Identity with upstream provider tokens when an embedded + // auth server is active (reader configured via WithUpstreamTokenReader). + if v.upstreamTokenReader != nil { + tokens, loadErr := v.loadUpstreamTokens(r.Context(), claims) + if loadErr != nil { + slog.WarnContext(r.Context(), "upstream token storage unavailable", + "error", loadErr, + ) + http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) + return + } + identity.UpstreamTokens = tokens + } + // Add the Identity to the request context ctx := WithIdentity(r.Context(), identity) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index b6006f369c..ea7eac2520 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -24,6 +24,8 @@ import ( "go.uber.org/mock/gomock" envmocks "github.com/stacklok/toolhive-core/env/mocks" + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" + upstreamtokenmocks "github.com/stacklok/toolhive/pkg/auth/upstreamtoken/mocks" "github.com/stacklok/toolhive/pkg/networking" oauthproto "github.com/stacklok/toolhive/pkg/oauth" ) @@ -2223,3 +2225,244 @@ func TestMiddleware_RFC6750JSONErrorResponse(t *testing.T) { }) } } + +func TestLoadUpstreamTokens(t *testing.T) { + t.Parallel() + + t.Run("loads tokens when tsid present", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(map[string]string{"github": "gh-token", "atlassian": "atl-token"}, nil) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }) + require.NoError(t, err) + require.Equal(t, map[string]string{"github": "gh-token", "atlassian": "atl-token"}, result) + }) + + t.Run("returns nil when no tsid claim", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + // No EXPECT — reader should not be called + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{"sub": "user123"}) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("returns nil when tsid is empty string", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "", + }) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("returns nil when tsid is non-string type", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: 12345, + }) + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("returns error when reader fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(nil, errors.New("storage unavailable")) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }) + require.Error(t, err) + require.Nil(t, result) + }) + + t.Run("returns empty map from reader", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-abc"). + Return(map[string]string{}, nil) + + v := &TokenValidator{upstreamTokenReader: reader} + result, err := v.loadUpstreamTokens(context.Background(), jwt.MapClaims{ + "sub": "user123", + upstreamtoken.TokenSessionIDClaimKey: "session-abc", + }) + require.NoError(t, err) + require.Equal(t, map[string]string{}, result) + }) +} + +func TestWithUpstreamTokenReader(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + opt := WithUpstreamTokenReader(reader) + + o := &tokenValidatorOptions{} + opt(o) + + require.Equal(t, reader, o.upstreamTokenReader) +} + +// TestMiddleware_UpstreamTokenEnrichment verifies the full middleware pipeline: +// JWT validation → tsid extraction → token loading → Identity.UpstreamTokens. +func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) { + t.Parallel() + + // Shared JWKS infrastructure + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + key, err := jwk.Import(&privateKey.PublicKey) + require.NoError(t, err) + require.NoError(t, key.Set(jwk.KeyIDKey, testKeyID)) + require.NoError(t, key.Set(jwk.AlgorithmKey, "RS256")) + require.NoError(t, key.Set(jwk.KeyUsageKey, "sig")) + + keySet := jwk.NewSet() + require.NoError(t, keySet.AddKey(key)) + jwksServer, caCertPath := createTestJWKSServer(t, keySet) + t.Cleanup(jwksServer.Close) + + makeValidator := func(t *testing.T, opts ...TokenValidatorOption) *TokenValidator { + t.Helper() + v, vErr := NewTokenValidator(context.Background(), TokenValidatorConfig{ + Issuer: "test-issuer", Audience: "test-audience", + JWKSURL: jwksServer.URL, ClientID: "test-client", + CACertPath: caCertPath, AllowPrivateIP: true, + }, opts...) + require.NoError(t, vErr) + require.NoError(t, v.ensureJWKSRegistered(context.Background())) + _, lErr := v.jwksClient.Lookup(context.Background(), jwksServer.URL) + require.NoError(t, lErr) + return v + } + + signToken := func(claims jwt.MapClaims) string { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = testKeyID + s, sErr := tok.SignedString(privateKey) + require.NoError(t, sErr) + return s + } + + claimsWithTsid := jwt.MapClaims{ + "iss": "test-issuer", "aud": "test-audience", "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + upstreamtoken.TokenSessionIDClaimKey: "session-xyz", + } + + t.Run("enriches identity with upstream tokens", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-xyz"). + Return(map[string]string{"github": "gh-tok"}, nil) + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, map[string]string{"github": "gh-tok"}, captured.UpstreamTokens) + }) + + t.Run("returns 503 when storage fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + reader.EXPECT().GetAllValidTokens(gomock.Any(), "session-xyz"). + Return(nil, errors.New("redis down")) + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + nextCalled := false + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + require.False(t, nextCalled) + }) + + t.Run("no enrichment without tsid", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + reader := upstreamtokenmocks.NewMockTokenReader(ctrl) + // No EXPECT — reader should not be called when tsid is absent + v := makeValidator(t, WithUpstreamTokenReader(reader)) + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + noTsid := jwt.MapClaims{ + "iss": "test-issuer", "aud": "test-audience", "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + } + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(noTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Nil(t, captured.UpstreamTokens) + }) + + t.Run("no enrichment when reader is nil", func(t *testing.T) { + t.Parallel() + v := makeValidator(t) // no WithUpstreamTokenReader + + var captured *Identity + handler := v.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured, _ = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer "+signToken(claimsWithTsid)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Nil(t, captured.UpstreamTokens) + }) +} diff --git a/pkg/auth/upstreamswap/middleware.go b/pkg/auth/upstreamswap/middleware.go index ca4edd1fbf..d4f515aafb 100644 --- a/pkg/auth/upstreamswap/middleware.go +++ b/pkg/auth/upstreamswap/middleware.go @@ -7,13 +7,10 @@ package upstreamswap import ( "encoding/json" - "errors" "fmt" - "log/slog" "net/http" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -46,10 +43,6 @@ type MiddlewareParams struct { Config *Config `json:"config,omitempty"` } -// ServiceGetter is a function that returns an upstream token service. -// It returns nil when the service is unavailable (e.g., auth server not configured). -type ServiceGetter func() upstreamtoken.Service - // Middleware wraps the upstream swap middleware functionality. type Middleware struct { middleware types.MiddlewareFunction @@ -83,10 +76,7 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid upstream swap configuration: %w", err) } - // Get the lazy service accessor from the runner. - serviceGetter := ServiceGetter(runner.GetUpstreamTokenService()) - - middleware := createMiddlewareFunc(cfg, serviceGetter) + middleware := createMiddlewareFunc(cfg) upstreamSwapMw := &Middleware{ middleware: middleware, @@ -146,7 +136,9 @@ func createCustomInjector(headerName string) injectionFunc { } // createMiddlewareFunc creates the actual middleware function. -func createMiddlewareFunc(cfg *Config, serviceGetter ServiceGetter) types.MiddlewareFunction { +// It reads upstream tokens from Identity.UpstreamTokens, which are populated +// during JWT validation by the auth middleware (Step 3). +func createMiddlewareFunc(cfg *Config) types.MiddlewareFunction { // Determine injection strategy at startup time strategy := cfg.HeaderStrategy if strategy == "" { @@ -166,73 +158,19 @@ func createMiddlewareFunc(cfg *Config, serviceGetter ServiceGetter) types.Middle return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 1. Get identity from auth middleware identity, ok := auth.IdentityFromContext(r.Context()) if !ok { - slog.Debug("No identity in context, proceeding without swap", - "middleware", "upstreamswap") next.ServeHTTP(w, r) return } - // 2. Extract tsid from claims - tsid, ok := identity.Claims[upstreamtoken.TokenSessionIDClaimKey].(string) - if !ok || tsid == "" { - slog.Debug("No tsid claim in identity, proceeding without swap", - "middleware", "upstreamswap") - next.ServeHTTP(w, r) + token, exists := identity.UpstreamTokens[cfg.ProviderName] + if !exists || token == "" { + writeUpstreamAuthRequired(w) return } - // 3. Get token service — fail closed if unavailable. - // The tsid claim confirms this request expects upstream token injection; - // passing through with the original JWT would leak it to the backend. - svc := serviceGetter() - if svc == nil { - slog.Warn("Token service unavailable, cannot perform required upstream swap", - "middleware", "upstreamswap") - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - // 4. Get valid upstream tokens (with transparent refresh) - cred, err := svc.GetValidTokens(r.Context(), tsid, cfg.ProviderName) - if err != nil { - // Client-attributable errors require re-authentication. - if errors.Is(err, upstreamtoken.ErrSessionNotFound) || - errors.Is(err, upstreamtoken.ErrNoRefreshToken) || - errors.Is(err, upstreamtoken.ErrRefreshFailed) || - errors.Is(err, upstreamtoken.ErrInvalidBinding) { - slog.Debug("Upstream token needs re-authentication", - "middleware", "upstreamswap", - "provider", cfg.ProviderName, - "error", err) - writeUpstreamAuthRequired(w) - return - } - // Other errors: fail closed to avoid bypassing the token swap - slog.Warn("Failed to get upstream tokens", - "middleware", "upstreamswap", - "provider", cfg.ProviderName, - "error", err) - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - // 5. Inject access token — fail closed if empty to prevent bypassing the swap - if cred.AccessToken == "" { - slog.Warn("Upstream token service returned empty access token", - "middleware", "upstreamswap", - "provider", cfg.ProviderName) - http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable) - return - } - - injectToken(r, cred.AccessToken) - slog.Debug("Injected upstream access token", - "middleware", "upstreamswap", - "provider", cfg.ProviderName) - + injectToken(r, token) next.ServeHTTP(w, r) }) } diff --git a/pkg/auth/upstreamswap/middleware_test.go b/pkg/auth/upstreamswap/middleware_test.go index 39454e36ca..9594c3b9f3 100644 --- a/pkg/auth/upstreamswap/middleware_test.go +++ b/pkg/auth/upstreamswap/middleware_test.go @@ -6,51 +6,28 @@ package upstreamswap import ( "context" "encoding/json" - "errors" "net/http" "net/http/httptest" - "sync" - "sync/atomic" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" - "github.com/stacklok/toolhive/pkg/authserver/storage" - storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) -// serviceGetterFromMocks creates a ServiceGetter that wraps mock storage and refresher -// in an InProcessService. This is the standard pattern for middleware tests. -func serviceGetterFromMocks( - stor storage.UpstreamTokenStorage, - refresher storage.UpstreamTokenRefresher, -) ServiceGetter { - svc := upstreamtoken.NewInProcessService(stor, refresher) - return func() upstreamtoken.Service { return svc } -} - -// nilServiceGetter returns a ServiceGetter that always returns nil (service unavailable). -func nilServiceGetter() ServiceGetter { - return func() upstreamtoken.Service { return nil } -} - -// requestWithIdentity creates an HTTP request with the given identity in context. -func requestWithIdentity(tsid string) *http.Request { +// requestWithIdentity creates an HTTP request with a "github" upstream token in context. +func requestWithIdentity(token string) *http.Request { req := httptest.NewRequest(http.MethodGet, "/test", nil) identity := &auth.Identity{ PrincipalInfo: auth.PrincipalInfo{ Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: tsid, - }, + }, + UpstreamTokens: map[string]string{ + "github": token, }, } ctx := auth.WithIdentity(req.Context(), identity) @@ -126,13 +103,12 @@ func TestValidateConfig(t *testing.T) { func TestMiddleware_NoIdentity(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { nextCalled = true - // Verify Authorization header was NOT modified assert.Empty(t, r.Header.Get("Authorization")) }) @@ -143,31 +119,25 @@ func TestMiddleware_NoIdentity(t *testing.T) { handler.ServeHTTP(rr, req) assert.True(t, nextCalled, "next handler should be called") + assert.Equal(t, http.StatusOK, rr.Code) } -func TestMiddleware_NoTsidClaim(t *testing.T) { +func TestMiddleware_NilUpstreamTokens(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when upstream tokens are nil") }) handler := middleware(nextHandler) - // Create request with identity but no tsid claim + // Identity exists but UpstreamTokens is nil (not populated by auth middleware) req := httptest.NewRequest(http.MethodGet, "/test", nil) identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - // No tsid claim - }, - }, + PrincipalInfo: auth.PrincipalInfo{Subject: "user123"}, } ctx := auth.WithIdentity(req.Context(), identity) req = req.WithContext(ctx) @@ -175,142 +145,56 @@ func TestMiddleware_NoTsidClaim(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.True(t, nextCalled, "next handler should be called") + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) } -func TestMiddleware_ServiceUnavailable_FailsClosed(t *testing.T) { +func TestMiddleware_ProviderMissing_Returns401(t *testing.T) { t.Parallel() - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) + cfg := &Config{ProviderName: "atlassian"} + middleware := createMiddlewareFunc(cfg) - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when provider is missing") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("gh-token") // has github but not atlassian rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called when service is unavailable") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) -} - -func TestMiddleware_ClientAttributableErrors_Returns401(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - setupStorage func(*storagemocks.MockUpstreamTokenStorage) - }{ - { - name: "not found", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrNotFound) - }, - }, - { - name: "expired with nil tokens", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrExpired) - }, - }, - { - name: "invalid binding", - setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { - s.EXPECT().GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, storage.ErrInvalidBinding) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - tt.setupStorage(mockStorage) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) - }) - } + assert.Equal(t, http.StatusUnauthorized, rr.Code) } -func TestMiddleware_StorageError(t *testing.T) { +func TestMiddleware_EmptyToken_Returns401(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(nil, errors.New("database error")) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + t.Error("next handler should NOT be called when token is empty") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("") // empty token rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called on storage error") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Equal(t, http.StatusUnauthorized, rr.Code) } -func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { +func TestMiddleware_SuccessfulSwap_ReplaceStrategy(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "upstream-access-token", - IDToken: "upstream-id-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - cfg := &Config{ HeaderStrategy: HeaderStrategyReplace, - ProviderName: "default", + ProviderName: "github", } - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + middleware := createMiddlewareFunc(cfg) var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -318,7 +202,7 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("upstream-access-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -326,118 +210,56 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) { assert.Equal(t, "Bearer upstream-access-token", capturedAuthHeader) } -func TestMiddleware_CustomHeader(t *testing.T) { +func TestMiddleware_SuccessfulSwap_DefaultStrategy(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "upstream-access-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - cfg := &Config{ - HeaderStrategy: HeaderStrategyCustom, - CustomHeaderName: "X-Upstream-Token", - ProviderName: "default", + ProviderName: "github", + // HeaderStrategy intentionally empty — defaults to replace } - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + middleware := createMiddlewareFunc(cfg) - var capturedCustomHeader string var capturedAuthHeader string nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - capturedCustomHeader = r.Header.Get("X-Upstream-Token") capturedAuthHeader = r.Header.Get("Authorization") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - req.Header.Set("Authorization", "Bearer original-token") + req := requestWithIdentity("default-strategy-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.Equal(t, "Bearer upstream-access-token", capturedCustomHeader) - // Original Authorization header should remain unchanged - assert.Equal(t, "Bearer original-token", capturedAuthHeader) -} - -func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Token is expired and has no refresh token - tokens := &storage.UpstreamTokens{ - AccessToken: "expired-upstream-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired 1 hour ago - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called for expired tokens") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "Bearer default-strategy-token", capturedAuthHeader) } -func TestMiddleware_EmptySelectedToken_FailsClosed(t *testing.T) { +func TestMiddleware_SuccessfulSwap_CustomHeader(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Tokens exist but the access token is empty - tokens := &storage.UpstreamTokens{ - AccessToken: "", // Empty access token - IDToken: "upstream-id-token", - ExpiresAt: time.Now().Add(1 * time.Hour), + cfg := &Config{ + HeaderStrategy: HeaderStrategyCustom, + CustomHeaderName: "X-Upstream-Token", + ProviderName: "github", } + middleware := createMiddlewareFunc(cfg) - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true + var capturedCustomHeader string + var capturedAuthHeader string + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedCustomHeader = r.Header.Get("X-Upstream-Token") + capturedAuthHeader = r.Header.Get("Authorization") }) handler := middleware(nextHandler) - req := requestWithIdentity("session-123") + req := requestWithIdentity("upstream-access-token") + req.Header.Set("Authorization", "Bearer original-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - assert.False(t, nextCalled, "next handler should NOT be called when access token is empty") - assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Equal(t, "Bearer upstream-access-token", capturedCustomHeader) + assert.Equal(t, "Bearer original-token", capturedAuthHeader) } func TestMiddleware_Close(t *testing.T) { @@ -483,35 +305,20 @@ func TestCreateInjectors(t *testing.T) { func TestMiddlewareWithContext(t *testing.T) { t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokens := &storage.UpstreamTokens{ - AccessToken: "ctx-test-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-ctx", "default"). - Return(tokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) + cfg := &Config{ProviderName: "github"} + middleware := createMiddlewareFunc(cfg) - // Test that context is properly passed through var receivedCtx context.Context nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { receivedCtx = r.Context() }) handler := middleware(nextHandler) - req := requestWithIdentity("session-ctx") + req := requestWithIdentity("ctx-test-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - // Verify identity is still in context identityFromCtx, ok := auth.IdentityFromContext(receivedCtx) assert.True(t, ok) assert.Equal(t, "user123", identityFromCtx.Subject) @@ -601,12 +408,7 @@ func TestCreateMiddleware(t *testing.T) { mockRunner := mocks.NewMockMiddlewareRunner(ctrl) - // Service getter is only called if validation passes (expectAddMiddleware) - // because validation happens before GetUpstreamTokenService is called if tt.expectAddMiddleware { - mockRunner.EXPECT().GetUpstreamTokenService().Return(func() upstreamtoken.Service { - return nil // Service availability is checked at request time - }) mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(_ string, mw types.Middleware) { _, ok := mw.(*Middleware) assert.True(t, ok, "Expected middleware to be of type *upstreamswap.Middleware") @@ -652,353 +454,3 @@ func TestCreateMiddleware_InvalidJSON(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "failed to unmarshal upstream swap middleware parameters") } - -// TestMiddleware_TsidClaimWrongType tests behavior when tsid claim exists but is wrong type. -func TestMiddleware_TsidClaimWrongType(t *testing.T) { - t.Parallel() - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, nilServiceGetter()) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - - // Create request with identity but tsid claim is wrong type (int instead of string) - req := httptest.NewRequest(http.MethodGet, "/test", nil) - identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{ - Subject: "user123", - Claims: map[string]any{ - "sub": "user123", - upstreamtoken.TokenSessionIDClaimKey: 12345, // Wrong type: int instead of string - }, - }, - } - ctx := auth.WithIdentity(req.Context(), identity) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called when tsid is wrong type") -} - -func TestMiddleware_ExpiredTokens_RefreshSuccess(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "new-access-token", - RefreshToken: "new-refresh-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(refreshedTokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - var capturedAuthHeader string - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - nextCalled = true - capturedAuthHeader = r.Header.Get("Authorization") - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called after successful refresh") - assert.Equal(t, "Bearer new-access-token", capturedAuthHeader) -} - -func TestMiddleware_ExpiredTokens_RefreshFails(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(nil, errors.New("refresh failed")) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called when refresh fails") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) -} - -func TestMiddleware_ExpiredTokens_NoRefreshToken(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "", // No refresh token - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, storage.ErrExpired) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, nil)) - - var nextCalled bool - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - nextCalled = true - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.False(t, nextCalled, "next handler should NOT be called when no refresh token available") - assert.Equal(t, http.StatusUnauthorized, rr.Code) - assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="invalid_token"`) -} - -func TestMiddleware_DefenseInDepth_ExpiredButNoError_RefreshSuccess(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Storage returns tokens with ExpiresAt in the past but NO error - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access-token", - RefreshToken: "my-refresh-token", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "refreshed-access-token", - RefreshToken: "refreshed-refresh-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) - mockStorage.EXPECT(). - GetUpstreamTokens(gomock.Any(), "session-123", "default"). - Return(expiredTokens, nil) // No error, but token is expired - - mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) - mockRefresher.EXPECT(). - RefreshAndStore(gomock.Any(), "session-123", expiredTokens). - Return(refreshedTokens, nil) - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetterFromMocks(mockStorage, mockRefresher)) - - var nextCalled bool - var capturedAuthHeader string - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - nextCalled = true - capturedAuthHeader = r.Header.Get("Authorization") - }) - - handler := middleware(nextHandler) - req := requestWithIdentity("session-123") - - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - assert.True(t, nextCalled, "next handler should be called after successful defense-in-depth refresh") - assert.Equal(t, "Bearer refreshed-access-token", capturedAuthHeader) -} - -// TestSingleFlightRefresh_ConcurrentRequests verifies that concurrent requests -// with the same expired session only trigger a single upstream refresh call. -// Without singleflight, providers that rotate refresh tokens (single-use) -// would fail all but the first concurrent caller. -func TestSingleFlightRefresh_ConcurrentRequests(t *testing.T) { - t.Parallel() - - const numRequests = 10 - var refreshCallCount atomic.Int32 - - expiredTokens := &storage.UpstreamTokens{ - AccessToken: "expired-access", - RefreshToken: "one-time-refresh", - ExpiresAt: time.Now().Add(-1 * time.Hour), - } - - refreshedTokens := &storage.UpstreamTokens{ - AccessToken: "fresh-access", - RefreshToken: "new-refresh", - ExpiresAt: time.Now().Add(1 * time.Hour), - } - - // Use a barrier so all goroutines complete GetUpstreamTokens before any - // enters singleflight. This guarantees they all contend on the same key. - var storageBarrier sync.WaitGroup - storageBarrier.Add(numRequests) - storageGate := make(chan struct{}) - - stor := &fakeTokenStorage{ - tokens: expiredTokens, - err: storage.ErrExpired, - barrier: &storageBarrier, - gate: storageGate, - } - proceed := make(chan struct{}) - refresher := &fakeRefresher{ - result: refreshedTokens, - callCount: &refreshCallCount, - proceed: proceed, - } - - // Create service with fake storage/refresher — singleflight lives in the service - svc := upstreamtoken.NewInProcessService(stor, refresher) - serviceGetter := func() upstreamtoken.Service { return svc } - - cfg := &Config{ProviderName: "default"} - middleware := createMiddlewareFunc(cfg, serviceGetter) - - nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) - handler := middleware(nextHandler) - - // Use a barrier to ensure all goroutines start at the same time - ready := make(chan struct{}) - var wg sync.WaitGroup - results := make([]int, numRequests) - - for i := range numRequests { - wg.Add(1) - go func(idx int) { - defer wg.Done() - <-ready // Wait for all goroutines to be ready - - req := requestWithIdentity("sf-concurrent-session") - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - results[idx] = rr.Code - }(i) - } - - // Release all goroutines simultaneously - close(ready) - - // Wait for ALL goroutines to reach GetUpstreamTokens - storageBarrier.Wait() - - // Release them all at once — they all proceed to singleflight.Do concurrently - close(storageGate) - - // Give goroutines a moment to enter singleflight, then let refresh complete - // The singleflight ensures only one actually calls RefreshAndStore - time.Sleep(10 * time.Millisecond) - close(proceed) - - wg.Wait() - - // KEY ASSERTION: RefreshAndStore should be called exactly once. - // Without singleflight, all 10 goroutines would call it independently. - assert.Equal(t, int32(1), refreshCallCount.Load(), - "RefreshAndStore should be called exactly once — singleflight deduplicates concurrent refreshes") - - // All requests should succeed - for i, code := range results { - assert.Equal(t, http.StatusOK, code, - "request %d should succeed", i) - } -} - -// fakeTokenStorage returns configured tokens and optionally blocks until -// a barrier is released, ensuring all goroutines reach storage before any -// proceeds to the singleflight refresh. -type fakeTokenStorage struct { - tokens *storage.UpstreamTokens - err error - barrier *sync.WaitGroup // if set, each call does barrier.Done() then waits - gate chan struct{} // if set, blocks until closed -} - -func (f *fakeTokenStorage) GetUpstreamTokens(_ context.Context, _, _ string) (*storage.UpstreamTokens, error) { - if f.barrier != nil { - f.barrier.Done() - } - if f.gate != nil { - <-f.gate - } - return f.tokens, f.err -} - -func (*fakeTokenStorage) GetAllUpstreamTokens(_ context.Context, _ string) (map[string]*storage.UpstreamTokens, error) { - return nil, nil -} - -func (*fakeTokenStorage) StoreUpstreamTokens(_ context.Context, _, _ string, _ *storage.UpstreamTokens) error { - return nil -} - -func (*fakeTokenStorage) DeleteUpstreamTokens(_ context.Context, _ string) error { - return nil -} - -// fakeRefresher counts calls and blocks until proceed is closed. -type fakeRefresher struct { - result *storage.UpstreamTokens - callCount *atomic.Int32 - proceed chan struct{} -} - -func (f *fakeRefresher) RefreshAndStore(_ context.Context, _ string, _ *storage.UpstreamTokens) (*storage.UpstreamTokens, error) { - f.callCount.Add(1) - <-f.proceed - return f.result, nil -} diff --git a/pkg/auth/upstreamtoken/mocks/mock_token_reader.go b/pkg/auth/upstreamtoken/mocks/mock_token_reader.go new file mode 100644 index 0000000000..c29a64f5c2 --- /dev/null +++ b/pkg/auth/upstreamtoken/mocks/mock_token_reader.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/auth/upstreamtoken (interfaces: TokenReader) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_token_reader.go -package=mocks github.com/stacklok/toolhive/pkg/auth/upstreamtoken TokenReader +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockTokenReader is a mock of TokenReader interface. +type MockTokenReader struct { + ctrl *gomock.Controller + recorder *MockTokenReaderMockRecorder + isgomock struct{} +} + +// MockTokenReaderMockRecorder is the mock recorder for MockTokenReader. +type MockTokenReaderMockRecorder struct { + mock *MockTokenReader +} + +// NewMockTokenReader creates a new mock instance. +func NewMockTokenReader(ctrl *gomock.Controller) *MockTokenReader { + mock := &MockTokenReader{ctrl: ctrl} + mock.recorder = &MockTokenReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenReader) EXPECT() *MockTokenReaderMockRecorder { + return m.recorder +} + +// GetAllValidTokens mocks base method. +func (m *MockTokenReader) GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllValidTokens", ctx, sessionID) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllValidTokens indicates an expected call of GetAllValidTokens. +func (mr *MockTokenReaderMockRecorder) GetAllValidTokens(ctx, sessionID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllValidTokens", reflect.TypeOf((*MockTokenReader)(nil).GetAllValidTokens), ctx, sessionID) +} diff --git a/pkg/auth/upstreamtoken/service.go b/pkg/auth/upstreamtoken/service.go index 537a0e0a4c..ff12423455 100644 --- a/pkg/auth/upstreamtoken/service.go +++ b/pkg/auth/upstreamtoken/service.go @@ -29,8 +29,11 @@ type InProcessService struct { sfGroup singleflight.Group } -// Compile-time check. -var _ Service = (*InProcessService)(nil) +// Compile-time checks. +var ( + _ Service = (*InProcessService)(nil) + _ TokenReader = (*InProcessService)(nil) +) // NewInProcessService creates a new InProcessService. // The refresher may be nil if upstream token refresh is not configured; @@ -77,6 +80,50 @@ func (s *InProcessService) GetValidTokens(ctx context.Context, sessionID, provid return &UpstreamCredential{AccessToken: tokens.AccessToken}, nil } +// GetAllValidTokens returns access tokens for all upstream providers in a session. +// Expired tokens are refreshed transparently; if refresh fails, the provider is +// omitted from the result so downstream middleware can return a clean 401. +func (s *InProcessService) GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) { + allTokens, err := s.storage.GetAllUpstreamTokens(ctx, sessionID) + if err != nil { + return nil, fmt.Errorf("bulk read upstream tokens: %w", err) + } + + if len(allTokens) == 0 { + return map[string]string{}, nil + } + + result := make(map[string]string, len(allTokens)) + // TODO(auth): Refresh providers in parallel using errgroup to avoid + // worst-case latency of N * refreshTimeout when multiple providers need refresh. + for providerName, tokens := range allTokens { + if tokens == nil { + continue + } + + // If token is not expired, use it directly. + if tokens.ExpiresAt.IsZero() || !tokens.IsExpired(time.Now()) { + result[providerName] = tokens.AccessToken + continue + } + + // Token is expired — attempt refresh. + refreshed, refreshErr := s.refreshOrFail(ctx, sessionID, providerName, tokens) + if refreshErr != nil { + // Refresh failed — omit provider so downstream middleware returns 401. + slog.WarnContext(ctx, "omitting provider with unrefreshable expired token", + "session_id", sessionID, + "provider", providerName, + "error", refreshErr, + ) + continue + } + result[providerName] = refreshed.AccessToken + } + + return result, nil +} + // refreshOrFail attempts a singleflight-deduplicated refresh and maps errors // to the service's sentinel errors. func (s *InProcessService) refreshOrFail( @@ -90,7 +137,10 @@ func (s *InProcessService) refreshOrFail( } if s.refresher == nil { - slog.Debug("token refresher not configured, cannot refresh upstream tokens") + slog.Debug("token refresher not configured, cannot refresh upstream tokens", + "session_id", sessionID, + "provider", providerName, + ) return nil, ErrNoRefreshToken } @@ -109,6 +159,7 @@ func (s *InProcessService) refreshOrFail( if err != nil { slog.Warn("upstream token refresh failed", "session_id", sessionID, + "provider", providerName, "error", err, ) return nil, fmt.Errorf("%w: %w", ErrRefreshFailed, err) diff --git a/pkg/auth/upstreamtoken/service_test.go b/pkg/auth/upstreamtoken/service_test.go index 0dde8d1acc..70a3553d68 100644 --- a/pkg/auth/upstreamtoken/service_test.go +++ b/pkg/auth/upstreamtoken/service_test.go @@ -233,3 +233,213 @@ func TestInProcessService_NilRefresher(t *testing.T) { assert.ErrorIs(t, err, ErrNoRefreshToken) assert.Nil(t, cred) } + +func TestInProcessService_GetAllValidTokens(t *testing.T) { + t.Parallel() + + freshTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "github-access-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + freshTokens2 := &storage.UpstreamTokens{ + ProviderID: "atlassian", + AccessToken: "atlassian-access-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + expiredTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "expired-github-token", + RefreshToken: "github-refresh-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + refreshedTokens := &storage.UpstreamTokens{ + ProviderID: "github", + AccessToken: "new-github-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + tests := []struct { + name string + sessionID string + setupStorage func(*storagemocks.MockUpstreamTokenStorage) + setupRefresher func(*storagemocks.MockUpstreamTokenRefresher) + wantResult map[string]string + wantErr bool + }{ + { + name: "all fresh tokens returned directly", + sessionID: "session-1", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-1"). + Return(map[string]*storage.UpstreamTokens{ + "github": freshTokens, + "atlassian": freshTokens2, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "github-access-token", + "atlassian": "atlassian-access-token", + }, + }, + { + name: "mixed fresh and expired with successful refresh", + sessionID: "session-2", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-2"). + Return(map[string]*storage.UpstreamTokens{ + "atlassian": freshTokens2, + "github": expiredTokens, + }, nil) + }, + setupRefresher: func(r *storagemocks.MockUpstreamTokenRefresher) { + r.EXPECT().RefreshAndStore(gomock.Any(), "session-2", expiredTokens). + Return(refreshedTokens, nil) + }, + wantResult: map[string]string{ + "atlassian": "atlassian-access-token", + "github": "new-github-token", + }, + }, + { + name: "expired refresh fails omits provider", + sessionID: "session-3", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-3"). + Return(map[string]*storage.UpstreamTokens{ + "github": expiredTokens, + }, nil) + }, + setupRefresher: func(r *storagemocks.MockUpstreamTokenRefresher) { + r.EXPECT().RefreshAndStore(gomock.Any(), "session-3", expiredTokens). + Return(nil, errors.New("upstream IDP unavailable")) + }, + wantResult: map[string]string{}, + }, + { + name: "empty session returns empty map", + sessionID: "session-4", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-4"). + Return(map[string]*storage.UpstreamTokens{}, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{}, + }, + { + name: "storage error propagated", + sessionID: "session-5", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-5"). + Return(nil, errors.New("redis connection lost")) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantErr: true, + }, + { + name: "nil tokens entry skipped", + sessionID: "session-6", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-6"). + Return(map[string]*storage.UpstreamTokens{ + "github": freshTokens, + "atlassian": nil, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "github-access-token", + }, + }, + { + name: "expired with no refresh token omits provider", + sessionID: "session-7", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-7"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "expired-no-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + RefreshToken: "", + }, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{}, + }, + { + name: "zero ExpiresAt treated as non-expiring", + sessionID: "session-8", + setupStorage: func(s *storagemocks.MockUpstreamTokenStorage) { + s.EXPECT().GetAllUpstreamTokens(gomock.Any(), "session-8"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "no-expiry-token", + ExpiresAt: time.Time{}, + }, + }, nil) + }, + setupRefresher: func(_ *storagemocks.MockUpstreamTokenRefresher) {}, + wantResult: map[string]string{ + "github": "no-expiry-token", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockRefresher := storagemocks.NewMockUpstreamTokenRefresher(ctrl) + + tt.setupStorage(mockStorage) + tt.setupRefresher(mockRefresher) + + svc := NewInProcessService(mockStorage, mockRefresher) + + result, err := svc.GetAllValidTokens(context.Background(), tt.sessionID) + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, result) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +// TestInProcessService_GetAllValidTokens_NilRefresher verifies that when the +// refresher is nil, expired tokens in the bulk path are omitted (not panicking). +func TestInProcessService_GetAllValidTokens_NilRefresher(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + mockStorage.EXPECT(). + GetAllUpstreamTokens(gomock.Any(), "session-1"). + Return(map[string]*storage.UpstreamTokens{ + "github": { + ProviderID: "github", + AccessToken: "expired-token", + RefreshToken: "has-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, + }, nil) + + svc := NewInProcessService(mockStorage, nil) + + result, err := svc.GetAllValidTokens(context.Background(), "session-1") + + require.NoError(t, err) + assert.Equal(t, map[string]string{}, result) +} diff --git a/pkg/auth/upstreamtoken/types.go b/pkg/auth/upstreamtoken/types.go index 7eafa96577..1668f71144 100644 --- a/pkg/auth/upstreamtoken/types.go +++ b/pkg/auth/upstreamtoken/types.go @@ -5,6 +5,8 @@ // lifecycle, including transparent refresh of expired access tokens. package upstreamtoken +//go:generate go run go.uber.org/mock/mockgen -destination=mocks/mock_token_reader.go -package=mocks github.com/stacklok/toolhive/pkg/auth/upstreamtoken TokenReader + import "context" // TokenSessionIDClaimKey is the JWT claim key for the token session ID. @@ -19,6 +21,19 @@ type UpstreamCredential struct { AccessToken string } +// TokenReader retrieves upstream provider access tokens for a session. +// This narrow interface decouples the auth middleware from storage internals. +// +// TODO(auth): Consider enriching the return type from map[string]string to +// map[string]UpstreamCredential to carry per-provider freshness/error metadata. +type TokenReader interface { + // GetAllValidTokens returns access tokens for all upstream providers in a session. + // Expired tokens are refreshed transparently when possible; if refresh fails, + // the provider is omitted from the result. + // Returns an empty map (not error) for unknown sessions. + GetAllValidTokens(ctx context.Context, sessionID string) (map[string]string, error) +} + // Service owns the upstream token lifecycle: read, refresh, error handling. type Service interface { // GetValidTokens returns a valid upstream credential for a session and provider. diff --git a/pkg/auth/utils.go b/pkg/auth/utils.go index dc5350b9fe..10137afb9c 100644 --- a/pkg/auth/utils.go +++ b/pkg/auth/utils.go @@ -61,13 +61,13 @@ func ExtractBearerToken(r *http.Request) (string, error) { // GetAuthenticationMiddleware returns the appropriate authentication middleware based on the configuration. // If OIDC config is provided, it returns JWT middleware. Otherwise, it returns local user middleware. -func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, +func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, opts ...TokenValidatorOption, ) (func(http.Handler) http.Handler, http.Handler, error) { if oidcConfig != nil { slog.Debug("oidc validation enabled") // Create JWT validator - jwtValidator, err := NewTokenValidator(ctx, *oidcConfig) + jwtValidator, err := NewTokenValidator(ctx, *oidcConfig, opts...) if err != nil { return nil, nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 0d102774d7..f9e107f617 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -76,10 +76,11 @@ type Runner struct { // Only initialized when Config.EmbeddedAuthServerConfig is set. embeddedAuthServer *authserverrunner.EmbeddedAuthServer - // upstreamTokenService is the upstream token service, created eagerly - // after the embedded auth server is initialized in Run(). + // upstreamTokenReader provides read-only access to upstream tokens for + // identity enrichment in auth middleware. Set when the embedded auth + // server is initialized in Run(). // Nil when no embedded auth server is configured. - upstreamTokenService upstreamtoken.Service + upstreamTokenReader upstreamtoken.TokenReader } // statusManagerAdapter adapts statuses.StatusManager to auth.StatusUpdater interface @@ -130,16 +131,11 @@ func (r *Runner) GetConfig() types.RunnerConfig { return r.Config } -// GetUpstreamTokenService returns an accessor for the upstream token service. -// The returned function should be called at request time; it returns nil if -// the embedded auth server is not configured. -// -// This method always returns a non-nil function. Service availability is -// determined at request time when the returned function is called. -func (r *Runner) GetUpstreamTokenService() func() upstreamtoken.Service { - return func() upstreamtoken.Service { - return r.upstreamTokenService - } +// GetUpstreamTokenReader returns the UpstreamTokenReader for identity +// enrichment in the auth middleware. Returns nil if no embedded auth +// server is configured. +func (r *Runner) GetUpstreamTokenReader() upstreamtoken.TokenReader { + return r.upstreamTokenReader } // GetName returns the name of the mcp-service from the runner config (implements types.RunnerConfig) @@ -258,7 +254,7 @@ func (r *Runner) Run(ctx context.Context) error { // InProcessService handles this gracefully (returns ErrNoRefreshToken). stor := r.embeddedAuthServer.IDPTokenStorage() refresher := r.embeddedAuthServer.UpstreamTokenRefresher() - r.upstreamTokenService = upstreamtoken.NewInProcessService(stor, refresher) + r.upstreamTokenReader = upstreamtoken.NewInProcessService(stor, refresher) // Mount auth server routes at specific prefixes to avoid conflicts with MCP endpoints // (e.g., /.well-known/oauth-protected-resource is an MCP endpoint, not auth server) diff --git a/pkg/runner/runner_test.go b/pkg/runner/runner_test.go index deef4bfbd6..03eecdd394 100644 --- a/pkg/runner/runner_test.go +++ b/pkg/runner/runner_test.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" "github.com/stacklok/toolhive/pkg/authserver" authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner" + storagemocks "github.com/stacklok/toolhive/pkg/authserver/storage/mocks" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/transport/types" statusesmocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" @@ -463,57 +464,6 @@ func TestRunner_EmbeddedAuthServer_Integration(t *testing.T) { }) } -func TestRunner_GetUpstreamTokenService(t *testing.T) { - t.Parallel() - - t.Run("returns nil when no auth server configured", func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusManager := statusesmocks.NewMockStatusManager(ctrl) - - runConfig := NewRunConfig() - runner := NewRunner(runConfig, mockStatusManager) - - serviceGetter := runner.GetUpstreamTokenService() - assert.NotNil(t, serviceGetter, "GetUpstreamTokenService should always return a non-nil function") - - svc := serviceGetter() - assert.Nil(t, svc, "service should be nil when no auth server is configured") - }) - - t.Run("returns service when upstreamTokenService is set", func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusManager := statusesmocks.NewMockStatusManager(ctrl) - - runConfig := NewRunConfig() - runner := NewRunner(runConfig, mockStatusManager) - - // Simulate what Run() does: create auth server, then set the service - authServerCfg := createMinimalAuthServerConfig() - embeddedServer, err := authserverrunner.NewEmbeddedAuthServer(context.Background(), authServerCfg) - require.NoError(t, err) - require.NotNil(t, embeddedServer) - defer func() { _ = embeddedServer.Close() }() - - runner.embeddedAuthServer = embeddedServer - - stor := embeddedServer.IDPTokenStorage() - refresher := embeddedServer.UpstreamTokenRefresher() - runner.upstreamTokenService = upstreamtoken.NewInProcessService(stor, refresher) - - serviceGetter := runner.GetUpstreamTokenService() - svc := serviceGetter() - assert.NotNil(t, svc, "service should not be nil when upstreamTokenService is set") - }) -} - func TestRunner_RejectsMultiUpstreamConfig(t *testing.T) { t.Parallel() @@ -560,3 +510,30 @@ func TestRunner_RejectsMultiUpstreamConfig(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "does not support multiple upstream providers") } + +func TestRunner_GetUpstreamTokenReader(t *testing.T) { + t.Parallel() + + t.Run("returns nil when no auth server configured", func(t *testing.T) { + t.Parallel() + + r := &Runner{} + reader := r.GetUpstreamTokenReader() + assert.Nil(t, reader) + }) + + t.Run("returns reader when auth server configured", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockStorage := storagemocks.NewMockUpstreamTokenStorage(ctrl) + svc := upstreamtoken.NewInProcessService(mockStorage, nil) + + r := &Runner{ + upstreamTokenReader: svc, + } + reader := r.GetUpstreamTokenReader() + assert.NotNil(t, reader) + assert.Equal(t, svc, reader) + }) +} diff --git a/pkg/transport/types/mocks/mock_transport.go b/pkg/transport/types/mocks/mock_transport.go index f852f046e4..9ae9e5b009 100644 --- a/pkg/transport/types/mocks/mock_transport.go +++ b/pkg/transport/types/mocks/mock_transport.go @@ -123,18 +123,18 @@ func (mr *MockMiddlewareRunnerMockRecorder) GetConfig() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetConfig)) } -// GetUpstreamTokenService mocks base method. -func (m *MockMiddlewareRunner) GetUpstreamTokenService() func() upstreamtoken.Service { +// GetUpstreamTokenReader mocks base method. +func (m *MockMiddlewareRunner) GetUpstreamTokenReader() upstreamtoken.TokenReader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUpstreamTokenService") - ret0, _ := ret[0].(func() upstreamtoken.Service) + ret := m.ctrl.Call(m, "GetUpstreamTokenReader") + ret0, _ := ret[0].(upstreamtoken.TokenReader) return ret0 } -// GetUpstreamTokenService indicates an expected call of GetUpstreamTokenService. -func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenService() *gomock.Call { +// GetUpstreamTokenReader indicates an expected call of GetUpstreamTokenReader. +func (mr *MockMiddlewareRunnerMockRecorder) GetUpstreamTokenReader() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenService", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenService)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpstreamTokenReader", reflect.TypeOf((*MockMiddlewareRunner)(nil).GetUpstreamTokenReader)) } // SetAuthInfoHandler mocks base method. diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index a8892863f1..ae1216f676 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -77,13 +77,9 @@ type MiddlewareRunner interface { // GetConfig returns a config interface for middleware to access runner configuration GetConfig() RunnerConfig - // GetUpstreamTokenService returns an accessor for the upstream token service. - // The returned function should be called at request time; it returns nil if - // the embedded auth server is not configured. - // - // This method always returns a non-nil function. Service availability is - // determined at request time when the returned function is called. - GetUpstreamTokenService() func() upstreamtoken.Service + // GetUpstreamTokenReader returns a TokenReader for identity enrichment. + // Returns nil if the embedded auth server is not configured. + GetUpstreamTokenReader() upstreamtoken.TokenReader } // RunnerConfig defines the config interface needed by middleware to access runner configuration