Skip to content

Commit 8beec21

Browse files
committed
add more custom config vars
Signed-off-by: Christian Troelsen <christian.troelsen@tryg.dk>
1 parent 14f9009 commit 8beec21

File tree

4 files changed

+96
-20
lines changed

4 files changed

+96
-20
lines changed

config.go

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package oauth
22

33
import (
44
"fmt"
5+
"strconv"
6+
"strings"
57

8+
"github.com/golang-jwt/jwt/v5"
69
"github.com/tuannvm/oauth-mcp-proxy/provider"
710
)
811

@@ -18,6 +21,7 @@ type Config struct {
1821
Audience string
1922
ClientID string
2023
ClientSecret string
24+
Scopes []string
2125

2226
// Server configuration
2327
ServerURL string // Full URL of the MCP server
@@ -31,6 +35,12 @@ type Config struct {
3135
// Implement the Logger interface (Debug, Info, Warn, Error methods) to
3236
// integrate with your application's logging system (e.g., zap, logrus).
3337
Logger Logger
38+
39+
// Token validation configuration
40+
SkipIssuerCheck bool
41+
SkipAudienceCheck bool
42+
SkipExpiryCheck bool
43+
TokenValidationFuncs []func(claims jwt.MapClaims) error
3444
}
3545

3646
// Validate validates the configuration
@@ -119,11 +129,15 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
119129
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
120130
// Convert root Config to provider.Config
121131
providerCfg := &provider.Config{
122-
Provider: cfg.Provider,
123-
Issuer: cfg.Issuer,
124-
Audience: cfg.Audience,
125-
JWTSecret: cfg.JWTSecret,
126-
Logger: logger,
132+
Provider: cfg.Provider,
133+
Issuer: cfg.Issuer,
134+
Audience: cfg.Audience,
135+
JWTSecret: cfg.JWTSecret,
136+
Logger: logger,
137+
SkipIssuerCheck: cfg.SkipIssuerCheck,
138+
SkipAudienceCheck: cfg.SkipAudienceCheck,
139+
SkipExpiryCheck: cfg.SkipExpiryCheck,
140+
TokenValidatorFuncs: cfg.TokenValidationFuncs,
127141
}
128142

129143
var validator provider.TokenValidator
@@ -217,12 +231,36 @@ func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder {
217231
return b
218232
}
219233

234+
// WithScopes sets the OIDC scopes
235+
func (b *ConfigBuilder) WithScopes(scopes []string) *ConfigBuilder {
236+
b.config.Scopes = scopes
237+
return b
238+
}
239+
220240
// WithLogger sets the logger
221241
func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder {
222242
b.config.Logger = logger
223243
return b
224244
}
225245

246+
// WithSkipIssuerCheck sets issuer check toggle
247+
func (b *ConfigBuilder) WithSkipIssuerCheck(skipIssuerCheck bool) *ConfigBuilder {
248+
b.config.SkipIssuerCheck = skipIssuerCheck
249+
return b
250+
}
251+
252+
// WithSkipAudienceCheck sets audience check toggle
253+
func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder {
254+
b.config.SkipAudienceCheck = skipAudienceCheck
255+
return b
256+
}
257+
258+
// WithSkipExpiryCheck sets expiry check toggle
259+
func (b *ConfigBuilder) WithSkipExpiryCheck(skipExpiryCheck bool) *ConfigBuilder {
260+
b.config.SkipExpiryCheck = skipExpiryCheck
261+
return b
262+
}
263+
226264
// WithServerURL sets the full server URL directly
227265
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
228266
b.config.ServerURL = url
@@ -280,6 +318,7 @@ func FromEnv() (*Config, error) {
280318
}
281319

282320
jwtSecret := getEnv("JWT_SECRET", "")
321+
scopes := strings.Split(getEnv("OIDC_SCOPES", ""), " ")
283322

284323
return NewConfigBuilder().
285324
WithMode(getEnv("OAUTH_MODE", "")).
@@ -289,7 +328,24 @@ func FromEnv() (*Config, error) {
289328
WithAudience(getEnv("OIDC_AUDIENCE", "")).
290329
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
291330
WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")).
331+
WithScopes(scopes).
332+
WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)).
333+
WithSkipIssuerCheck(parseBoolEnv("OIDC_SKIP_ISSUER_CHECK", false)).
334+
WithSkipExpiryCheck(parseBoolEnv("OIDC_SKIP_EXPIRY_CHECK", false)).
292335
WithServerURL(serverURL).
293336
WithJWTSecret([]byte(jwtSecret)).
294337
Build()
295338
}
339+
340+
// parseBoolEnv parses a boolean environment variable
341+
func parseBoolEnv(key string, defaultVal bool) bool {
342+
val := getEnv(key, "")
343+
if val == "" {
344+
return defaultVal
345+
}
346+
parsed, err := strconv.ParseBool(val)
347+
if err != nil {
348+
return defaultVal
349+
}
350+
return parsed
351+
}

