Skip to content
27 changes: 12 additions & 15 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,9 @@ sequenceDiagram
**Availability**: Automatically enabled when using the embedded auth server (`EmbeddedAuthServerConfig`)

**Responsibilities**:
- Extract the token session ID (`tsid`) claim from the ToolHive JWT
- Look up the stored upstream IdP tokens associated with that session
- Read the upstream access token for the configured provider from `Identity.UpstreamTokens`
- Inject the upstream access token into the request (replacing Authorization header or using a custom header)
- Return 401 Unauthorized with WWW-Authenticate header when tokens are expired or not found
- Gracefully proceed without modification if identity, session ID, or storage is unavailable
- Return 401 Unauthorized with WWW-Authenticate header when the provider token is missing or empty

**Configuration**:

Expand All @@ -148,16 +146,16 @@ sequenceDiagram

**Behavior**:
- **Automatic activation**: Enabled whenever the embedded auth server is configured, even without explicit `UpstreamSwapConfig`
- **Expired tokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required
- **Tokens not found**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required
- **Missing identity/tsid**: Proceeds without modification (logs debug message)
- **Storage unavailable**: Proceeds without modification (logs warning)
- **Other storage errors**: Returns 503 Service Unavailable to fail closed (logs warning)
- **Provider token found**: Injects the token into the request using the configured header strategy
- **Provider not in UpstreamTokens**: Returns 401 Unauthorized with `WWW-Authenticate` header indicating re-authentication is required
- **Empty token value**: Returns 401 Unauthorized (same as missing provider)
- **No identity in context**: Passes through without modification (auth middleware not in chain)
- **Storage unavailable**: The auth middleware returns 503 before the request reaches this middleware

**Context Data Used**:
- Identity from Authentication middleware (specifically the `tsid` claim)
- `Identity.UpstreamTokens` map populated by the Authentication middleware during JWT validation

**Note**: This middleware is designed for use with ToolHive's embedded auth server. It reads from the auth server's token storage to retrieve upstream IdP tokens that were captured during the OAuth authorization flow.
**Note**: This middleware is a simple map reader. All upstream token loading, refresh, and error handling occurs in the Authentication middleware (Step 3), which populates `Identity.UpstreamTokens` from the token session ID (`tsid`) claim during JWT validation.

---

