@@ -2,7 +2,10 @@ package oauth
22
33import (
44 "fmt"
5+ "strconv"
6+ "strings"
57
8+ "github.com/golang-jwt/jwt/v5"
69 "github.com/tuannvm/oauth-mcp-proxy/provider"
710)
811
@@ -18,6 +21,7 @@ type Config struct {
1821 Audience string
1922 ClientID string
2023 ClientSecret string
24+ Scopes []string
2125
2226 // Server configuration
2327 ServerURL string // Full URL of the MCP server
@@ -31,6 +35,12 @@ type Config struct {
3135 // Implement the Logger interface (Debug, Info, Warn, Error methods) to
3236 // integrate with your application's logging system (e.g., zap, logrus).
3337 Logger Logger
38+
39+ // Token validation configuration
40+ SkipIssuerCheck bool
41+ SkipAudienceCheck bool
42+ SkipExpiryCheck bool
43+ TokenValidationFuncs []func (claims jwt.MapClaims ) error
3444}
3545
3646// Validate validates the configuration
@@ -119,11 +129,15 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
119129func createValidator (cfg * Config , logger Logger ) (provider.TokenValidator , error ) {
120130 // Convert root Config to provider.Config
121131 providerCfg := & provider.Config {
122- Provider : cfg .Provider ,
123- Issuer : cfg .Issuer ,
124- Audience : cfg .Audience ,
125- JWTSecret : cfg .JWTSecret ,
126- Logger : logger ,
132+ Provider : cfg .Provider ,
133+ Issuer : cfg .Issuer ,
134+ Audience : cfg .Audience ,
135+ JWTSecret : cfg .JWTSecret ,
136+ Logger : logger ,
137+ SkipIssuerCheck : cfg .SkipIssuerCheck ,
138+ SkipAudienceCheck : cfg .SkipAudienceCheck ,
139+ SkipExpiryCheck : cfg .SkipExpiryCheck ,
140+ TokenValidatorFuncs : cfg .TokenValidationFuncs ,
127141 }
128142
129143 var validator provider.TokenValidator
@@ -217,12 +231,36 @@ func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder {
217231 return b
218232}
219233
234+ // WithScopes sets the OIDC scopes
235+ func (b * ConfigBuilder ) WithScopes (scopes []string ) * ConfigBuilder {
236+ b .config .Scopes = scopes
237+ return b
238+ }
239+
220240// WithLogger sets the logger
221241func (b * ConfigBuilder ) WithLogger (logger Logger ) * ConfigBuilder {
222242 b .config .Logger = logger
223243 return b
224244}
225245
246+ // WithSkipIssuerCheck sets issuer check toggle
247+ func (b * ConfigBuilder ) WithSkipIssuerCheck (skipIssuerCheck bool ) * ConfigBuilder {
248+ b .config .SkipIssuerCheck = skipIssuerCheck
249+ return b
250+ }
251+
252+ // WithSkipAudienceCheck sets audience check toggle
253+ func (b * ConfigBuilder ) WithSkipAudienceCheck (skipAudienceCheck bool ) * ConfigBuilder {
254+ b .config .SkipAudienceCheck = skipAudienceCheck
255+ return b
256+ }
257+
258+ // WithSkipExpiryCheck sets expiry check toggle
259+ func (b * ConfigBuilder ) WithSkipExpiryCheck (skipExpiryCheck bool ) * ConfigBuilder {
260+ b .config .SkipExpiryCheck = skipExpiryCheck
261+ return b
262+ }
263+
226264// WithServerURL sets the full server URL directly
227265func (b * ConfigBuilder ) WithServerURL (url string ) * ConfigBuilder {
228266 b .config .ServerURL = url
@@ -280,6 +318,7 @@ func FromEnv() (*Config, error) {
280318 }
281319
282320 jwtSecret := getEnv ("JWT_SECRET" , "" )
321+ scopes := strings .Split (getEnv ("OIDC_SCOPES" , "" ), " " )
283322
284323 return NewConfigBuilder ().
285324 WithMode (getEnv ("OAUTH_MODE" , "" )).
@@ -289,7 +328,24 @@ func FromEnv() (*Config, error) {
289328 WithAudience (getEnv ("OIDC_AUDIENCE" , "" )).
290329 WithClientID (getEnv ("OIDC_CLIENT_ID" , "" )).
291330 WithClientSecret (getEnv ("OIDC_CLIENT_SECRET" , "" )).
331+ WithScopes (scopes ).
332+ WithSkipAudienceCheck (parseBoolEnv ("OIDC_SKIP_AUDIENCE_CHECK" , false )).
333+ WithSkipIssuerCheck (parseBoolEnv ("OIDC_SKIP_ISSUER_CHECK" , false )).
334+ WithSkipExpiryCheck (parseBoolEnv ("OIDC_SKIP_EXPIRY_CHECK" , false )).
292335 WithServerURL (serverURL ).
293336 WithJWTSecret ([]byte (jwtSecret )).
294337 Build ()
295338}
339+
340+ // parseBoolEnv parses a boolean environment variable
341+ func parseBoolEnv (key string , defaultVal bool ) bool {
342+ val := getEnv (key , "" )
343+ if val == "" {
344+ return defaultVal
345+ }
346+ parsed , err := strconv .ParseBool (val )
347+ if err != nil {
348+ return defaultVal
349+ }
350+ return parsed
351+ }
0 commit comments