handlers.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type OAuth2Config struct {
4646
Audience string
4747
ClientID string
4848
ClientSecret string
49+
Scopes []string
4950

5051
// Server configuration
5152
MCPHost string
@@ -96,7 +97,7 @@ func NewOAuth2Handler(cfg *OAuth2Config, logger Logger) *OAuth2Handler {
9697
ClientID: cfg.ClientID,
9798
ClientSecret: cfg.ClientSecret,
9899
Endpoint: endpoint,
99-
Scopes: []string{"openid", "profile", "email"},
100+
Scopes: cfg.Scopes,
100101
}
101102

102103
// Log client configuration type for debugging
@@ -177,6 +178,11 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
177178
mcpURL = getEnv("MCP_URL", fmt.Sprintf("%s://%s:%s", scheme, mcpHost, mcpPort))
178179
}
179180

181+
scopes := cfg.Scopes
182+
if len(scopes) == 0 {
183+
scopes = []string{"openid", "profile", "email"}
184+
}
185+
180186
return &OAuth2Config{
181187
Enabled: true,
182188
Mode: cfg.Mode,
@@ -186,6 +192,7 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
186192
Audience: cfg.Audience,
187193
ClientID: cfg.ClientID,
188194
ClientSecret: cfg.ClientSecret,
195+
Scopes: scopes,
189196
MCPHost: mcpHost,
190197
MCPPort: mcpPort,
191198
MCPURL: mcpURL,

metadata.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func (h *OAuth2Handler) HandleOIDCDiscovery(w http.ResponseWriter, r *http.Reque
242242
"token_endpoint_auth_methods_supported": []string{"none"},
243243
"code_challenge_methods_supported": []string{"plain", "S256"},
244244
"subject_types_supported": []string{"public"},
245-
"scopes_supported": []string{"openid", "profile", "email"},
245+
"scopes_supported": h.config.Scopes,
246246
}
247247

248248
// Add provider-specific fields

provider/provider.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ type Logger interface {
3030

3131
// Config holds OAuth configuration (subset needed by provider)
3232
type Config struct {
33-
Provider string
34-
Issuer string
35-
Audience string
36-
JWTSecret []byte
37-
Logger Logger
33+
Provider string
34+
Issuer string
35+
Audience string
36+
JWTSecret []byte
37+
Logger Logger
38+
SkipIssuerCheck bool
39+
SkipAudienceCheck bool
40+
SkipExpiryCheck bool
41+
TokenValidatorFuncs []func(claims jwt.MapClaims) error
3842
}
3943

4044
// TokenValidator interface for OAuth token validation
@@ -52,10 +56,11 @@ type HMACValidator struct {
5256

5357
// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
5458
type OIDCValidator struct {
55-
verifier *oidc.IDTokenVerifier
56-
provider *oidc.Provider
57-
audience string
58-
logger Logger
59+
verifier *oidc.IDTokenVerifier
60+
provider *oidc.Provider
61+
audience string
62+
TokenValidatorFuncs []func(claims jwt.MapClaims) error
63+
logger Logger
5964
}
6065

6166
// Initialize sets up the HMAC validator with JWT secret and audience
@@ -90,7 +95,6 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
9095
}
9196
return []byte(v.secret), nil
9297
})
93-
9498
if err != nil {
9599
return nil, fmt.Errorf("failed to parse and validate token: %w", err)
96100
}
@@ -204,15 +208,16 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
204208
verifier := provider.Verifier(&oidc.Config{
205209
ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85
206210
SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
207-
SkipClientIDCheck: false, // Always validate if ClientID is provided
208-
SkipExpiryCheck: false, // Verify expiration
209-
SkipIssuerCheck: false, // Verify issuer
211+
SkipClientIDCheck: cfg.SkipAudienceCheck,
212+
SkipExpiryCheck: cfg.SkipExpiryCheck,
213+
SkipIssuerCheck: cfg.SkipIssuerCheck,
210214
})
211215

212216
v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)
213217

214218
v.provider = provider
215219
v.verifier = verifier
220+
v.TokenValidatorFuncs = cfg.TokenValidatorFuncs
216221
return nil
217222
}
218223

@@ -261,6 +266,14 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
261266
return nil, fmt.Errorf("audience validation failed: %w", err)
262267
}
263268

269+
// Run extra validation functions
270+
for _, fn := range v.TokenValidatorFuncs {
271+
err := fn(rawClaims)
272+
if err != nil {
273+
return nil, fmt.Errorf("validation function failed with error: %w", err)
274+
}
275+
}
276+
264277
return &User{
265278
Subject: claims.Subject,
266279
Username: claims.PreferredUsername,

0 commit comments

Comments
 (0)