Expand Down Expand Up @@ -785,10 +783,9 @@ type MiddlewareRunner interface {
// GetConfig returns a config interface for middleware to access runner configuration
GetConfig() RunnerConfig

// GetUpstreamTokenService returns a lazy accessor for the upstream token service.
// Returns a function that provides the service at request time.
// Used by upstream swap middleware to get valid upstream tokens (with transparent refresh).
GetUpstreamTokenService() func() upstreamtoken.Service
// GetUpstreamTokenReader returns an UpstreamTokenReader for identity enrichment.
// Returns nil if the embedded auth server is not configured.
GetUpstreamTokenReader() upstreamtoken.UpstreamTokenReader
}
```

Expand Down
25 changes: 24 additions & 1 deletion pkg/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) {
return nil, errors.New("missing or invalid 'sub' claim (required by OIDC Core 1.0 § 5.1)")
}

// Filter internal claims that should not be externalized (e.g., in
// webhook payloads or audit logs). The tsid is a session identifier
// used to look up upstream tokens in storage; exposing it widens the
// attack surface if a webhook receiver is compromised.
filteredClaims := filterInternalClaims(claims)

identity := &Identity{
PrincipalInfo: PrincipalInfo{
Subject: sub,
Claims: claims,
Claims: filteredClaims,
},
Token: token,
TokenType: "Bearer",
Expand All @@ -86,3 +92,20 @@ func claimsToIdentity(claims jwt.MapClaims, token string) (*Identity, error) {

return identity, nil
}

// internalClaims are JWT claim keys used internally by the auth server
// that must not be externalized in webhook payloads, audit logs, etc.
// "tsid" is the token session ID used to look up upstream tokens in storage.
var internalClaims = []string{"tsid"}

// filterInternalClaims returns a copy of claims with internal keys removed.
func filterInternalClaims(claims jwt.MapClaims) jwt.MapClaims {
filtered := make(jwt.MapClaims, len(claims))
for k, v := range claims {
filtered[k] = v
}
for _, key := range internalClaims {
delete(filtered, key)
}
return filtered
}
56 changes: 40 additions & 16 deletions pkg/auth/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ type Identity struct {

// Metadata stores additional identity information.
Metadata map[string]string

// UpstreamTokens maps upstream provider names to their access tokens.
// This is populated by the auth middleware when an embedded auth server
// is active and the JWT contains a token session ID (tsid claim).
// Redacted in MarshalJSON() to prevent token leakage.
// MUST NOT be mutated after the Identity is placed in the request context.
UpstreamTokens map[string]string
}

// String returns a string representation of the Identity with sensitive fields redacted.
Expand All @@ -77,30 +84,47 @@ func (i *Identity) MarshalJSON() ([]byte, error) {

// Create a safe representation with lowercase field names and redacted token
type SafeIdentity struct {
Subject string `json:"subject"`
Name string `json:"name"`
Email string `json:"email"`
Groups []string `json:"groups"`
Claims map[string]any `json:"claims"`
Token string `json:"token"`
TokenType string `json:"tokenType"`
Metadata map[string]string `json:"metadata"`
Subject string `json:"subject"`
Name string `json:"name"`
Email string `json:"email"`
Groups []string `json:"groups"`
Claims map[string]any `json:"claims"`
Token string `json:"token"`
TokenType string `json:"tokenType"`
Metadata map[string]string `json:"metadata"`
UpstreamTokens map[string]string `json:"upstreamTokens,omitempty"`
}

token := i.Token
if token != "" {
token = "REDACTED"
}

// Redact upstream tokens: preserve keys, replace non-empty values
var redactedUpstreamTokens map[string]string
// Guard with len() > 0 (not != nil) so that both nil and empty maps
// produce a nil redactedUpstreamTokens, which omitempty then omits.
if len(i.UpstreamTokens) > 0 {
redactedUpstreamTokens = make(map[string]string, len(i.UpstreamTokens))
for k, v := range i.UpstreamTokens {
if v != "" {
redactedUpstreamTokens[k] = "REDACTED"
} else {
redactedUpstreamTokens[k] = ""
}
}
}

return json.Marshal(&SafeIdentity{
Subject: i.Subject,
Name: i.Name,
Email: i.Email,
Groups: i.Groups,
Claims: i.Claims,
Token: token,
TokenType: i.TokenType,
Metadata: i.Metadata,
Subject: i.Subject,
Name: i.Name,
Email: i.Email,
Groups: i.Groups,
Claims: i.Claims,
Token: token,
TokenType: i.TokenType,
Metadata: i.Metadata,
UpstreamTokens: redactedUpstreamTokens,
})
}

Expand Down
96 changes: 96 additions & 0 deletions pkg/auth/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ func TestIdentity_String(t *testing.T) {
identity: nil,
want: "<nil>",
},
{
name: "does_not_leak_upstream_tokens",
identity: &Identity{
PrincipalInfo: PrincipalInfo{Subject: "user123"},
UpstreamTokens: map[string]string{
"github": "gho_secret123",
},
},
want: `Identity{Subject:"user123"}`,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -293,6 +303,92 @@ func TestIdentity_MarshalJSON(t *testing.T) {
assert.Equal(t, "null", string(data))
},
},
{
name: "redacts_upstream_tokens",
identity: &Identity{
PrincipalInfo: PrincipalInfo{Subject: "user123"},
UpstreamTokens: map[string]string{
"github": "gho_secret123",
"atlassian": "atl_secret456",
},
},
wantErr: false,
checkFunc: func(t *testing.T, data []byte) {
t.Helper()

var result map[string]any
err := json.Unmarshal(data, &result)
require.NoError(t, err)

tokens, ok := result["upstreamTokens"].(map[string]any)
require.True(t, ok, "upstreamTokens should be a map")
assert.Equal(t, "REDACTED", tokens["github"])
assert.Equal(t, "REDACTED", tokens["atlassian"])
assert.NotContains(t, string(data), "gho_secret123")
assert.NotContains(t, string(data), "atl_secret456")
},
},
{
name: "empty_upstream_tokens_omitted",
identity: &Identity{
PrincipalInfo: PrincipalInfo{Subject: "user123"},
UpstreamTokens: map[string]string{},
},
wantErr: false,
checkFunc: func(t *testing.T, data []byte) {
t.Helper()

var result map[string]any
err := json.Unmarshal(data, &result)
require.NoError(t, err)

// Empty map should be omitted because len() == 0 produces nil redacted map
_, exists := result["upstreamTokens"]
assert.False(t, exists, "empty upstreamTokens should be omitted")
},
},
{
name: "nil_upstream_tokens_omitted",
identity: &Identity{
PrincipalInfo: PrincipalInfo{Subject: "user123"},
UpstreamTokens: nil,
},
wantErr: false,
checkFunc: func(t *testing.T, data []byte) {
t.Helper()

var result map[string]any
err := json.Unmarshal(data, &result)
require.NoError(t, err)

_, exists := result["upstreamTokens"]
assert.False(t, exists, "nil upstreamTokens should be omitted")
},
},
{
name: "upstream_tokens_mixed_empty_and_populated",
identity: &Identity{
PrincipalInfo: PrincipalInfo{Subject: "user123"},
UpstreamTokens: map[string]string{
"github": "gho_secret123",
"pending": "",
},
},
wantErr: false,
checkFunc: func(t *testing.T, data []byte) {
t.Helper()

var result map[string]any
err := json.Unmarshal(data, &result)
require.NoError(t, err)

tokens, ok := result["upstreamTokens"].(map[string]any)
require.True(t, ok, "upstreamTokens should be a map")
assert.Equal(t, "REDACTED", tokens["github"])
assert.Equal(t, "", tokens["pending"])
assert.NotContains(t, string(data), "gho_secret123")
},
},
}

for _, tt := range tests {
Expand Down
7 changes: 6 additions & 1 deletion pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
return fmt.Errorf("failed to unmarshal auth middleware parameters: %w", err)
}

middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig)
var opts []TokenValidatorOption
if reader := runner.GetUpstreamTokenReader(); reader != nil {
opts = append(opts, WithUpstreamTokenReader(reader))
}

middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...)
if err != nil {
return fmt.Errorf("failed to create authentication middleware: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) {
// Create mock runner
mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)

// Expect AddMiddleware to be called with a middleware instance
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) {
// Verify it's our auth middleware
Expand Down Expand Up @@ -210,6 +213,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) {

mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)

// Expect AddMiddleware to be called
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any())

Expand Down
Loading
Loading