Skip to content

Commit d71ecb7

Browse files
authored
fix(mcp/oauth): Add WrapMCPEndpoint for automatic 401 handling (#6)
* fix(mcp/oauth): Add WrapMCPEndpoint for automatic 401 handling Signed-off-by: Tommy Nguyen <tuannvm@hotmail.com> * fix(oauth): trim whitespace from Bearer tokens in headers Signed-off-by: Tommy Nguyen <tuannvm@hotmail.com> * fix(oauth): reject non-Bearer auth schemes and handle malformed tokens Signed-off-by: Tommy Nguyen <tuannvm@hotmail.com> --------- Signed-off-by: Tommy Nguyen <tuannvm@hotmail.com>
1 parent 33bddcf commit d71ecb7

File tree

6 files changed

+167
-14
lines changed

6 files changed

+167
-14
lines changed

examples/mark3labs/advanced/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func main() {
5050
)
5151
streamableServer := mcpserver.NewStreamableHTTPServer(mcpServer, httpOpts...)
5252

53-
// Feature 4: WrapHandler - Auto Bearer token pre-check with 401
53+
// Feature 4: WrapMCPEndpoint - Automatic 401 handling with CORS support
5454
mcpHandler := func(w http.ResponseWriter, r *http.Request) {
5555
w.Header().Set("Access-Control-Allow-Origin", "*")
5656
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
@@ -64,7 +64,7 @@ func main() {
6464
streamableServer.ServeHTTP(w, r)
6565
}
6666

67-
mux.HandleFunc("/mcp", oauthServer.WrapHandlerFunc(mcpHandler))
67+
mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(http.HandlerFunc(mcpHandler)))
6868

6969
// Add status endpoint (not OAuth protected)
7070
mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {

examples/mark3labs/simple/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func main() {
2222
// export OKTA_DOMAIN="dev-12345.okta.com" (your Okta domain)
2323
// export OKTA_AUDIENCE="api://my-mcp-server" (your API identifier)
2424
// export SERVER_URL="https://mcp.example.com" (your server URL)
25-
_, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{
25+
oauthServer, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{
2626
Provider: "okta",
2727
Issuer: fmt.Sprintf("https://%s", getEnv("OKTA_DOMAIN", "dev-12345.okta.com")),
2828
Audience: getEnv("OKTA_AUDIENCE", "api://my-mcp-server"),
@@ -51,13 +51,13 @@ func main() {
5151
},
5252
)
5353

54-
// 5. Setup MCP endpoint
54+
// 5. Setup MCP endpoint with automatic 401 handling
5555
streamableServer := mcpserver.NewStreamableHTTPServer(
5656
mcpServer,
5757
mcpserver.WithEndpointPath("/mcp"),
5858
mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()),
5959
)
60-
mux.Handle("/mcp", streamableServer)
60+
mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))
6161

6262
// 6. Start server
6363
// Note: PORT is the local bind port. If you change SERVER_URL port

mark3labs/oauth.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ import (
2323
// })
2424
// mcpServer := server.NewMCPServer("Server", "1.0.0", oauthOption)
2525
//
26+
// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...)
27+
// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))
28+
//
2629
// This function:
2730
// - Creates OAuth server instance
2831
// - Registers OAuth HTTP endpoints on mux
2932
// - Returns server instance and middleware as server option
3033
//
3134
// The returned Server instance provides access to:
35+
// - WrapMCPEndpoint() - Wrap /mcp endpoint with automatic 401 handling
3236
// - WrapHandler() - Wrap HTTP handlers with OAuth token validation
3337
// - GetHTTPServerOptions() - Get StreamableHTTPServer options
3438
// - LogStartup() - Log OAuth endpoint information

mcp/oauth.go

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mcp
33
import (
44
"fmt"
55
"net/http"
6+
"strings"
67

78
"github.com/modelcontextprotocol/go-sdk/mcp"
89
oauth "github.com/tuannvm/oauth-mcp-proxy"
@@ -32,14 +33,18 @@ import (
3233
// This function:
3334
// - Creates OAuth server instance
3435
// - Registers OAuth HTTP endpoints on mux
35-
// - Wraps MCP StreamableHTTPHandler with OAuth token validation
36+
// - Wraps MCP StreamableHTTPHandler with automatic 401 handling
3637
// - Returns OAuth server and protected HTTP handler
3738
//
39+
// The returned handler automatically:
40+
// - Returns 401 with WWW-Authenticate headers if Bearer token missing
41+
// - Passes through OPTIONS requests (CORS pre-flight)
42+
// - Rejects non-Bearer auth schemes (OAuth-only endpoint)
43+
//
3844
// The returned oauth.Server instance provides access to:
3945
// - LogStartup() - Log OAuth endpoint information
4046
// - Discovery URL helpers (GetCallbackURL, GetMetadataURL, etc.)
4147
//
42-
// The HTTP handler validates OAuth tokens before delegating to the MCP server.
4348
// Tool handlers can access the authenticated user via oauth.GetUserFromContext(ctx).
4449
func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*oauth.Server, http.Handler, error) {
4550
oauthServer, err := oauth.NewServer(cfg)
@@ -54,24 +59,64 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o
5459
}, nil)
5560

5661
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
62+
// Pass through OPTIONS requests (CORS pre-flight)
63+
if r.Method == http.MethodOptions {
64+
mcpHandler.ServeHTTP(w, r)
65+
return
66+
}
67+
68+
// Check Authorization header
5769
authHeader := r.Header.Get("Authorization")
58-
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
59-
http.Error(w, "Missing or invalid Authorization header", http.StatusUnauthorized)
70+
authLower := strings.ToLower(authHeader)
71+
72+
// Return 401 if Bearer token missing
73+
if authHeader == "" {
74+
oauthServer.Return401(w)
6075
return
6176
}
6277

63-
token := authHeader[7:]
78+
// Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec)
79+
if !strings.HasPrefix(authLower, "bearer") {
80+
// Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only)
81+
oauthServer.Return401(w)
82+
return
83+
}
84+
85+
// Malformed Bearer token (no space after "Bearer")
86+
if !strings.HasPrefix(authLower, "bearer ") {
87+
oauthServer.Return401InvalidToken(w)
88+
return
89+
}
90+
91+
// Extract and validate token (safe slice operation)
92+
const bearerPrefix = "Bearer "
93+
if len(authHeader) < len(bearerPrefix)+1 {
94+
oauthServer.Return401InvalidToken(w)
95+
return
96+
}
97+
token := authHeader[len(bearerPrefix):]
98+
99+
// Clean any whitespace (e.g., "Bearer token ")
100+
token = strings.TrimSpace(token)
101+
102+
// Validate token is not empty
103+
if token == "" {
104+
oauthServer.Return401InvalidToken(w)
105+
return
106+
}
64107

65108
user, err := oauthServer.ValidateTokenCached(r.Context(), token)
66109
if err != nil {
67-
http.Error(w, fmt.Sprintf("Authentication failed: %v", err), http.StatusUnauthorized)
110+
oauthServer.Return401InvalidToken(w)
68111
return
69112
}
70113

114+
// Add token and user to context
71115
ctx := oauth.WithOAuthToken(r.Context(), token)
72116
ctx = oauth.WithUser(ctx, user)
73117
r = r.WithContext(ctx)
74118

119+
// Pass to wrapped handler
75120
mcpHandler.ServeHTTP(w, r)
76121
})
77122

