Skip to content

Commit 4d5a1f7

Browse files
committed
add support for arbitrary token validation funcs
1 parent 0ab422d commit 4d5a1f7

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

config.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"strconv"
66
"strings"
77

8+
"github.com/golang-jwt/jwt/v5"
89
"github.com/tuannvm/oauth-mcp-proxy/provider"
910
)
1011

@@ -35,10 +36,11 @@ type Config struct {
3536
// integrate with your application's logging system (e.g., zap, logrus).
3637
Logger Logger
3738

38-
// Validation skip configuration
39-
SkipIssuerCheck bool
40-
SkipAudienceCheck bool
41-
SkipExpiryCheck bool
39+
// Token validation configuration
40+
SkipIssuerCheck bool
41+
SkipAudienceCheck bool
42+
SkipExpiryCheck bool
43+
TokenValidationFuncs []func(claims jwt.MapClaims) error
4244
}
4345

4446
// Validate validates the configuration
@@ -127,14 +129,15 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
127129
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
128130
// Convert root Config to provider.Config
129131
providerCfg := &provider.Config{
130-
Provider: cfg.Provider,
131-
Issuer: cfg.Issuer,
132-
Audience: cfg.Audience,
133-
JWTSecret: cfg.JWTSecret,
134-
Logger: logger,
135-
SkipIssuerCheck: cfg.SkipIssuerCheck,
136-
SkipAudienceCheck: cfg.SkipAudienceCheck,
137-
SkipExpiryCheck: cfg.SkipExpiryCheck,
132+
Provider: cfg.Provider,
133+
Issuer: cfg.Issuer,
134+
Audience: cfg.Audience,
135+
JWTSecret: cfg.JWTSecret,
136+
Logger: logger,
137+
SkipIssuerCheck: cfg.SkipIssuerCheck,
138+
SkipAudienceCheck: cfg.SkipAudienceCheck,
139+
SkipExpiryCheck: cfg.SkipExpiryCheck,
140+
TokenValidatorFuncs: cfg.TokenValidationFuncs,
138141
}
139142

140143
var validator provider.TokenValidator

provider/provider.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ type Logger interface {
3030

3131
// Config holds OAuth configuration (subset needed by provider)
3232
type Config struct {
33-
Provider string
34-
Issuer string
35-
Audience string
36-
JWTSecret []byte
37-
Logger Logger
38-
SkipIssuerCheck bool
39-
SkipAudienceCheck bool
40-
SkipExpiryCheck bool
33+
Provider string
34+
Issuer string
35+
Audience string
36+
JWTSecret []byte
37+
Logger Logger
38+
SkipIssuerCheck bool
39+
SkipAudienceCheck bool
40+
SkipExpiryCheck bool
41+
TokenValidatorFuncs []func(claims jwt.MapClaims) error
4142
}
4243

4344
// TokenValidator interface for OAuth token validation
@@ -55,10 +56,11 @@ type HMACValidator struct {
5556

5657
// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
5758
type OIDCValidator struct {
58-
verifier *oidc.IDTokenVerifier
59-
provider *oidc.Provider
60-
audience string
61-
logger Logger
59+
verifier *oidc.IDTokenVerifier
60+
provider *oidc.Provider
61+
audience string
62+
TokenValidatorFuncs []func(claims jwt.MapClaims) error
63+
logger Logger
6264
}
6365

6466
// Initialize sets up the HMAC validator with JWT secret and audience
@@ -215,6 +217,7 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
215217

216218
v.provider = provider
217219
v.verifier = verifier
220+
v.TokenValidatorFuncs = cfg.TokenValidatorFuncs
218221
return nil
219222
}
220223

@@ -263,6 +266,14 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
263266
return nil, fmt.Errorf("audience validation failed: %w", err)
264267
}
265268

269+
// Run extra validation functions
270+
for _, fn := range v.TokenValidatorFuncs {
271+
err := fn(rawClaims)
272+
if err != nil {
273+
return nil, fmt.Errorf("validation function failed with error: %w", err)
274+
}
275+
}
276+
266277
return &User{
267278
Subject: claims.Subject,
268279
Username: claims.PreferredUsername,

0 commit comments

Comments
 (0)