diff --git a/mcp/streamable.go b/mcp/streamable.go index 708b1326..b49822e6 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -37,6 +37,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/oauth2" ) // A StreamableHTTPHandler is an http.Handler that serves streamable MCP @@ -1803,6 +1804,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + if err := c.setMCPHeaders(req); err != nil { // Failure to set headers means that the request was not sent. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr @@ -1934,9 +1936,20 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { if ts != nil { token, err := ts.Token() if err != nil { - return err - } - if token != nil { + // If the error is an invalid_grant oauth2.RetrieveError it indicates + // that the token source doesn't have valid authorization for the token + // endpoint, per RFC 6749 section 5.2. For example, the refresh token + // may be expired or invalid. + // + // In that case, ignore the error, skip setting the Authorization + // header, and proceed with the request. Callers that support + // authorization flows get a 401/403 response and trigger the + // Authorize() flow to refresh their token. + var retrieveErr *oauth2.RetrieveError + if !errors.As(err, &retrieveErr) || retrieveErr.ErrorCode != "invalid_grant" { + return err + } + } else if token != nil { req.Header.Set("Authorization", "Bearer "+token.AccessToken) } } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 96c5e75f..517e51af 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "errors" "fmt" "io" "net/http" @@ -1156,3 +1157,58 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +// errTestAuthorizeFailed is a sentinel error returned by +// retrieveErrorOAuthHandler.Authorize(). +var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test") + +// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns +// an oauth2.RetrieveError from its TokenSource's Token() method. +type retrieveErrorOAuthHandler struct{} + +func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h, nil +} + +func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) { + return nil, &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: http.StatusBadRequest}, + Body: []byte("test retrieve error"), + ErrorCode: "invalid_grant", + } +} + +func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + return errTestAuthorizeFailed +} + +// TestStreamableClientOAuth_RetrieveError verifies that an invalid_grant RetrieveError +// from the OAuth token source correctly skips sending Authorization header and relies on +// the server's 401 response to trigger the Authorize fallback flow. +func TestStreamableClientOAuth_RetrieveError(t *testing.T) { + ctx := context.Background() + oauthHandler := &retrieveErrorOAuthHandler{} + + // Mock MCP server returns 401 Unauthorized to simulate a server rejecting + // the request that omitted the Authorization header. + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + + // Attempt to connect. The Connect call will trigger the initialization request, + // which will fail to retrieve the token and proceed without auth header, receive 401, + // and invoke Authorize(). + _, err := client.Connect(ctx, transport, nil) + + // Expect the connection to fail with the sentinel error, not the RetrieveError. + if !errors.Is(err, errTestAuthorizeFailed) { + t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed) + } +}