middleware.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,13 @@ func OAuthMiddleware(validator provider.TokenValidator, enabled bool) func(serve
119119
// via WithOAuthToken(). The OAuth middleware then retrieves it via GetOAuthToken().
120120
func CreateHTTPContextFunc() func(context.Context, *http.Request) context.Context {
121121
return func(ctx context.Context, r *http.Request) context.Context {
122-
// Extract Bearer token from Authorization header
122+
// Extract Bearer token from Authorization header (case-insensitive per OAuth 2.0 spec)
123123
authHeader := r.Header.Get("Authorization")
124-
if strings.HasPrefix(authHeader, "Bearer ") {
125-
token := strings.TrimPrefix(authHeader, "Bearer ")
124+
authLower := strings.ToLower(authHeader)
125+
126+
if strings.HasPrefix(authLower, "bearer ") {
127+
// Extract token (skip "Bearer " or "bearer " prefix)
128+
token := authHeader[7:]
126129
// Clean any whitespace
127130
token = strings.TrimSpace(token)
128131
ctx = WithOAuthToken(ctx, token)

oauth.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"net/http"
9+
"strings"
910
"time"
1011

1112
mcpserver "github.com/mark3labs/mcp-go/server"
@@ -312,6 +313,106 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc {
312313
return s.WrapHandler(next).ServeHTTP
313314
}
314315

316+
// WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling.
317+
// Returns 401 with WWW-Authenticate headers if Bearer token is missing or invalid.
318+
//
319+
// This method provides automatic OAuth discovery for MCP clients by:
320+
// - Passing through OPTIONS requests (CORS pre-flight)
321+
// - Rejecting non-Bearer auth schemes (OAuth-only endpoint)
322+
// - Returning 401 with proper headers if Bearer token is missing/malformed
323+
// - Extracting token to context and passing to wrapped handler
324+
//
325+
// Usage with mark3labs SDK:
326+
//
327+
// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...)
328+
// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))
329+
//
330+
// For official SDK, use mcp.WithOAuth() which includes this automatically.
331+
func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc {
332+
return func(w http.ResponseWriter, r *http.Request) {
333+
// Pass through OPTIONS requests (CORS pre-flight)
334+
if r.Method == http.MethodOptions {
335+
handler.ServeHTTP(w, r)
336+
return
337+
}
338+
339+
// Check Authorization header
340+
authHeader := r.Header.Get("Authorization")
341+
authLower := strings.ToLower(authHeader)
342+
343+
// Return 401 if Bearer token missing
344+
if authHeader == "" {
345+
s.Return401(w)
346+
return
347+
}
348+
349+
// Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec)
350+
if !strings.HasPrefix(authLower, "bearer") {
351+
// Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only)
352+
s.Return401(w)
353+
return
354+
}
355+
356+
// Malformed Bearer token (no space after "Bearer")
357+
if !strings.HasPrefix(authLower, "bearer ") {
358+
s.Return401InvalidToken(w)
359+
return
360+
}
361+
362+
// Extract token to context
363+
contextFunc := CreateHTTPContextFunc()
364+
ctx := contextFunc(r.Context(), r)
365+
r = r.WithContext(ctx)
366+
367+
// Pass to wrapped handler
368+
handler.ServeHTTP(w, r)
369+
}
370+
}
371+
372+
// Return401 writes a 401 response with WWW-Authenticate header.
373+
// Used by WrapMCPEndpoint and can be called by adapters.
374+
//
375+
// Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens.
376+
// Includes resource_metadata URL for OAuth discovery.
377+
func (s *Server) Return401(w http.ResponseWriter) {
378+
metadataURL := s.GetProtectedResourceMetadataURL()
379+
380+
// RFC 6750 compliant: all parameters in single Bearer header
381+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
382+
`Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`,
383+
metadataURL))
384+
w.Header().Set("Content-Type", "application/json")
385+
w.WriteHeader(http.StatusUnauthorized)
386+
387+
errorResponse := map[string]string{
388+
"error": "invalid_request",
389+
"error_description": "Bearer token required",
390+
}
391+
_ = json.NewEncoder(w).Encode(errorResponse)
392+
}
393+
394+
// Return401InvalidToken writes a 401 response for invalid/expired tokens.
395+
// Used when token validation fails (vs missing token).
396+
//
397+
// Returns error code "invalid_token" per RFC 6750 §3.1 for invalid tokens.
398+
// Includes resource_metadata URL for OAuth discovery.
399+
func (s *Server) Return401InvalidToken(w http.ResponseWriter) {
400+
metadataURL := s.GetProtectedResourceMetadataURL()
401+
402+
// RFC 6750 compliant: all parameters in single Bearer header
403+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
404+
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
405+
metadataURL))
406+
w.Header().Set("Content-Type", "application/json")
407+
w.WriteHeader(http.StatusUnauthorized)
408+
409+
errorResponse := map[string]string{
410+
"error": "invalid_token",
411+
"error_description": "Authentication failed",
412+
}
413+
_ = json.NewEncoder(w).Encode(errorResponse)
414+
}
415+
315416
// WithOAuth returns a server option that enables OAuth authentication
316417
// This is the composable API for mcp-go v0.41.1
317418
//

0 commit comments

Comments
 (0)