|
6 | 6 | "encoding/json" |
7 | 7 | "fmt" |
8 | 8 | "net/http" |
| 9 | + "strings" |
9 | 10 | "time" |
10 | 11 |
|
11 | 12 | mcpserver "github.com/mark3labs/mcp-go/server" |
@@ -312,6 +313,106 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc { |
312 | 313 | return s.WrapHandler(next).ServeHTTP |
313 | 314 | } |
314 | 315 |
|
| 316 | +// WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling. |
| 317 | +// Returns 401 with WWW-Authenticate headers if Bearer token is missing or invalid. |
| 318 | +// |
| 319 | +// This method provides automatic OAuth discovery for MCP clients by: |
| 320 | +// - Passing through OPTIONS requests (CORS pre-flight) |
| 321 | +// - Rejecting non-Bearer auth schemes (OAuth-only endpoint) |
| 322 | +// - Returning 401 with proper headers if Bearer token is missing/malformed |
| 323 | +// - Extracting token to context and passing to wrapped handler |
| 324 | +// |
| 325 | +// Usage with mark3labs SDK: |
| 326 | +// |
| 327 | +// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...) |
| 328 | +// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) |
| 329 | +// |
| 330 | +// For official SDK, use mcp.WithOAuth() which includes this automatically. |
| 331 | +func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { |
| 332 | + return func(w http.ResponseWriter, r *http.Request) { |
| 333 | + // Pass through OPTIONS requests (CORS pre-flight) |
| 334 | + if r.Method == http.MethodOptions { |
| 335 | + handler.ServeHTTP(w, r) |
| 336 | + return |
| 337 | + } |
| 338 | + |
| 339 | + // Check Authorization header |
| 340 | + authHeader := r.Header.Get("Authorization") |
| 341 | + authLower := strings.ToLower(authHeader) |
| 342 | + |
| 343 | + // Return 401 if Bearer token missing |
| 344 | + if authHeader == "" { |
| 345 | + s.Return401(w) |
| 346 | + return |
| 347 | + } |
| 348 | + |
| 349 | + // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) |
| 350 | + if !strings.HasPrefix(authLower, "bearer") { |
| 351 | + // Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only) |
| 352 | + s.Return401(w) |
| 353 | + return |
| 354 | + } |
| 355 | + |
| 356 | + // Malformed Bearer token (no space after "Bearer") |
| 357 | + if !strings.HasPrefix(authLower, "bearer ") { |
| 358 | + s.Return401InvalidToken(w) |
| 359 | + return |
| 360 | + } |
| 361 | + |
| 362 | + // Extract token to context |
| 363 | + contextFunc := CreateHTTPContextFunc() |
| 364 | + ctx := contextFunc(r.Context(), r) |
| 365 | + r = r.WithContext(ctx) |
| 366 | + |
| 367 | + // Pass to wrapped handler |
| 368 | + handler.ServeHTTP(w, r) |
| 369 | + } |
| 370 | +} |
| 371 | + |
| 372 | +// Return401 writes a 401 response with WWW-Authenticate header. |
| 373 | +// Used by WrapMCPEndpoint and can be called by adapters. |
| 374 | +// |
| 375 | +// Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens. |
| 376 | +// Includes resource_metadata URL for OAuth discovery. |
| 377 | +func (s *Server) Return401(w http.ResponseWriter) { |
| 378 | + metadataURL := s.GetProtectedResourceMetadataURL() |
| 379 | + |
| 380 | + // RFC 6750 compliant: all parameters in single Bearer header |
| 381 | + w.Header().Set("WWW-Authenticate", fmt.Sprintf( |
| 382 | + `Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`, |
| 383 | + metadataURL)) |
| 384 | + w.Header().Set("Content-Type", "application/json") |
| 385 | + w.WriteHeader(http.StatusUnauthorized) |
| 386 | + |
| 387 | + errorResponse := map[string]string{ |
| 388 | + "error": "invalid_request", |
| 389 | + "error_description": "Bearer token required", |
| 390 | + } |
| 391 | + _ = json.NewEncoder(w).Encode(errorResponse) |
| 392 | +} |
| 393 | + |
| 394 | +// Return401InvalidToken writes a 401 response for invalid/expired tokens. |
| 395 | +// Used when token validation fails (vs missing token). |
| 396 | +// |
| 397 | +// Returns error code "invalid_token" per RFC 6750 §3.1 for invalid tokens. |
| 398 | +// Includes resource_metadata URL for OAuth discovery. |
| 399 | +func (s *Server) Return401InvalidToken(w http.ResponseWriter) { |
| 400 | + metadataURL := s.GetProtectedResourceMetadataURL() |
| 401 | + |
| 402 | + // RFC 6750 compliant: all parameters in single Bearer header |
| 403 | + w.Header().Set("WWW-Authenticate", fmt.Sprintf( |
| 404 | + `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`, |
| 405 | + metadataURL)) |
| 406 | + w.Header().Set("Content-Type", "application/json") |
| 407 | + w.WriteHeader(http.StatusUnauthorized) |
| 408 | + |
| 409 | + errorResponse := map[string]string{ |
| 410 | + "error": "invalid_token", |
| 411 | + "error_description": "Authentication failed", |
| 412 | + } |
| 413 | + _ = json.NewEncoder(w).Encode(errorResponse) |
| 414 | +} |
| 415 | + |
315 | 416 | // WithOAuth returns a server option that enables OAuth authentication |
316 | 417 | // This is the composable API for mcp-go v0.41.1 |
317 | 418 | // |
|
0 commit comments