Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/modelcontextprotocol/go-sdk/internal/util"
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"golang.org/x/oauth2"
)

// A StreamableHTTPHandler is an http.Handler that serves streamable MCP
Expand Down Expand Up @@ -1803,6 +1804,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")

if err := c.setMCPHeaders(req); err != nil {
// Failure to set headers means that the request was not sent.
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
Expand Down Expand Up @@ -1934,9 +1936,20 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error {
if ts != nil {
token, err := ts.Token()
if err != nil {
return err
}
if token != nil {
// If the error is an invalid_grant oauth2.RetrieveError it indicates
// that the token source doesn't have valid authorization for the token
// endpoint, per RFC 6749 section 5.2. For example, the refresh token
// may be expired or invalid.
//
// In that case, ignore the error, skip setting the Authorization
// header, and proceed with the request. Callers that support
// authorization flows get a 401/403 response and trigger the
// Authorize() flow to refresh their token.
var retrieveErr *oauth2.RetrieveError
if !errors.As(err, &retrieveErr) || retrieveErr.ErrorCode != "invalid_grant" {
return err
}
} else if token != nil {
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
}
Expand Down
56 changes: 56 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mcp

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -1156,3 +1157,58 @@ func TestTokenInfo(t *testing.T) {
t.Errorf("got %q, want %q", g, w)
}
}

// errTestAuthorizeFailed is a sentinel error returned by
// retrieveErrorOAuthHandler.Authorize().
var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test")

// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns
// an oauth2.RetrieveError from its TokenSource's Token() method.
type retrieveErrorOAuthHandler struct{}

func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return h, nil
}

func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) {
return nil, &oauth2.RetrieveError{
Response: &http.Response{StatusCode: http.StatusBadRequest},
Body: []byte("test retrieve error"),
ErrorCode: "invalid_grant",
}
}

func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error {
return errTestAuthorizeFailed
}

// TestStreamableClientOAuth_RetrieveError verifies that an invalid_grant RetrieveError
// from the OAuth token source correctly skips sending Authorization header and relies on
// the server's 401 response to trigger the Authorize fallback flow.
func TestStreamableClientOAuth_RetrieveError(t *testing.T) {
ctx := context.Background()
oauthHandler := &retrieveErrorOAuthHandler{}

// Mock MCP server returns 401 Unauthorized to simulate a server rejecting
// the request that omitted the Authorization header.
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
t.Cleanup(httpServer.Close)

transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
OAuthHandler: oauthHandler,
}
client := NewClient(testImpl, nil)

// Attempt to connect. The Connect call will trigger the initialization request,
// which will fail to retrieve the token and proceed without auth header, receive 401,
// and invoke Authorize().
_, err := client.Connect(ctx, transport, nil)

// Expect the connection to fail with the sentinel error, not the RetrieveError.
if !errors.Is(err, errTestAuthorizeFailed) {
t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed)
}
}