From 4be258db9cbc33104a3ea1c611d53438e9423735 Mon Sep 17 00:00:00 2001 From: Christian Troelsen Date: Mon, 24 Nov 2025 14:29:15 +0000 Subject: [PATCH] add more custom config vars Signed-off-by: Christian Troelsen fix suggestions by coderabbit Signed-off-by: Christian Troelsen --- config.go | 71 ++++++++++++++++++++++++++++++++++++++++---- handlers.go | 9 +++++- metadata.go | 4 +-- provider/provider.go | 39 ++++++++++++++++-------- 4 files changed, 102 insertions(+), 21 deletions(-) diff --git a/config.go b/config.go index d4e5071..73143a2 100644 --- a/config.go +++ b/config.go @@ -2,7 +2,10 @@ package oauth import ( "fmt" + "strconv" + "strings" + "github.com/golang-jwt/jwt/v5" "github.com/tuannvm/oauth-mcp-proxy/provider" ) @@ -18,6 +21,7 @@ type Config struct { Audience string ClientID string ClientSecret string + Scopes []string // Server configuration ServerURL string // Full URL of the MCP server @@ -31,6 +35,12 @@ type Config struct { // Implement the Logger interface (Debug, Info, Warn, Error methods) to // integrate with your application's logging system (e.g., zap, logrus). Logger Logger + + // Token validation configuration + SkipIssuerCheck bool + SkipAudienceCheck bool + SkipExpiryCheck bool + TokenValidationFuncs []func(claims jwt.MapClaims) error } // Validate validates the configuration @@ -119,11 +129,15 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) { func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) { // Convert root Config to provider.Config providerCfg := &provider.Config{ - Provider: cfg.Provider, - Issuer: cfg.Issuer, - Audience: cfg.Audience, - JWTSecret: cfg.JWTSecret, - Logger: logger, + Provider: cfg.Provider, + Issuer: cfg.Issuer, + Audience: cfg.Audience, + JWTSecret: cfg.JWTSecret, + Logger: logger, + SkipIssuerCheck: cfg.SkipIssuerCheck, + SkipAudienceCheck: cfg.SkipAudienceCheck, + SkipExpiryCheck: cfg.SkipExpiryCheck, + TokenValidatorFuncs: cfg.TokenValidationFuncs, } var validator provider.TokenValidator @@ -217,12 +231,36 @@ func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder { return b } +// WithScopes sets the OIDC scopes +func (b *ConfigBuilder) WithScopes(scopes []string) *ConfigBuilder { + b.config.Scopes = scopes + return b +} + // WithLogger sets the logger func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder { b.config.Logger = logger return b } +// WithSkipIssuerCheck sets issuer check toggle +func (b *ConfigBuilder) WithSkipIssuerCheck(skipIssuerCheck bool) *ConfigBuilder { + b.config.SkipIssuerCheck = skipIssuerCheck + return b +} + +// WithSkipAudienceCheck sets audience check toggle +func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder { + b.config.SkipAudienceCheck = skipAudienceCheck + return b +} + +// WithSkipExpiryCheck sets expiry check toggle +func (b *ConfigBuilder) WithSkipExpiryCheck(skipExpiryCheck bool) *ConfigBuilder { + b.config.SkipExpiryCheck = skipExpiryCheck + return b +} + // WithServerURL sets the full server URL directly func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder { b.config.ServerURL = url @@ -281,6 +319,12 @@ func FromEnv() (*Config, error) { jwtSecret := getEnv("JWT_SECRET", "") + scopes := []string{} + scopesEnv := getEnv("OIDC_SCOPES", "") + if scopesEnv != "" { + scopes = strings.Split(scopesEnv, " ") + } + return NewConfigBuilder(). WithMode(getEnv("OAUTH_MODE", "")). WithProvider(getEnv("OAUTH_PROVIDER", "")). @@ -289,7 +333,24 @@ func FromEnv() (*Config, error) { WithAudience(getEnv("OIDC_AUDIENCE", "")). WithClientID(getEnv("OIDC_CLIENT_ID", "")). WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")). + WithScopes(scopes). + WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)). + WithSkipIssuerCheck(parseBoolEnv("OIDC_SKIP_ISSUER_CHECK", false)). + WithSkipExpiryCheck(parseBoolEnv("OIDC_SKIP_EXPIRY_CHECK", false)). WithServerURL(serverURL). WithJWTSecret([]byte(jwtSecret)). Build() } + +// parseBoolEnv parses a boolean environment variable +func parseBoolEnv(key string, defaultVal bool) bool { + val := getEnv(key, "") + if val == "" { + return defaultVal + } + parsed, err := strconv.ParseBool(val) + if err != nil { + return defaultVal + } + return parsed +} diff --git a/handlers.go b/handlers.go index a1fec89..780a98b 100644 --- a/handlers.go +++ b/handlers.go @@ -46,6 +46,7 @@ type OAuth2Config struct { Audience string ClientID string ClientSecret string + Scopes []string // Server configuration MCPHost string @@ -96,7 +97,7 @@ func NewOAuth2Handler(cfg *OAuth2Config, logger Logger) *OAuth2Handler { ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, Endpoint: endpoint, - Scopes: []string{"openid", "profile", "email"}, + Scopes: cfg.Scopes, } // Log client configuration type for debugging @@ -177,6 +178,11 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config { mcpURL = getEnv("MCP_URL", fmt.Sprintf("%s://%s:%s", scheme, mcpHost, mcpPort)) } + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "profile", "email"} + } + return &OAuth2Config{ Enabled: true, Mode: cfg.Mode, @@ -186,6 +192,7 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config { Audience: cfg.Audience, ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, + Scopes: scopes, MCPHost: mcpHost, MCPPort: mcpPort, MCPURL: mcpURL, diff --git a/metadata.go b/metadata.go index 81c4e56..a6a9116 100644 --- a/metadata.go +++ b/metadata.go @@ -242,7 +242,7 @@ func (h *OAuth2Handler) HandleOIDCDiscovery(w http.ResponseWriter, r *http.Reque "token_endpoint_auth_methods_supported": []string{"none"}, "code_challenge_methods_supported": []string{"plain", "S256"}, "subject_types_supported": []string{"public"}, - "scopes_supported": []string{"openid", "profile", "email"}, + "scopes_supported": h.config.Scopes, } // Add provider-specific fields @@ -282,7 +282,7 @@ func (h *OAuth2Handler) GetAuthorizationServerMetadata() map[string]interface{} "grant_types_supported": []string{"authorization_code"}, "token_endpoint_auth_methods_supported": []string{"none"}, "code_challenge_methods_supported": []string{"plain", "S256"}, - "scopes_supported": []string{"openid", "profile", "email"}, + "scopes_supported": h.config.Scopes, } // Add provider-specific endpoints diff --git a/provider/provider.go b/provider/provider.go index 39c5847..bd63c93 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -30,11 +30,15 @@ type Logger interface { // Config holds OAuth configuration (subset needed by provider) type Config struct { - Provider string - Issuer string - Audience string - JWTSecret []byte - Logger Logger + Provider string + Issuer string + Audience string + JWTSecret []byte + Logger Logger + SkipIssuerCheck bool + SkipAudienceCheck bool + SkipExpiryCheck bool + TokenValidatorFuncs []func(claims jwt.MapClaims) error } // TokenValidator interface for OAuth token validation @@ -52,10 +56,11 @@ type HMACValidator struct { // OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure) type OIDCValidator struct { - verifier *oidc.IDTokenVerifier - provider *oidc.Provider - audience string - logger Logger + verifier *oidc.IDTokenVerifier + provider *oidc.Provider + audience string + TokenValidatorFuncs []func(claims jwt.MapClaims) error + logger Logger } // Initialize sets up the HMAC validator with JWT secret and audience @@ -90,7 +95,6 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) ( } return []byte(v.secret), nil }) - if err != nil { return nil, fmt.Errorf("failed to parse and validate token: %w", err) } @@ -204,15 +208,16 @@ func (v *OIDCValidator) Initialize(cfg *Config) error { verifier := provider.Verifier(&oidc.Config{ ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85 SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256}, - SkipClientIDCheck: false, // Always validate if ClientID is provided - SkipExpiryCheck: false, // Verify expiration - SkipIssuerCheck: false, // Verify issuer + SkipClientIDCheck: cfg.SkipAudienceCheck, + SkipExpiryCheck: cfg.SkipExpiryCheck, + SkipIssuerCheck: cfg.SkipIssuerCheck, }) v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience) v.provider = provider v.verifier = verifier + v.TokenValidatorFuncs = cfg.TokenValidatorFuncs return nil } @@ -261,6 +266,14 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) ( return nil, fmt.Errorf("audience validation failed: %w", err) } + // Run extra validation functions + for _, fn := range v.TokenValidatorFuncs { + err := fn(rawClaims) + if err != nil { + return nil, fmt.Errorf("validation function failed with error: %w", err) + } + } + return &User{ Subject: claims.Subject, Username: claims.PreferredUsername,