From f00d0f4a93519dc07917bde917e974e9b20105f1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 24 Apr 2026 14:08:13 +0000 Subject: [PATCH 01/17] refactor: centralize MCP headers and add support for validating standard request headers --- mcp/mcp_http_headers.go | 93 ++++++++ mcp/mcp_http_headers_test.go | 436 ++++++++++++++++++++++++++++++++++ mcp/shared.go | 3 + mcp/streamable.go | 61 +++-- mcp/streamable_bench_test.go | 2 +- mcp/streamable_client_test.go | 34 +-- mcp/streamable_test.go | 282 +++++++++++++++++++++- 7 files changed, 859 insertions(+), 52 deletions(-) create mode 100644 mcp/mcp_http_headers.go create mode 100644 mcp/mcp_http_headers_test.go diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go new file mode 100644 index 00000000..70099f8c --- /dev/null +++ b/mcp/mcp_http_headers.go @@ -0,0 +1,93 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + ProtocolVersionHeader = "Mcp-Protocol-Version" + SessionIDHeader = "Mcp-Session-Id" + LastEventIDHeader = "Last-Event-ID" + MethodHeader = "Mcp-Method" + NameHeader = "Mcp-Name" + MinVersionForStandardHeaders = "2026-06-XX" +) + +func extractName(method string, params json.RawMessage) (string, bool) { + switch method { + case "tools/call": + var p CallToolParams + if err := json.Unmarshal(params, &p); err == nil { + return p.Name, true + } + case "prompts/get": + var p GetPromptParams + if err := json.Unmarshal(params, &p); err == nil { + return p.Name, true + } + case "resources/read": + var p ReadResourceParams + if err := json.Unmarshal(params, &p); err == nil { + return p.URI, true + } + } + + return "", false +} + +func setStandardHeaders(httpReq *http.Request, msg jsonrpc.Message) { + if msg == nil { + return + } + if httpReq.Header.Get(ProtocolVersionHeader) == "" || httpReq.Header.Get(ProtocolVersionHeader) < MinVersionForStandardHeaders { + return + } + + switch msg := msg.(type) { + case *jsonrpc.Request: + httpReq.Header.Set(MethodHeader, msg.Method) + if name, ok := extractName(msg.Method, msg.Params); ok { + httpReq.Header.Set(NameHeader, name) + } + } +} + +func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { + protocolVersion := req.Header.Get(ProtocolVersionHeader) + if protocolVersion == "" || protocolVersion < MinVersionForStandardHeaders { + return nil + } + + switch msg := msg.(type) { + case *jsonrpc.Request: + methodInHeader := req.Header.Get(MethodHeader) + if methodInHeader == "" { + return errors.New("missing required Mcp-Method header") + } + if methodInHeader != msg.Method { + return fmt.Errorf("Header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method) + } + + if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" { + nameInHeader := req.Header.Get(NameHeader) + if nameInHeader == "" { + return fmt.Errorf("missing required Mcp-Name header for method %q", msg.Method) + } + if nameInBody, ok := extractName(msg.Method, msg.Params); ok { + if nameInHeader != nameInBody { + return fmt.Errorf("Header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) + } + } + } + } + return nil +} diff --git a/mcp/mcp_http_headers_test.go b/mcp/mcp_http_headers_test.go new file mode 100644 index 00000000..464ba5f5 --- /dev/null +++ b/mcp/mcp_http_headers_test.go @@ -0,0 +1,436 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestExtractName(t *testing.T) { + tests := []struct { + name string + method string + params json.RawMessage + wantName string + wantOK bool + }{ + { + name: "tools/call", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my-tool"}), + wantName: "my-tool", + wantOK: true, + }, + { + name: "prompts/get", + method: "prompts/get", + params: mustMarshal(&GetPromptParams{Name: "code_review"}), + wantName: "code_review", + wantOK: true, + }, + { + name: "resources/read", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"}), + wantName: "file:///info.txt", + wantOK: true, + }, + { + name: "tools/call with empty name", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: ""}), + wantName: "", + wantOK: true, + }, + { + name: "tool name with hyphen", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my-tool-v2"}), + wantName: "my-tool-v2", + wantOK: true, + }, + { + name: "tool name with underscore", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my_tool_name"}), + wantName: "my_tool_name", + wantOK: true, + }, + { + name: "resource URI with special chars", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "file:///path/to/file%20name.txt"}), + wantName: "file:///path/to/file%20name.txt", + wantOK: true, + }, + { + name: "resource URI with query string", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "https://example.com/resource?id=123"}), + wantName: "https://example.com/resource?id=123", + wantOK: true, + }, + { + name: "unrelated method", + method: "initialize", + params: mustMarshal(&InitializeParams{ProtocolVersion: "2025-06-18"}), + wantName: "", + wantOK: false, + }, + { + name: "notification method", + method: "notifications/initialized", + params: nil, + wantName: "", + wantOK: false, + }, + { + name: "invalid JSON params", + method: "tools/call", + params: []byte("not json"), + wantName: "", + wantOK: false, + }, + { + name: "nil params", + method: "tools/call", + params: nil, + wantName: "", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotName, gotOK := extractName(tt.method, tt.params) + if gotName != tt.wantName || gotOK != tt.wantOK { + t.Errorf("extractName(%q, ...) = (%q, %v), want (%q, %v)", + tt.method, gotName, gotOK, tt.wantName, tt.wantOK) + } + }) + } +} + +func TestSetStandardHeaders(t *testing.T) { + tests := []struct { + name string + protocolVersion string + msg jsonrpc.Message + wantMethodHeader string + wantNameHeader string + }{ + { + name: "tools/call with future version", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "tools/call", + wantNameHeader: "my-tool", + }, + { + name: "prompts/get with future version", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "code_review"})}, + wantMethodHeader: "prompts/get", + wantNameHeader: "code_review", + }, + { + name: "resources/read with future version", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantMethodHeader: "resources/read", + wantNameHeader: "file:///info.txt", + }, + { + name: "initialize sets method only", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "initialize", Params: mustMarshal(&InitializeParams{ProtocolVersion: MinVersionForStandardHeaders})}, + wantMethodHeader: "initialize", + wantNameHeader: "", + }, + { + name: "notification sets method only", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "notifications/initialized"}, + wantMethodHeader: "notifications/initialized", + wantNameHeader: "", + }, + { + name: "old version skips all headers", + protocolVersion: protocolVersion20251125, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "empty version skips all headers", + protocolVersion: "", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "nil message is a no-op", + protocolVersion: MinVersionForStandardHeaders, + msg: nil, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "response message is ignored", + protocolVersion: MinVersionForStandardHeaders, + msg: &jsonrpc.Response{}, + wantMethodHeader: "", + wantNameHeader: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpReq, err := http.NewRequest("POST", "http://localhost/mcp", nil) + if err != nil { + t.Fatal(err) + } + if tt.protocolVersion != "" { + httpReq.Header.Set(ProtocolVersionHeader, tt.protocolVersion) + } + + setStandardHeaders(httpReq, tt.msg) + + if got := httpReq.Header.Get(MethodHeader); got != tt.wantMethodHeader { + t.Errorf("MethodHeader = %q, want %q", got, tt.wantMethodHeader) + } + if got := httpReq.Header.Get(NameHeader); got != tt.wantNameHeader { + t.Errorf("NameHeader = %q, want %q", got, tt.wantNameHeader) + } + }) + } +} + +func TestValidateMcpHeaders(t *testing.T) { + futureVersion := MinVersionForStandardHeaders + oldVersion := protocolVersion20251125 + + tests := []struct { + name string + version string + methodHeader string + nameHeader string + msg jsonrpc.Message + wantErr bool + wantErrContain string + }{ + // -- Version gating -- + { + name: "old version skips validation", + version: oldVersion, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + { + name: "empty version skips validation", + version: "", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + + // -- Missing headers -- + { + name: "missing Mcp-Method header", + version: futureVersion, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Method header", + }, + { + name: "missing Mcp-Name for tools/call", + version: futureVersion, + methodHeader: "tools/call", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + { + name: "missing Mcp-Name for resources/read", + version: futureVersion, + methodHeader: "resources/read", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + { + name: "missing Mcp-Name for prompts/get", + version: futureVersion, + methodHeader: "prompts/get", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "review"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + + // -- Mismatches -- + { + name: "method mismatch", + version: futureVersion, + methodHeader: "tools/call", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "review"})}, + wantErr: true, + wantErrContain: "Mcp-Method header value 'tools/call' does not match body value 'prompts/get'", + }, + { + name: "tool name mismatch", + version: futureVersion, + methodHeader: "tools/call", + nameHeader: "wrong-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "right-tool"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'wrong-tool' does not match body value 'right-tool'", + }, + { + name: "resource URI mismatch", + version: futureVersion, + methodHeader: "resources/read", + nameHeader: "file:///wrong.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///right.txt"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'file:///wrong.txt' does not match body value 'file:///right.txt'", + }, + { + name: "prompt name mismatch", + version: futureVersion, + methodHeader: "prompts/get", + nameHeader: "wrong-prompt", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "right-prompt"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'wrong-prompt' does not match body value 'right-prompt'", + }, + + // -- Case sensitivity -- + { + name: "method value is case-sensitive", + version: futureVersion, + methodHeader: "TOOLS/CALL", + nameHeader: "my-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "Mcp-Method header value 'TOOLS/CALL' does not match body value 'tools/call'", + }, + + // -- Valid cases -- + { + name: "valid tools/call", + version: futureVersion, + methodHeader: "tools/call", + nameHeader: "my-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + { + name: "valid resources/read", + version: futureVersion, + methodHeader: "resources/read", + nameHeader: "file:///info.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantErr: false, + }, + { + name: "valid prompts/get", + version: futureVersion, + methodHeader: "prompts/get", + nameHeader: "code_review", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "code_review"})}, + wantErr: false, + }, + { + name: "valid initialize (no name needed)", + version: futureVersion, + methodHeader: "initialize", + msg: &jsonrpc.Request{Method: "initialize", Params: mustMarshal(&InitializeParams{ProtocolVersion: MinVersionForStandardHeaders})}, + wantErr: false, + }, + { + name: "valid notification (no name needed)", + version: futureVersion, + methodHeader: "notifications/initialized", + msg: &jsonrpc.Request{Method: "notifications/initialized"}, + wantErr: false, + }, + + // -- Special characters -- + { + name: "tool name with hyphen", + version: futureVersion, + methodHeader: "tools/call", + nameHeader: "my-tool-name", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool-name"})}, + wantErr: false, + }, + { + name: "tool name with underscore", + version: futureVersion, + methodHeader: "tools/call", + nameHeader: "my_tool_name", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my_tool_name"})}, + wantErr: false, + }, + { + name: "resource URI with special chars", + version: futureVersion, + methodHeader: "resources/read", + nameHeader: "file:///path/to/file%20name.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///path/to/file%20name.txt"})}, + wantErr: false, + }, + { + name: "resource URI with query string", + version: futureVersion, + methodHeader: "resources/read", + nameHeader: "https://example.com/resource?id=123", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "https://example.com/resource?id=123"})}, + wantErr: false, + }, + + // -- Non-request messages -- + { + name: "response message is ignored", + version: futureVersion, + msg: &jsonrpc.Response{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpReq, err := http.NewRequest("POST", "http://localhost/mcp", nil) + if err != nil { + t.Fatal(err) + } + if tt.version != "" { + httpReq.Header.Set(ProtocolVersionHeader, tt.version) + } + if tt.methodHeader != "" { + httpReq.Header.Set(MethodHeader, tt.methodHeader) + } + if tt.nameHeader != "" { + httpReq.Header.Set(NameHeader, tt.nameHeader) + } + + err = validateMcpHeaders(httpReq, tt.msg) + if tt.wantErr { + if err == nil { + t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain) + } + if !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("validateMcpHeaders() error = %q, want substring %q", err.Error(), tt.wantErrContain) + } + } else if err != nil { + t.Errorf("validateMcpHeaders() = %v, want nil", err) + } + }) + } +} diff --git a/mcp/shared.go b/mcp/shared.go index cc28de85..53546197 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -343,6 +343,9 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont // MCP-specific error codes. const ( + // CodeHeaderMismatch indicates that HTTP headers do not match the corresponding values + // in the request body, or that required headers are missing or malformed. + CodeHeaderMismatch = -32001 // CodeResourceNotFound indicates that a requested resource could not be found. CodeResourceNotFound = -32002 // CodeURLElicitationRequired indicates that the server requires URL elicitation diff --git a/mcp/streamable.go b/mcp/streamable.go index 8deb6c93..5922c479 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -39,12 +39,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) -const ( - protocolVersionHeader = "Mcp-Protocol-Version" - sessionIDHeader = "Mcp-Session-Id" - lastEventIDHeader = "Last-Event-ID" -) - // A StreamableHTTPHandler is an http.Handler that serves streamable MCP // sessions, as defined by the [MCP spec]. // @@ -284,7 +278,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - sessionID := req.Header.Get(sessionIDHeader) + sessionID := req.Header.Get(SessionIDHeader) var sessInfo *sessionInfo if sessionID != "" { h.mu.Lock() @@ -380,7 +374,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // This logic matches the typescript SDK. // // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header - protocolVersion := req.Header.Get(protocolVersionHeader) + protocolVersion := req.Header.Get(ProtocolVersionHeader) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -924,8 +918,8 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written. lastIdx := -1 - if len(req.Header.Values(lastEventIDHeader)) > 0 { - eid := req.Header.Get(lastEventIDHeader) + if len(req.Header.Values(LastEventIDHeader)) > 0 { + eid := req.Header.Get(LastEventIDHeader) var ok bool streamID, lastIdx, ok = parseEventID(eid) if !ok { @@ -942,7 +936,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request // Read the protocol version from the header. For GET requests, this should // always be present since GET only happens after initialization. - protocolVersion := req.Header.Get(protocolVersionHeader) + protocolVersion := req.Header.Get(ProtocolVersionHeader) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -1095,7 +1089,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // // It returns an HTTP status code and error message. func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { - if len(req.Header.Values(lastEventIDHeader)) > 0 { + if len(req.Header.Values(LastEventIDHeader)) > 0 { http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) return } @@ -1119,7 +1113,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - protocolVersion := req.Header.Get(protocolVersionHeader) + protocolVersion := req.Header.Get(ProtocolVersionHeader) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -1183,6 +1177,26 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } + // Validate MCP standard headers (Mcp-Method, Mcp-Name) after checkRequest + // has confirmed the message is structurally valid, so we can safely include + // the request ID in the JSON-RPC error response. + if !isBatch && len(incoming) == 1 { + if err := validateMcpHeaders(req, incoming[0]); err != nil { + resp := &jsonrpc.Response{ + Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), + } + if jreq, ok := incoming[0].(*jsonrpc.Request); ok { + resp.ID = jreq.ID + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if data, err := jsonrpc2.EncodeMessage(resp); err == nil { + w.Write(data) + } + return + } + } + // The prime and close events were added in protocol version 2025-11-25 (SEP-1699). // Use the version from InitializeParams if this is an initialize request, // otherwise use the protocol version header. @@ -1234,7 +1248,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques w.Header().Set("Connection", "keep-alive") } if c.sessionID != "" && isInitialize { - w.Header().Set(sessionIDHeader, c.sessionID) + w.Header().Set(SessionIDHeader, c.sessionID) } // Set up stream delivery state. @@ -1783,7 +1797,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - if err := c.setMCPHeaders(req); err != nil { + if err := c.setMCPHeaders(req, msg); err != nil { // Failure to set headers means that the request was not sent. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. @@ -1827,7 +1841,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + if sessionID := resp.Header.Get(SessionIDHeader); sessionID != "" { c.mu.Lock() hadSessionID := c.sessionID if hadSessionID == "" { @@ -1883,7 +1897,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { +func (c *streamableClientConn) setMCPHeaders(req *http.Request, msg jsonrpc.Message) error { c.mu.Lock() defer c.mu.Unlock() @@ -1903,11 +1917,14 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { } } if c.initializedResult != nil { - req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + req.Header.Set(ProtocolVersionHeader, c.initializedResult.ProtocolVersion) } if c.sessionID != "" { - req.Header.Set(sessionIDHeader, c.sessionID) + req.Header.Set(SessionIDHeader, c.sessionID) } + + setStandardHeaders(req, msg) + return nil } @@ -2161,11 +2178,11 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - if err := c.setMCPHeaders(req); err != nil { + if err := c.setMCPHeaders(req, nil); err != nil { return nil, err } if lastEventID != "" { - req.Header.Set(lastEventIDHeader, lastEventID) + req.Header.Set(LastEventIDHeader, lastEventID) } req.Header.Set("Accept", "text/event-stream") resp, err := c.client.Do(req) @@ -2194,7 +2211,7 @@ func (c *streamableClientConn) Close() error { if err != nil { c.closeErr = err } else { - if err := c.setMCPHeaders(req); err != nil { + if err := c.setMCPHeaders(req, nil); err != nil { c.closeErr = err } else if _, err := c.client.Do(req); err != nil { c.closeErr = err diff --git a/mcp/streamable_bench_test.go b/mcp/streamable_bench_test.go index fc7b0efe..9e83a59e 100644 --- a/mcp/streamable_bench_test.go +++ b/mcp/streamable_bench_test.go @@ -121,7 +121,7 @@ func BenchmarkStreamableServing_BadSessions(b *testing.B) { if got, want := resp.StatusCode, http.StatusBadRequest; got != want { b.Fatalf("POST got status %d, want %d", got, want) } - if got := resp.Header.Get("Mcp-Session-Id"); got != "" { + if got := resp.Header.Get(mcp.SessionIDHeader); got != "" { b.Fatalf("POST got unexpected session ID") } resp.Body.Close() diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9da8aeca..338d675c 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -69,8 +69,8 @@ func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { key := streamableRequestKey{ httpMethod: req.Method, - sessionID: req.Header.Get(sessionIDHeader), - lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader + sessionID: req.Header.Get(SessionIDHeader), + lastEventID: req.Header.Get(LastEventIDHeader), // TODO: extract this to a constant, like SessionIDHeader } var jsonrpcReq *jsonrpc.Request if req.Method == http.MethodPost { @@ -123,7 +123,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques w.WriteHeader(status) rc.Flush() // flush response headers - if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { + if v := req.Header.Get(ProtocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) } w.Write([]byte(body)) @@ -168,7 +168,7 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -219,7 +219,7 @@ func TestStreamableClientRedundantDelete(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -282,7 +282,7 @@ func TestStreamableClientGETHandling(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json; charset=utf-8", // should ignore the charset - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -351,7 +351,7 @@ func TestStreamableClientStrictness(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -369,7 +369,7 @@ func TestStreamableClientStrictness(t *testing.T) { {"POST", "123", methodListTools, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), optional: true, @@ -410,7 +410,7 @@ func TestStreamableClientUnresumableRequest(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "text/event-stream", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: "", }, @@ -490,7 +490,7 @@ func TestStreamableClientResumption_Cancelled(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -518,7 +518,7 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level" {"POST", "123", methodListTools, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, resp(3, &ListToolsResult{Tools: []*Tool{}}, nil)), }, @@ -633,7 +633,7 @@ func TestStreamableClientTransientErrors(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -647,7 +647,7 @@ func TestStreamableClientTransientErrors(t *testing.T) { {"POST", "123", methodListTools, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, responseFunc: func(r *jsonrpc.Request) (string, int) { // First call returns transient error, subsequent calls succeed. @@ -722,7 +722,7 @@ func TestStreamableClientRetryWithoutProgress(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "test-session", + SessionIDHeader: "test-session", }, body: jsonBody(t, initResp), }, @@ -825,7 +825,7 @@ func TestStreamableClientDisableStandaloneSSE(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -942,7 +942,7 @@ func TestStreamableClientOAuth_AuthorizationHeader(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, @@ -989,7 +989,7 @@ func TestStreamableClientOAuth_401(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 592981fc..db5e4e05 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -198,7 +198,7 @@ func TestStreamableTransports(t *testing.T) { if g := session.ID(); g != sid { t.Errorf("session ID: got %q, want %q", g, sid) } - if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w { + if g, w := lastHeader.Get(ProtocolVersionHeader), latestProtocolVersion; g != w { t.Errorf("got protocol version header %q, want %q", g, w) } want := &CallToolResult{ @@ -825,7 +825,7 @@ func TestStreamableServerTransport(t *testing.T) { } initialized20251125 := streamableRequest{ method: "POST", - headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + headers: http.Header{ProtocolVersionHeader: {protocolVersion20251125}}, messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, } @@ -1175,7 +1175,7 @@ func TestStreamableServerTransport(t *testing.T) { initialized20251125, { method: "POST", - headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + headers: http.Header{ProtocolVersionHeader: {protocolVersion20251125}}, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, @@ -1195,7 +1195,7 @@ func TestStreamableServerTransport(t *testing.T) { initialized20251125, { method: "POST", - headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + headers: http.Header{ProtocolVersionHeader: {protocolVersion20251125}}, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, @@ -1486,7 +1486,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { - req.Header.Set(sessionIDHeader, sessionID) + req.Header.Set(SessionIDHeader, sessionID) } if s.method == http.MethodPost { req.Header.Set("Content-Type", "application/json") @@ -1494,6 +1494,21 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, req.Header.Set("Accept", "application/json, text/event-stream") maps.Copy(req.Header, s.headers) + // Auto-populate MCP headers for single messages if not already set. + if len(s.messages) == 1 { + msg := s.messages[0] + if jreq, ok := msg.(*jsonrpc.Request); ok { + if req.Header.Get(MethodHeader) == "" { + req.Header.Set(MethodHeader, jreq.Method) + } + if req.Header.Get(NameHeader) == "" { + if name, ok := extractName(jreq.Method, jreq.Params); ok { + req.Header.Set(NameHeader, name) + } + } + } + } + if req.Header.Get("Content-Type") == "" { req.Header.Del("Content-Type") } @@ -1504,7 +1519,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } defer resp.Body.Close() - newSessionID := resp.Header.Get(sessionIDHeader) + newSessionID := resp.Header.Get(SessionIDHeader) contentType := baseMediaType(resp.Header.Get("Content-Type")) var respBody []byte @@ -1781,7 +1796,7 @@ func TestSessionHijackingPrevention(t *testing.T) { req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Authorization", "Bearer "+userID) if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(SessionIDHeader, sessionID) } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -1802,7 +1817,7 @@ func TestSessionHijackingPrevention(t *testing.T) { body, _ := io.ReadAll(resp.Body) t.Fatalf("initialize failed with status %d: %s", resp.StatusCode, body) } - sessionID := resp.Header.Get("Mcp-Session-Id") + sessionID := resp.Header.Get(SessionIDHeader) if sessionID == "" { t.Fatal("no session ID in response") } @@ -1886,13 +1901,13 @@ func TestStreamableGET(t *testing.T) { t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body)) } - sessionID := resp.Header.Get(sessionIDHeader) + sessionID := resp.Header.Get(SessionIDHeader) if sessionID == "" { t.Fatalf("initialized missing session ID") } get2 := newReq("GET", nil) - get2.Header.Set(sessionIDHeader, sessionID) + get2.Header.Set(SessionIDHeader, sessionID) resp, err = http.DefaultClient.Do(get2) if err != nil { t.Fatal(err) @@ -1904,7 +1919,7 @@ func TestStreamableGET(t *testing.T) { t.Log("Sending final DELETE request to close session and release resources") del := newReq("DELETE", nil) - del.Header.Set(sessionIDHeader, sessionID) + del.Header.Set(SessionIDHeader, sessionID) resp, err = http.DefaultClient.Do(del) if err != nil { t.Fatal(err) @@ -1915,6 +1930,249 @@ func TestStreamableGET(t *testing.T) { } } +// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method and +// Mcp-Name header validation through the full HTTP handler, as specified +// in SEP-2243. +func TestStreamableMcpHeaderValidation(t *testing.T) { + // Temporarily register the future version so the handler accepts it. + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), MinVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + futureVersionHeader := http.Header{ProtocolVersionHeader: {MinVersionForStandardHeaders}} + + initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: MinVersionForStandardHeaders}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: MinVersionForStandardHeaders, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + }, nil) + + initialize := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, + } + initialized := streamableRequest{ + method: "POST", + headers: futureVersionHeader, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + wantStatusCode: http.StatusAccepted, + } + + testStreamableHandler(t, handler, []streamableRequest{ + initialize, + initialized, + // Correct headers should succeed. + { + method: "POST", + headers: futureVersionHeader, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + // Method mismatch: header says prompts/get, body says tools/call. + { + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {"prompts/get"}, + NameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Method header value", + }, + // Name mismatch: header says wrong-tool, body says my-tool. + { + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {"tools/call"}, + NameHeader: {"wrong-tool"}, + }, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Name header value", + }, + // Case-sensitive: TOOLS/CALL != tools/call. + { + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {"TOOLS/CALL"}, + NameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Method header value", + }, + // Valid request after errors should still succeed. + { + method: "POST", + headers: futureVersionHeader, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + }, + }) +} + +// TestStreamableMcpHeaderValidationErrorFormat verifies that header +// validation errors return a JSON-RPC error with code -32001 and +// Content-Type application/json, per SEP-2243. +func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), MinVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + // Step 1: initialize the session. + initBody, _ := jsonrpc2.EncodeMessage(req(1, methodInitialize, &InitializeParams{ProtocolVersion: MinVersionForStandardHeaders})) + initHTTPReq, _ := http.NewRequest("POST", httpServer.URL, bytes.NewReader(initBody)) + initHTTPReq.Header.Set("Content-Type", "application/json") + initHTTPReq.Header.Set("Accept", "application/json, text/event-stream") + initHTTPReq.Header.Set(MethodHeader, methodInitialize) + initResp, err := http.DefaultClient.Do(initHTTPReq) + if err != nil { + t.Fatal(err) + } + io.ReadAll(initResp.Body) + initResp.Body.Close() + sessionID := initResp.Header.Get(SessionIDHeader) + + // Step 2: send initialized notification. + initdBody, _ := jsonrpc2.EncodeMessage(req(0, notificationInitialized, &InitializedParams{})) + initdHTTPReq, _ := http.NewRequest("POST", httpServer.URL, bytes.NewReader(initdBody)) + initdHTTPReq.Header.Set("Content-Type", "application/json") + initdHTTPReq.Header.Set("Accept", "application/json, text/event-stream") + initdHTTPReq.Header.Set(SessionIDHeader, sessionID) + initdHTTPReq.Header.Set(ProtocolVersionHeader, MinVersionForStandardHeaders) + initdHTTPReq.Header.Set(MethodHeader, notificationInitialized) + initdResp, err := http.DefaultClient.Do(initdHTTPReq) + if err != nil { + t.Fatal(err) + } + initdResp.Body.Close() + + // Step 3: send a request with mismatched Mcp-Method header. + callBody, _ := jsonrpc2.EncodeMessage(req(2, "tools/call", &CallToolParams{Name: "my-tool"})) + callReq, _ := http.NewRequest("POST", httpServer.URL, bytes.NewReader(callBody)) + callReq.Header.Set("Content-Type", "application/json") + callReq.Header.Set("Accept", "application/json, text/event-stream") + callReq.Header.Set(SessionIDHeader, sessionID) + callReq.Header.Set(ProtocolVersionHeader, MinVersionForStandardHeaders) + callReq.Header.Set(MethodHeader, "wrong-method") + callReq.Header.Set(NameHeader, "my-tool") + + resp, err := http.DefaultClient.Do(callReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Verify HTTP status code. + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } + + // Verify Content-Type. + ct := baseMediaType(resp.Header.Get("Content-Type")) + if ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } + + // Verify JSON-RPC error body contains error code -32001. + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + if !strings.Contains(bodyStr, `"code":-32001`) { + t.Errorf("response body missing error code -32001:\n%s", bodyStr) + } + if !strings.Contains(bodyStr, "Mcp-Method header value") { + t.Errorf("response body missing error message:\n%s", bodyStr) + } +} + +// TestStreamableMcpHeaderVersionGating verifies that header validation +// is skipped for protocol versions older than MinVersionForStandardHeaders. +func TestStreamableMcpHeaderVersionGating(t *testing.T) { + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20251125}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: protocolVersion20251125, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + }, nil) + + testStreamableHandler(t, handler, []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, + }, + { + method: "POST", + headers: http.Header{ProtocolVersionHeader: {protocolVersion20251125}}, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + wantStatusCode: http.StatusAccepted, + }, + // Requests with deliberately wrong MCP headers should still succeed + // because the protocol version is too old for validation. + { + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {protocolVersion20251125}, + MethodHeader: {"wrong-method"}, + NameHeader: {"wrong-name"}, + }, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + }) +} + // TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance: // 405 Method Not Allowed responses MUST include an Allow header. func TestStreamable405AllowHeader(t *testing.T) { @@ -2082,7 +2340,7 @@ func TestStreamableClientContextPropagation(t *testing.T) { switch req.Method { case "POST": w.Header().Set("Content-Type", "application/json") - w.Header().Set("Mcp-Session-Id", "test-session") + w.Header().Set(SessionIDHeader, "test-session") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}`)) case "GET": From d00288d75567c18c6d0b2be8d188b40ae95c8651 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 24 Apr 2026 14:49:08 +0000 Subject: [PATCH 02/17] fix --- mcp/mcp_http_headers.go | 4 ++-- mcp/streamable.go | 4 +--- mcp/streamable_client_test.go | 2 +- mcp/streamable_test.go | 39 +++++++++++++++-------------------- 4 files changed, 21 insertions(+), 28 deletions(-) diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index 70099f8c..899af59f 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -74,7 +74,7 @@ func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { return errors.New("missing required Mcp-Method header") } if methodInHeader != msg.Method { - return fmt.Errorf("Header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method) + return fmt.Errorf("header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method) } if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" { @@ -84,7 +84,7 @@ func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { } if nameInBody, ok := extractName(msg.Method, msg.Params); ok { if nameInHeader != nameInBody { - return fmt.Errorf("Header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) + return fmt.Errorf("header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) } } } diff --git a/mcp/streamable.go b/mcp/streamable.go index 5922c479..23fb664d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1177,9 +1177,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // Validate MCP standard headers (Mcp-Method, Mcp-Name) after checkRequest - // has confirmed the message is structurally valid, so we can safely include - // the request ID in the JSON-RPC error response. + // Validate MCP standard headers (Mcp-Method, Mcp-Name) if !isBatch && len(incoming) == 1 { if err := validateMcpHeaders(req, incoming[0]); err != nil { resp := &jsonrpc.Response{ diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 338d675c..9874c0fd 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -70,7 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques key := streamableRequestKey{ httpMethod: req.Method, sessionID: req.Header.Get(SessionIDHeader), - lastEventID: req.Header.Get(LastEventIDHeader), // TODO: extract this to a constant, like SessionIDHeader + lastEventID: req.Header.Get(LastEventIDHeader), } var jsonrpcReq *jsonrpc.Request if req.Method == http.MethodPost { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index db5e4e05..0b3e4485 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1494,20 +1494,6 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, req.Header.Set("Accept", "application/json, text/event-stream") maps.Copy(req.Header, s.headers) - // Auto-populate MCP headers for single messages if not already set. - if len(s.messages) == 1 { - msg := s.messages[0] - if jreq, ok := msg.(*jsonrpc.Request); ok { - if req.Header.Get(MethodHeader) == "" { - req.Header.Set(MethodHeader, jreq.Method) - } - if req.Header.Get(NameHeader) == "" { - if name, ok := extractName(jreq.Method, jreq.Params); ok { - req.Header.Set(NameHeader, name) - } - } - } - } if req.Header.Get("Content-Type") == "" { req.Header.Del("Content-Type") @@ -1949,8 +1935,6 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() - futureVersionHeader := http.Header{ProtocolVersionHeader: {MinVersionForStandardHeaders}} - initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: MinVersionForStandardHeaders}) initResp := resp(1, &InitializeResult{ Capabilities: &ServerCapabilities{ @@ -1969,8 +1953,11 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { wantSessionID: true, } initialized := streamableRequest{ - method: "POST", - headers: futureVersionHeader, + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {notificationInitialized}, + }, messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, wantStatusCode: http.StatusAccepted, } @@ -1980,8 +1967,12 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { initialized, // Correct headers should succeed. { - method: "POST", - headers: futureVersionHeader, + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {"tools/call"}, + NameHeader: {"my-tool"}, + }, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, @@ -2024,8 +2015,12 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, // Valid request after errors should still succeed. { - method: "POST", - headers: futureVersionHeader, + method: "POST", + headers: http.Header{ + ProtocolVersionHeader: {MinVersionForStandardHeaders}, + MethodHeader: {"tools/call"}, + NameHeader: {"my-tool"}, + }, messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, From 57659c0fd3c945f9ef1c646c1b3773a45cd1ca51 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 24 Apr 2026 14:53:55 +0000 Subject: [PATCH 03/17] formatter --- mcp/streamable_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0b3e4485..a7825c3b 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1494,7 +1494,6 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, req.Header.Set("Accept", "application/json, text/event-stream") maps.Copy(req.Header, s.headers) - if req.Header.Get("Content-Type") == "" { req.Header.Del("Content-Type") } From 604f2d4ebc891f018da226699f66419ebf6d36d0 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 24 Apr 2026 15:15:21 +0000 Subject: [PATCH 04/17] fix after merge --- mcp/streamable_client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index dcd6d46f..664e2fe2 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -1066,7 +1066,7 @@ func TestStreamableClientOAuth_CancelledAuthorize_NoReprompt(t *testing.T) { {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", - sessionIDHeader: "123", + SessionIDHeader: "123", }, body: jsonBody(t, initResp), }, From 7df5ab69067ec70c185bd8a1a98ad7d0cde354e4 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 24 Apr 2026 15:24:59 +0000 Subject: [PATCH 05/17] refactor: decouple standard header population from setMCPHeaders in streamableClientConn --- mcp/streamable.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 42cb02bf..b3afcf3f 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1803,12 +1803,13 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - if err := c.setMCPHeaders(req, msg); err != nil { + if err := c.setMCPHeaders(req); err != nil { // Failure to set headers means that the request was not sent. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } + setStandardHeaders(req, msg) resp, err := c.client.Do(req) if err != nil { // Any error from client.Do means the request didn't reach the server. @@ -1919,7 +1920,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -func (c *streamableClientConn) setMCPHeaders(req *http.Request, msg jsonrpc.Message) error { +func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { c.mu.Lock() defer c.mu.Unlock() @@ -1945,8 +1946,6 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request, msg jsonrpc.Mess req.Header.Set(SessionIDHeader, c.sessionID) } - setStandardHeaders(req, msg) - return nil } @@ -2200,7 +2199,7 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - if err := c.setMCPHeaders(req, nil); err != nil { + if err := c.setMCPHeaders(req); err != nil { return nil, err } if lastEventID != "" { @@ -2233,7 +2232,7 @@ func (c *streamableClientConn) Close() error { if err != nil { c.closeErr = err } else { - if err := c.setMCPHeaders(req, nil); err != nil { + if err := c.setMCPHeaders(req); err != nil { c.closeErr = err } else if _, err := c.client.Do(req); err != nil { c.closeErr = err From 9de3bec99bcca6eae4e98508a71d87235e7e102f Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 26 Apr 2026 12:17:01 +0000 Subject: [PATCH 06/17] feat: implement SEP-2243 header encoding and support Mcp-Param- headers for tool calls --- mcp/client.go | 37 ++- mcp/header_encoding.go | 90 +++++++ mcp/header_encoding_test.go | 110 ++++++++ mcp/mcp_http_headers.go | 162 +++++++++++ mcp/mcp_http_headers_test.go | 503 +++++++++++++++++++++++++++++++++++ mcp/streamable.go | 5 + 6 files changed, 906 insertions(+), 1 deletion(-) create mode 100644 mcp/header_encoding.go create mode 100644 mcp/header_encoding_test.go diff --git a/mcp/client.go b/mcp/client.go index c82e189d..b132177a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -299,6 +299,12 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp // // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. +// toolContextKeyType is the context key type for passing tool definitions +// from CallTool to the transport layer. +type toolContextKeyType struct{} + +var toolContextKey = toolContextKeyType{} + type ClientSession struct { // Ensure that onClose is called at most once. // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the @@ -318,6 +324,13 @@ type ClientSession struct { // Pending URL elicitations waiting for completion notifications. pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} + + // toolCache stores tool definitions keyed by name, populated by + // ListTools/Tools. Used to look up x-mcp-header annotations when + // constructing Mcp-Param-* headers for tools/call requests. + // No mutex is required because CallTool cannot be meaningfully called + // until ListTools has returned (the caller needs the tool name). + toolCache map[string]*Tool } type clientSessionState struct { @@ -363,6 +376,19 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +func (cs *ClientSession) cacheTools(tools []*Tool) { + if cs.toolCache == nil { + cs.toolCache = make(map[string]*Tool, len(tools)) + } + for _, tool := range tools { + cs.toolCache[tool.Name] = tool + } +} + +func (cs *ClientSession) getCachedTool(name string) *Tool { + return cs.toolCache[name] +} + // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup @@ -981,7 +1007,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + if err != nil { + return nil, err + } + result.Tools = filterValidTools(result.Tools) + cs.cacheTools(result.Tools) + return result, nil } // CallTool calls the tool with the given parameters. @@ -995,6 +1027,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } + if tool := cs.getCachedTool(params.Name); tool != nil { + ctx = context.WithValue(ctx, toolContextKey, tool) + } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } diff --git a/mcp/header_encoding.go b/mcp/header_encoding.go new file mode 100644 index 00000000..3f032fac --- /dev/null +++ b/mcp/header_encoding.go @@ -0,0 +1,90 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/base64" + "fmt" + "strings" +) + +const ( + base64Prefix = "=?base64?" + base64Suffix = "?=" +) + +// encodeHeaderValue converts a parameter value to an HTTP header-safe string +// per the SEP-2243 encoding rules: +// - string: used as-is if safe ASCII, otherwise Base64 encoded +// - number (float64): decimal string representation +// - bool: lowercase "true" or "false" +// - nil: returns "", false +// +// Values that contain non-ASCII characters, control characters, or +// leading/trailing whitespace are Base64-encoded with the =?base64?...?= wrapper. +func encodeHeaderValue(value any) (string, bool) { + var s string + switch v := value.(type) { + case string: + s = v + case float64: + s = fmt.Sprintf("%g", v) + case bool: + if v { + s = "true" + } else { + s = "false" + } + default: + return "", false + } + + if requiresBase64Encoding(s) { + return encodeBase64(s), true + } + return s, true +} + +// decodeHeaderValue decodes a header value that may be Base64-encoded +// with the =?base64?...?= wrapper. Returns the decoded string and true +// on success. Returns "", false if Base64 decoding fails. +// Non-encoded values are returned as-is. +func decodeHeaderValue(headerValue string) (string, bool) { + if len(headerValue) == 0 { + return headerValue, true + } + + if strings.HasPrefix(strings.ToLower(headerValue), strings.ToLower(base64Prefix)) && + strings.HasSuffix(headerValue, base64Suffix) { + encoded := headerValue[len(base64Prefix) : len(headerValue)-len(base64Suffix)] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", false + } + return string(decoded), true + } + return headerValue, true +} + +func requiresBase64Encoding(s string) bool { + if len(s) == 0 { + return false + } + if s[0] == ' ' || s[0] == '\t' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' { + return true + } + for _, c := range s { + if c < 0x20 || c > 0x7E { + if c != '\t' { + return true + } + } + } + return false +} + +func encodeBase64(s string) string { + return base64Prefix + base64.StdEncoding.EncodeToString([]byte(s)) + base64Suffix +} diff --git a/mcp/header_encoding_test.go b/mcp/header_encoding_test.go new file mode 100644 index 00000000..811ab336 --- /dev/null +++ b/mcp/header_encoding_test.go @@ -0,0 +1,110 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import "testing" + +func TestEncodeHeaderValue(t *testing.T) { + tests := []struct { + name string + value any + want string + wantOK bool + }{ + // Strings + {"plain ASCII", "us-west1", "us-west1", true}, + {"empty string", "", "", true}, + {"string with internal spaces", "us west 1", "us west 1", true}, + {"string with leading space", " us-west1", "=?base64?IHVzLXdlc3Qx?=", true}, + {"string with trailing space", "us-west1 ", "=?base64?dXMtd2VzdDEg?=", true}, + {"string with both spaces", " us-west1 ", "=?base64?IHVzLXdlc3QxIA==?=", true}, + {"non-ASCII", "日本語", "=?base64?5pel5pys6Kqe?=", true}, + {"mixed ASCII and non-ASCII", "Hello, 世界", "=?base64?SGVsbG8sIOS4lueVjA==?=", true}, + {"string with newline", "line1\nline2", "=?base64?bGluZTEKbGluZTI=?=", true}, + {"string with carriage return", "line1\r\nline2", "=?base64?bGluZTENCmxpbmUy?=", true}, + {"string with leading tab", "\tindented", "=?base64?CWluZGVudGVk?=", true}, + + // Numbers + {"integer", float64(42), "42", true}, + {"float", float64(3.14159), "3.14159", true}, + + // Booleans + {"true", true, "true", true}, + {"false", false, "false", true}, + + // Unsupported types + {"nil", nil, "", false}, + {"slice", []string{"a"}, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := encodeHeaderValue(tt.value) + if ok != tt.wantOK { + t.Fatalf("encodeHeaderValue(%v) ok = %v, want %v", tt.value, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("encodeHeaderValue(%v) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestDecodeHeaderValue(t *testing.T) { + tests := []struct { + name string + input string + want string + wantOK bool + }{ + {"plain value", "us-west1", "us-west1", true}, + {"empty value", "", "", true}, + {"valid base64", "=?base64?SGVsbG8=?=", "Hello", true}, + {"non-ASCII decoded", "=?base64?5pel5pys6Kqe?=", "日本語", true}, + {"leading space decoded", "=?base64?IHVzLXdlc3Qx?=", " us-west1", true}, + {"case-insensitive prefix", "=?BASE64?SGVsbG8=?=", "Hello", true}, + {"invalid base64 chars", "=?base64?SGVs!!!bG8=?=", "", false}, + // Missing prefix or suffix: treated as literal values, not base64 + {"missing prefix", "SGVsbG8=", "SGVsbG8=", true}, + {"missing suffix", "=?base64?SGVsbG8=", "=?base64?SGVsbG8=", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := decodeHeaderValue(tt.input) + if ok != tt.wantOK { + t.Fatalf("decodeHeaderValue(%q) ok = %v, want %v", tt.input, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("decodeHeaderValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEncodeDecodeRoundTrip(t *testing.T) { + values := []string{ + "us-west1", + "", + " leading", + "trailing ", + "Hello, 世界", + "line1\nline2", + "\ttab", + } + for _, v := range values { + encoded, ok := encodeHeaderValue(v) + if !ok { + t.Fatalf("encodeHeaderValue(%q) failed", v) + } + decoded, ok := decodeHeaderValue(encoded) + if !ok { + t.Fatalf("decodeHeaderValue(%q) failed", encoded) + } + if decoded != v { + t.Errorf("round-trip failed: %q -> %q -> %q", v, encoded, decoded) + } + } +} diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index 899af59f..350273e9 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -8,7 +8,9 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" + "strings" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -19,6 +21,7 @@ const ( LastEventIDHeader = "Last-Event-ID" MethodHeader = "Mcp-Method" NameHeader = "Mcp-Name" + ParamHeaderPrefix = "Mcp-Param-" MinVersionForStandardHeaders = "2026-06-XX" ) @@ -58,7 +61,166 @@ func setStandardHeaders(httpReq *http.Request, msg jsonrpc.Message) { if name, ok := extractName(msg.Method, msg.Params); ok { httpReq.Header.Set(NameHeader, name) } + if msg.Method == "tools/call" { + if tool, ok := msg.Extra.(*Tool); ok && tool != nil { + setParamHeaders(httpReq, tool, msg.Params) + } + } + } +} + +// setParamHeaders reads x-mcp-header annotations from the tool's InputSchema +// and sets Mcp-Param-{Name} headers on the HTTP request from the corresponding +// argument values. +func setParamHeaders(httpReq *http.Request, tool *Tool, params json.RawMessage) { + paramHeaders := extractToolParamHeaders(tool) + if len(paramHeaders) == 0 { + return + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := json.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { + return + } + + for paramName, headerName := range paramHeaders { + argRaw, ok := raw.Arguments[paramName] + if !ok { + continue + } + // null → omit header per SEP + if string(argRaw) == "null" { + continue + } + val := unmarshalPrimitive(argRaw) + if val == nil { + continue + } + encoded, ok := encodeHeaderValue(val) + if !ok { + continue + } + httpReq.Header.Set(ParamHeaderPrefix+headerName, encoded) + } +} + +// extractToolParamHeaders returns a map of parameter name → header name +// for all properties in the tool's InputSchema that have an x-mcp-header +// annotation. On the client side, InputSchema arrives as map[string]any. +func extractToolParamHeaders(tool *Tool) map[string]string { + schema, ok := tool.InputSchema.(map[string]any) + if !ok { + return nil + } + props, ok := schema["properties"].(map[string]any) + if !ok { + return nil + } + result := make(map[string]string) + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + headerName, ok := ps["x-mcp-header"].(string) + if !ok || headerName == "" { + continue + } + result[propName] = headerName + } + if len(result) == 0 { + return nil + } + return result +} + +// unmarshalPrimitive unmarshals a JSON value into a Go primitive +// (string, float64, or bool). Returns nil for non-primitive types. +func unmarshalPrimitive(raw json.RawMessage) any { + var val any + if err := json.Unmarshal(raw, &val); err != nil { + return nil + } + switch val.(type) { + case string, float64, bool: + return val + default: + return nil + } +} + +// validateToolParamHeaders checks that a tool's x-mcp-header annotations +// are valid per SEP-2243. Returns an error describing the first violation, or nil. +// +// Constraints: +// - x-mcp-header values MUST NOT be empty +// - MUST contain only ASCII characters (excluding space and ':') +// - MUST be case-insensitively unique +// - MUST only be applied to properties with primitive types (string, number, boolean) +func validateToolParamHeaders(tool *Tool) error { + schema, ok := tool.InputSchema.(map[string]any) + if !ok { + return nil + } + props, ok := schema["properties"].(map[string]any) + if !ok { + return nil + } + + seen := make(map[string]bool) + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + headerNameRaw, exists := ps["x-mcp-header"] + if !exists { + continue + } + headerName, ok := headerNameRaw.(string) + if !ok || headerName == "" { + return fmt.Errorf("property %q: x-mcp-header must be a non-empty string", propName) + } + if err := validateHeaderName(headerName); err != nil { + return fmt.Errorf("property %q: %w", propName, err) + } + lower := strings.ToLower(headerName) + if seen[lower] { + return fmt.Errorf("property %q: duplicate x-mcp-header value %q (case-insensitive)", propName, headerName) + } + seen[lower] = true + + propType, _ := ps["type"].(string) + if propType != "" && propType != "string" && propType != "number" && propType != "integer" && propType != "boolean" { + return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %q", propName, propType) + } + } + return nil +} + +func validateHeaderName(name string) error { + for _, c := range name { + if c <= 0x20 || c > 0x7E || c == ':' { + return fmt.Errorf("x-mcp-header value %q contains invalid character %q", name, c) + } + } + return nil +} + +// filterValidTools returns a new slice containing only tools with valid +// x-mcp-header annotations. Invalid tools are logged and excluded. +func filterValidTools(tools []*Tool) []*Tool { + result := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if err := validateToolParamHeaders(tool); err != nil { + log.Printf("mcp: excluding tool %q from tools/list: %v", tool.Name, err) + continue + } + result = append(result, tool) } + return result } func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { diff --git a/mcp/mcp_http_headers_test.go b/mcp/mcp_http_headers_test.go index 464ba5f5..68b3e965 100644 --- a/mcp/mcp_http_headers_test.go +++ b/mcp/mcp_http_headers_test.go @@ -6,6 +6,7 @@ package mcp import ( "encoding/json" + "fmt" "net/http" "strings" "testing" @@ -434,3 +435,505 @@ func TestValidateMcpHeaders(t *testing.T) { }) } } + +func TestValidateToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + wantErr bool + wantErrSub string + }{ + { + name: "valid tool with x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + { + name: "tool with no x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + }, + { + name: "empty x-mcp-header value", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "non-empty string", + }, + { + name: "x-mcp-header with space", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "My Region", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with colon", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region:Primary", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with non-ASCII", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Région", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "duplicate header names same case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "duplicate header names different case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "REGION"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "x-mcp-header on array type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on object type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "nested": map[string]any{ + "type": "object", + "x-mcp-header": "Nested", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on number type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "number", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on integer type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "integer", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on boolean type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateToolParamHeaders(tt.tool) + if tt.wantErr { + if err == nil { + t.Fatal("validateToolParamHeaders() = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Errorf("error = %q, want substring %q", err.Error(), tt.wantErrSub) + } + } else if err != nil { + t.Errorf("validateToolParamHeaders() = %v, want nil", err) + } + }) + } +} + +func TestFilterValidTools(t *testing.T) { + valid := &Tool{ + Name: "valid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + } + invalid := &Tool{ + Name: "invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": ""}, + }, + }, + } + noAnnotation := &Tool{ + Name: "plain", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, + } + + result := filterValidTools([]*Tool{valid, invalid, noAnnotation}) + if len(result) != 2 { + t.Fatalf("filterValidTools returned %d tools, want 2", len(result)) + } + if result[0].Name != "valid" || result[1].Name != "plain" { + t.Errorf("filterValidTools returned [%s, %s], want [valid, plain]", result[0].Name, result[1].Name) + } +} + +func TestSetStandardHeadersWithParamHeaders(t *testing.T) { + toolSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + "priority": map[string]any{ + "type": "string", + "x-mcp-header": "Priority", + }, + }, + } + tool := &Tool{Name: "execute_sql", InputSchema: toolSchema} + + tests := []struct { + name string + tool *Tool + params any + wantHeaders map[string]string + }{ + { + name: "sets param headers from arguments", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1", "priority": "high"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "us-west1", + "Mcp-Param-Priority": "high", + }, + }, + { + name: "omits header when argument is missing", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"query": "SELECT 1"}, + }, + wantHeaders: map[string]string{}, + }, + { + name: "omits header when argument is null", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": nil, "query": "SELECT 1"}, + }, + wantHeaders: map[string]string{}, + }, + { + name: "encodes non-ASCII value", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "日本", "query": "SELECT 1"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "=?base64?5pel5pys?=", + }, + }, + { + name: "handles boolean argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean", "x-mcp-header": "Flag"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"flag": true}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Flag": "true", + }, + }, + { + name: "handles number argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "number", "x-mcp-header": "Count"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"count": float64(42)}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Count": "42", + }, + }, + { + name: "no tool in extra does not add param headers", + tool: nil, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + }, + wantHeaders: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpReq, err := http.NewRequest("POST", "http://localhost/mcp", nil) + if err != nil { + t.Fatal(err) + } + httpReq.Header.Set(ProtocolVersionHeader, MinVersionForStandardHeaders) + + msg := &jsonrpc.Request{ + Method: "tools/call", + Params: mustMarshal(tt.params), + Extra: tt.tool, + } + + setStandardHeaders(httpReq, msg) + + if got := httpReq.Header.Get(MethodHeader); got != "tools/call" { + t.Errorf("MethodHeader = %q, want %q", got, "tools/call") + } + + for header, want := range tt.wantHeaders { + if got := httpReq.Header.Get(header); got != want { + t.Errorf("%s = %q, want %q", header, got, want) + } + } + + // Verify non-annotated params don't get headers + if got := httpReq.Header.Get("Mcp-Param-query"); got != "" { + t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got) + } + }) + } +} + +func TestExtractToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + want map[string]string + }{ + { + name: "extracts x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "query": map[string]any{"type": "string"}, + "tenant_id": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, + }, + }, + }, + want: map[string]string{"region": "Region", "tenant_id": "TenantId"}, + }, + { + name: "returns nil for tool without properties", + tool: &Tool{Name: "test", InputSchema: map[string]any{"type": "object"}}, + want: nil, + }, + { + name: "returns nil for non-map schema", + tool: &Tool{Name: "test", InputSchema: "not a map"}, + want: nil, + }, + { + name: "returns nil when no annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"q": map[string]any{"type": "string"}}, + }, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractToolParamHeaders(tt.tool) + if tt.want == nil { + if got != nil { + t.Errorf("extractToolParamHeaders() = %v, want nil", got) + } + return + } + if len(got) != len(tt.want) { + t.Fatalf("extractToolParamHeaders() returned %d entries, want %d", len(got), len(tt.want)) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("extractToolParamHeaders()[%q] = %q, want %q", k, got[k], v) + } + } + }) + } +} + +func TestUnmarshalPrimitive(t *testing.T) { + tests := []struct { + name string + raw string + want any + }{ + {"string", `"hello"`, "hello"}, + {"number", `42`, float64(42)}, + {"float", `3.14`, float64(3.14)}, + {"true", `true`, true}, + {"false", `false`, false}, + {"null", `null`, nil}, + {"array", `[1,2]`, nil}, + {"object", `{"a":1}`, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unmarshalPrimitive(json.RawMessage(tt.raw)) + if fmt.Sprintf("%v", got) != fmt.Sprintf("%v", tt.want) { + t.Errorf("unmarshalPrimitive(%s) = %v (%T), want %v (%T)", tt.raw, got, got, tt.want, tt.want) + } + }) + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index b3afcf3f..8467693d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1785,6 +1785,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e if msg.IsCall() { forCall = msg } + if msg.Method == "tools/call" { + if tool, ok := ctx.Value(toolContextKey).(*Tool); ok { + msg.Extra = tool + } + } case *jsonrpc.Response: requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) default: From f429bc57ea362999f63fab181daef5275dfa779e Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 26 Apr 2026 12:57:24 +0000 Subject: [PATCH 07/17] feat: implement validation for tool-specific parameter headers in MCP requests --- mcp/mcp_http_headers.go | 54 +++++++++++++++++++++++++++++++++++- mcp/mcp_http_headers_test.go | 2 +- mcp/streamable.go | 20 +++++++++++-- 3 files changed, 72 insertions(+), 4 deletions(-) diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index 350273e9..175aa4a8 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -223,7 +223,7 @@ func filterValidTools(tools []*Tool) []*Tool { return result } -func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { +func validateMcpHeaders(req *http.Request, msg jsonrpc.Message, tool *Tool) error { protocolVersion := req.Header.Get(ProtocolVersionHeader) if protocolVersion == "" || protocolVersion < MinVersionForStandardHeaders { return nil @@ -250,6 +250,58 @@ func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error { } } } + + if msg.Method == "tools/call" && tool != nil { + if err := validateParamHeaders(req, msg, tool); err != nil { + return err + } + } + } + return nil +} + +func validateParamHeaders(req *http.Request, msg *jsonrpc.Request, tool *Tool) error { + paramHeaders := extractToolParamHeaders(tool) + if len(paramHeaders) == 0 { + return nil + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := json.Unmarshal(msg.Params, &raw); err != nil { + return nil + } + + for paramName, headerName := range paramHeaders { + fullHeader := ParamHeaderPrefix + headerName + headerVal := req.Header.Get(fullHeader) + argRaw, argExists := raw.Arguments[paramName] + argIsNull := argExists && string(argRaw) == "null" + + if !argExists || argIsNull { + if headerVal != "" { + return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, paramName) + } + continue + } + + if headerVal == "" { + return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, paramName) + } + + bodyVal := unmarshalPrimitive(argRaw) + if bodyVal == nil { + continue + } + expected, ok := encodeHeaderValue(bodyVal) + if !ok { + continue + } + + if headerVal != expected { + return fmt.Errorf("header mismatch: %s header value '%s' does not match body value", fullHeader, headerVal) + } } return nil } diff --git a/mcp/mcp_http_headers_test.go b/mcp/mcp_http_headers_test.go index 68b3e965..77d69e16 100644 --- a/mcp/mcp_http_headers_test.go +++ b/mcp/mcp_http_headers_test.go @@ -421,7 +421,7 @@ func TestValidateMcpHeaders(t *testing.T) { httpReq.Header.Set(NameHeader, tt.nameHeader) } - err = validateMcpHeaders(httpReq, tt.msg) + err = validateMcpHeaders(httpReq, tt.msg, nil /*tool*/) if tt.wantErr { if err == nil { t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8467693d..d9114e92 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -490,6 +490,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "failed connection", http.StatusInternalServerError) return } + transport.connection.toolLookup = func(name string) *Tool { + server.mu.Lock() + defer server.mu.Unlock() + if st, ok := server.tools.get(name); ok { + return st.tool + } + return nil + } // Capture the user ID from the token info to enable session hijacking // prevention on subsequent requests. var userID string @@ -668,6 +676,8 @@ type streamableServerConn struct { logger *slog.Logger + toolLookup func(name string) *Tool + incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex // guards all fields below @@ -1185,9 +1195,15 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // Validate MCP standard headers (Mcp-Method, Mcp-Name) + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { - if err := validateMcpHeaders(req, incoming[0]); err != nil { + var tool *Tool + if jreq, ok := incoming[0].(*jsonrpc.Request); ok && jreq.Method == "tools/call" && c.toolLookup != nil { + if name, ok := extractName(jreq.Method, jreq.Params); ok { + tool = c.toolLookup(name) + } + } + if err := validateMcpHeaders(req, incoming[0], tool); err != nil { resp := &jsonrpc.Response{ Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), } From ad175620b2e5e6b00c3b05a26fe97a41ea482e01 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 26 Apr 2026 13:50:49 +0000 Subject: [PATCH 08/17] feat: implement base64 header decoding and standardize primitive value comparison for parameter validation --- mcp/mcp_http_headers.go | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index 175aa4a8..6b6cb268 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -136,6 +136,22 @@ func extractToolParamHeaders(tool *Tool) map[string]string { return result } +func primitiveToString(value any) string { + switch v := value.(type) { + case string: + return v + case float64: + return fmt.Sprintf("%g", v) + case bool: + if v { + return "true" + } + return "false" + default: + return fmt.Sprintf("%v", v) + } +} + // unmarshalPrimitive unmarshals a JSON value into a Go primitive // (string, float64, or bool). Returns nil for non-primitive types. func unmarshalPrimitive(raw json.RawMessage) any { @@ -277,9 +293,8 @@ func validateParamHeaders(req *http.Request, msg *jsonrpc.Request, tool *Tool) e fullHeader := ParamHeaderPrefix + headerName headerVal := req.Header.Get(fullHeader) argRaw, argExists := raw.Arguments[paramName] - argIsNull := argExists && string(argRaw) == "null" - if !argExists || argIsNull { + if !argExists || string(argRaw) == "null" { if headerVal != "" { return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, paramName) } @@ -290,16 +305,18 @@ func validateParamHeaders(req *http.Request, msg *jsonrpc.Request, tool *Tool) e return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, paramName) } + decoded, ok := decodeHeaderValue(headerVal) + if !ok { + return fmt.Errorf("header mismatch: %s header contains invalid Base64 encoding", fullHeader) + } + bodyVal := unmarshalPrimitive(argRaw) if bodyVal == nil { continue } - expected, ok := encodeHeaderValue(bodyVal) - if !ok { - continue - } + expected := primitiveToString(bodyVal) - if headerVal != expected { + if decoded != expected { return fmt.Errorf("header mismatch: %s header value '%s' does not match body value", fullHeader, headerVal) } } From a4e1e23f681572604e1d42027a1e0a732db92f5e Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 27 Apr 2026 12:12:05 +0000 Subject: [PATCH 09/17] refactor: simplify MCP header logic and remove redundant tool cache accessor --- mcp/client.go | 12 +++--------- mcp/header_encoding.go | 8 ++------ mcp/mcp_http_headers.go | 18 ++++++------------ 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index b132177a..f32807ab 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -325,11 +325,9 @@ type ClientSession struct { pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} - // toolCache stores tool definitions keyed by name, populated by - // ListTools/Tools. Used to look up x-mcp-header annotations when + // toolCache stores tool definitions keyed by name. + // It is used to look up x-mcp-header annotations when // constructing Mcp-Param-* headers for tools/call requests. - // No mutex is required because CallTool cannot be meaningfully called - // until ListTools has returned (the caller needs the tool name). toolCache map[string]*Tool } @@ -385,10 +383,6 @@ func (cs *ClientSession) cacheTools(tools []*Tool) { } } -func (cs *ClientSession) getCachedTool(name string) *Tool { - return cs.toolCache[name] -} - // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup @@ -1027,7 +1021,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } - if tool := cs.getCachedTool(params.Name); tool != nil { + if tool := cs.toolCache[params.Name]; tool != nil { ctx = context.WithValue(ctx, toolContextKey, tool) } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) diff --git a/mcp/header_encoding.go b/mcp/header_encoding.go index 3f032fac..77f414f5 100644 --- a/mcp/header_encoding.go +++ b/mcp/header_encoding.go @@ -48,9 +48,7 @@ func encodeHeaderValue(value any) (string, bool) { } // decodeHeaderValue decodes a header value that may be Base64-encoded -// with the =?base64?...?= wrapper. Returns the decoded string and true -// on success. Returns "", false if Base64 decoding fails. -// Non-encoded values are returned as-is. +// with the =?base64?...?= wrapper. func decodeHeaderValue(headerValue string) (string, bool) { if len(headerValue) == 0 { return headerValue, true @@ -77,9 +75,7 @@ func requiresBase64Encoding(s string) bool { } for _, c := range s { if c < 0x20 || c > 0x7E { - if c != '\t' { - return true - } + return true } } return false diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index 6b6cb268..da3802a0 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -70,8 +70,7 @@ func setStandardHeaders(httpReq *http.Request, msg jsonrpc.Message) { } // setParamHeaders reads x-mcp-header annotations from the tool's InputSchema -// and sets Mcp-Param-{Name} headers on the HTTP request from the corresponding -// argument values. +// and sets Mcp-Param-{Name} headers on the HTTP request. func setParamHeaders(httpReq *http.Request, tool *Tool, params json.RawMessage) { paramHeaders := extractToolParamHeaders(tool) if len(paramHeaders) == 0 { @@ -106,9 +105,9 @@ func setParamHeaders(httpReq *http.Request, tool *Tool, params json.RawMessage) } } -// extractToolParamHeaders returns a map of parameter name → header name +// extractToolParamHeaders returns a map of parameter // for all properties in the tool's InputSchema that have an x-mcp-header -// annotation. On the client side, InputSchema arrives as map[string]any. +// annotation. func extractToolParamHeaders(tool *Tool) map[string]string { schema, ok := tool.InputSchema.(map[string]any) if !ok { @@ -168,13 +167,7 @@ func unmarshalPrimitive(raw json.RawMessage) any { } // validateToolParamHeaders checks that a tool's x-mcp-header annotations -// are valid per SEP-2243. Returns an error describing the first violation, or nil. -// -// Constraints: -// - x-mcp-header values MUST NOT be empty -// - MUST contain only ASCII characters (excluding space and ':') -// - MUST be case-insensitively unique -// - MUST only be applied to properties with primitive types (string, number, boolean) +// are valid. func validateToolParamHeaders(tool *Tool) error { schema, ok := tool.InputSchema.(map[string]any) if !ok { @@ -216,6 +209,7 @@ func validateToolParamHeaders(tool *Tool) error { return nil } +// A valid HeaderName MUST contain only ASCII characters (excluding space and ':'). func validateHeaderName(name string) error { for _, c := range name { if c <= 0x20 || c > 0x7E || c == ':' { @@ -225,7 +219,7 @@ func validateHeaderName(name string) error { return nil } -// filterValidTools returns a new slice containing only tools with valid +// filterValidTools returns only tools that have valid // x-mcp-header annotations. Invalid tools are logged and excluded. func filterValidTools(tools []*Tool) []*Tool { result := make([]*Tool, 0, len(tools)) From aeada36e48127a1c60c05d5c33dbca1d41e6574b Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 27 Apr 2026 17:17:05 +0000 Subject: [PATCH 10/17] refactor: prohibit x-mcp-header annotations on nested object properties and add structural validation --- mcp/client.go | 12 +-- mcp/mcp_http_headers.go | 201 +++++++++++++++++++++-------------- mcp/mcp_http_headers_test.go | 88 ++++++++++++++- 3 files changed, 217 insertions(+), 84 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index f32807ab..d992aac2 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -160,6 +160,12 @@ type ClientOptions struct { KeepAlive time.Duration } +// toolContextKeyType is the context key type for passing tool definitions +// from CallTool to the transport layer. +type toolContextKeyType struct{} + +var toolContextKey = toolContextKeyType{} + // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { @@ -299,12 +305,6 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp // // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. -// toolContextKeyType is the context key type for passing tool definitions -// from CallTool to the transport layer. -type toolContextKeyType struct{} - -var toolContextKey = toolContextKeyType{} - type ClientSession struct { // Ensure that onClose is called at most once. // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the diff --git a/mcp/mcp_http_headers.go b/mcp/mcp_http_headers.go index da3802a0..1723e546 100644 --- a/mcp/mcp_http_headers.go +++ b/mcp/mcp_http_headers.go @@ -23,8 +23,13 @@ const ( NameHeader = "Mcp-Name" ParamHeaderPrefix = "Mcp-Param-" MinVersionForStandardHeaders = "2026-06-XX" + mcpHeaderExtension = "x-mcp-header" ) +// --------------------------------------------------------------------------- +// Shared helpers (used by both client and server) +// --------------------------------------------------------------------------- + func extractName(method string, params json.RawMessage) (string, bool) { switch method { case "tools/call": @@ -47,6 +52,79 @@ func extractName(method string, params json.RawMessage) (string, bool) { return "", false } +func extractSchemaProperties(schema any) map[string]any { + s, ok := schema.(map[string]any) + if !ok { + return nil + } + props, ok := s["properties"].(map[string]any) + if !ok { + return nil + } + return props +} + +// extractToolParamHeaders returns a map of parameter name to header name +// for all properties in the tool's InputSchema that have an x-mcp-header +// annotation. +func extractToolParamHeaders(tool *Tool) map[string]string { + props := extractSchemaProperties(tool.InputSchema) + if props == nil { + return nil + } + result := make(map[string]string) + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + headerName, ok := ps[mcpHeaderExtension].(string) + if !ok || headerName == "" { + continue + } + result[propName] = headerName + } + if len(result) == 0 { + return nil + } + return result +} + +func primitiveToString(value any) string { + switch v := value.(type) { + case string: + return v + case float64: + return fmt.Sprintf("%g", v) + case bool: + if v { + return "true" + } + return "false" + default: + return fmt.Sprintf("%v", v) + } +} + +// unmarshalPrimitive unmarshals a JSON value into a Go primitive +// (string, float64, or bool). Returns nil for non-primitive types. +func unmarshalPrimitive(raw json.RawMessage) any { + var val any + if err := json.Unmarshal(raw, &val); err != nil { + return nil + } + switch val.(type) { + case string, float64, bool: + return val + default: + return nil + } +} + +// --------------------------------------------------------------------------- +// Client-side helpers +// --------------------------------------------------------------------------- + func setStandardHeaders(httpReq *http.Request, msg jsonrpc.Message) { if msg == nil { return @@ -89,7 +167,6 @@ func setParamHeaders(httpReq *http.Request, tool *Tool, params json.RawMessage) if !ok { continue } - // null → omit header per SEP if string(argRaw) == "null" { continue } @@ -105,76 +182,25 @@ func setParamHeaders(httpReq *http.Request, tool *Tool, params json.RawMessage) } } -// extractToolParamHeaders returns a map of parameter -// for all properties in the tool's InputSchema that have an x-mcp-header -// annotation. -func extractToolParamHeaders(tool *Tool) map[string]string { - schema, ok := tool.InputSchema.(map[string]any) - if !ok { - return nil - } - props, ok := schema["properties"].(map[string]any) - if !ok { - return nil - } - result := make(map[string]string) - for propName, propSchema := range props { - ps, ok := propSchema.(map[string]any) - if !ok { - continue - } - headerName, ok := ps["x-mcp-header"].(string) - if !ok || headerName == "" { +// filterValidTools returns only tools that have valid +// x-mcp-header annotations. Invalid tools are logged and excluded. +func filterValidTools(tools []*Tool) []*Tool { + result := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if err := validateToolParamHeaders(tool); err != nil { + log.Printf("mcp: excluding tool %q from tools/list: %v", tool.Name, err) continue } - result[propName] = headerName - } - if len(result) == 0 { - return nil + result = append(result, tool) } return result } -func primitiveToString(value any) string { - switch v := value.(type) { - case string: - return v - case float64: - return fmt.Sprintf("%g", v) - case bool: - if v { - return "true" - } - return "false" - default: - return fmt.Sprintf("%v", v) - } -} - -// unmarshalPrimitive unmarshals a JSON value into a Go primitive -// (string, float64, or bool). Returns nil for non-primitive types. -func unmarshalPrimitive(raw json.RawMessage) any { - var val any - if err := json.Unmarshal(raw, &val); err != nil { - return nil - } - switch val.(type) { - case string, float64, bool: - return val - default: - return nil - } -} - // validateToolParamHeaders checks that a tool's x-mcp-header annotations // are valid. func validateToolParamHeaders(tool *Tool) error { - schema, ok := tool.InputSchema.(map[string]any) - if !ok { - return nil - } - props, ok := schema["properties"].(map[string]any) - if !ok { + props := extractSchemaProperties(tool.InputSchema) + if props == nil { return nil } @@ -184,7 +210,7 @@ func validateToolParamHeaders(tool *Tool) error { if !ok { continue } - headerNameRaw, exists := ps["x-mcp-header"] + headerNameRaw, exists := ps[mcpHeaderExtension] if !exists { continue } @@ -206,10 +232,41 @@ func validateToolParamHeaders(tool *Tool) error { return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %q", propName, propType) } } + + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + if err := checkForNestedHeaders(ps, propName); err != nil { + return err + } + } return nil } -// A valid HeaderName MUST contain only ASCII characters (excluding space and ':'). +func checkForNestedHeaders(schema map[string]any, path string) error { + nestedProps := extractSchemaProperties(schema) + if nestedProps == nil { + return nil + } + for propName, propSchema := range nestedProps { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + if _, exists := ps[mcpHeaderExtension]; exists { + return fmt.Errorf("property %q: x-mcp-header cannot be applied to nested properties", path+"."+propName) + } + if err := checkForNestedHeaders(ps, path+"."+propName); err != nil { + return err + } + } + return nil +} + +// validateHeaderName checks that a header name contains only valid +// ASCII characters (excluding space and ':'). func validateHeaderName(name string) error { for _, c := range name { if c <= 0x20 || c > 0x7E || c == ':' { @@ -219,19 +276,9 @@ func validateHeaderName(name string) error { return nil } -// filterValidTools returns only tools that have valid -// x-mcp-header annotations. Invalid tools are logged and excluded. -func filterValidTools(tools []*Tool) []*Tool { - result := make([]*Tool, 0, len(tools)) - for _, tool := range tools { - if err := validateToolParamHeaders(tool); err != nil { - log.Printf("mcp: excluding tool %q from tools/list: %v", tool.Name, err) - continue - } - result = append(result, tool) - } - return result -} +// --------------------------------------------------------------------------- +// Server-side helpers +// --------------------------------------------------------------------------- func validateMcpHeaders(req *http.Request, msg jsonrpc.Message, tool *Tool) error { protocolVersion := req.Header.Get(ProtocolVersionHeader) diff --git a/mcp/mcp_http_headers_test.go b/mcp/mcp_http_headers_test.go index 77d69e16..41fa74cc 100644 --- a/mcp/mcp_http_headers_test.go +++ b/mcp/mcp_http_headers_test.go @@ -647,6 +647,78 @@ func TestValidateToolParamHeaders(t *testing.T) { }, }, }, + { + name: "x-mcp-header on nested property inside object", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "x-mcp-header on deeply nested property", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "outer": map[string]any{ + "type": "object", + "properties": map[string]any{ + "inner": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + "x-mcp-header": "Value", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "object property without nested x-mcp-header is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + }, + }, + }, + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -689,8 +761,22 @@ func TestFilterValidTools(t *testing.T) { Name: "plain", InputSchema: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, } + nestedInvalid := &Tool{ + Name: "nested-invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + }, + } - result := filterValidTools([]*Tool{valid, invalid, noAnnotation}) + result := filterValidTools([]*Tool{valid, invalid, noAnnotation, nestedInvalid}) if len(result) != 2 { t.Fatalf("filterValidTools returned %d tools, want 2", len(result)) } From b223143fb6c21624444e2f4ef8e4bee7e4526689 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 29 Apr 2026 16:06:35 +0000 Subject: [PATCH 11/17] refactor: rename mcp_http_headers to streamable_headers --- mcp/{mcp_http_headers.go => streamable_headers.go} | 0 mcp/{mcp_http_headers_test.go => streamable_headers_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename mcp/{mcp_http_headers.go => streamable_headers.go} (100%) rename mcp/{mcp_http_headers_test.go => streamable_headers_test.go} (100%) diff --git a/mcp/mcp_http_headers.go b/mcp/streamable_headers.go similarity index 100% rename from mcp/mcp_http_headers.go rename to mcp/streamable_headers.go diff --git a/mcp/mcp_http_headers_test.go b/mcp/streamable_headers_test.go similarity index 100% rename from mcp/mcp_http_headers_test.go rename to mcp/streamable_headers_test.go From 24cc607584de91eed6f0b49ed157c1afc5fb24ab Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 29 Apr 2026 17:23:21 +0000 Subject: [PATCH 12/17] feat: enable MCP parameter headers and add validation tests using internal JSON unmarshaling --- mcp/streamable_headers.go | 18 ++- mcp/streamable_test.go | 272 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+), 3 deletions(-) diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index 08641841..b54b1760 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -27,6 +27,10 @@ const ( mcpHeaderExtension = "x-mcp-header" ) +// --------------------------------------------------------------------------- +// Shared helpers (used by both client and server) +// --------------------------------------------------------------------------- + func extractName(method string, params json.RawMessage) (string, bool) { switch method { case "tools/call": @@ -107,7 +111,7 @@ func primitiveToString(value any) string { // (string, float64, or bool). Returns nil for non-primitive types. func unmarshalPrimitive(raw json.RawMessage) any { var val any - if err := json.Unmarshal(raw, &val); err != nil { + if err := internaljson.Unmarshal(raw, &val); err != nil { return nil } switch val.(type) { @@ -118,6 +122,10 @@ func unmarshalPrimitive(raw json.RawMessage) any { } } +// --------------------------------------------------------------------------- +// Client-side helpers +// --------------------------------------------------------------------------- + // setStandardHeaders populates standard MCP headers. // It requires the protocol version header to be set. func setStandardHeaders(header http.Header, msg jsonrpc.Message) { @@ -153,7 +161,7 @@ func setParamHeaders(header http.Header, tool *Tool, params json.RawMessage) { var raw struct { Arguments map[string]json.RawMessage `json:"arguments"` } - if err := json.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { + if err := internaljson.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { return } @@ -271,6 +279,10 @@ func validateHeaderName(name string) error { return nil } +// --------------------------------------------------------------------------- +// Server-side helpers +// --------------------------------------------------------------------------- + func validateMcpHeaders(header http.Header, msg jsonrpc.Message, tool *Tool) error { protocolVersion := header.Get(protocolVersionHeader) if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { @@ -319,7 +331,7 @@ func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool) var raw struct { Arguments map[string]json.RawMessage `json:"arguments"` } - if err := json.Unmarshal(msg.Params, &raw); err != nil { + if err := internaljson.Unmarshal(msg.Params, &raw); err != nil { return nil } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 80fb4baa..b493b1e0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2170,6 +2170,278 @@ func TestStreamableMcpHeaderVersionGating(t *testing.T) { }) } +// TestStreamableParamHeadersClientSetsHeaders verifies that the client sets +// Mcp-Param-* headers on tool calls when the tool has x-mcp-header annotations. +func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + var capturedHeaders http.Header + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header.Get(methodHeader) == "tools/call" { + capturedHeaders = req.Header.Clone() + } + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + } + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, clientTransport, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + // ListTools to populate the tool cache (needed for param headers). + if _, err := session.ListTools(ctx, nil); err != nil { + t.Fatal(err) + } + + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + }) + if err != nil { + t.Fatal(err) + } + + if capturedHeaders == nil { + t.Fatal("no tool call headers captured") + } + if got := capturedHeaders.Get(methodHeader); got != "tools/call" { + t.Errorf("Mcp-Method = %q, want %q", got, "tools/call") + } + if got := capturedHeaders.Get(nameHeader); got != "execute_sql" { + t.Errorf("Mcp-Name = %q, want %q", got, "execute_sql") + } + if got := capturedHeaders.Get(paramHeaderPrefix + "Region"); got != "us-west1" { + t.Errorf("Mcp-Param-Region = %q, want %q", got, "us-west1") + } + if got := capturedHeaders.Get("Mcp-Param-query"); got != "" { + t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got) + } +} + +// TestStreamableParamHeadersServerValidation verifies that the server +// validates Mcp-Param-* headers against the body for tools with +// x-mcp-header annotations. +func TestStreamableParamHeadersServerValidation(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: minVersionForStandardHeaders, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + }, nil) + + testStreamableHandler(t, handler, []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {notificationInitialized}, + }, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + wantStatusCode: http.StatusAccepted, + }, + // Correct param header should succeed. + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"us-west1"}, + }, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + })}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + // Mismatched param header value should fail. + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"eu-central1"}, + }, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + // Missing param header when body has the argument should fail. + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + }, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "missing", + }, + }) +} + +// TestStreamableFilterValidToolsIntegration verifies that invalid tools +// (with bad x-mcp-header annotations) are filtered out when the client +// calls ListTools. +func TestStreamableFilterValidToolsIntegration(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + noop := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + } + + // Valid tool with correct x-mcp-header annotation. + server.AddTool(&Tool{ + Name: "valid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, noop) + + // Invalid tool: x-mcp-header on an array type. + server.AddTool(&Tool{ + Name: "invalid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, noop) + + // Tool with no x-mcp-header annotations (always valid). + server.AddTool(&Tool{ + Name: "plain-tool", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, noop) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + result, err := session.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + + toolNames := make([]string, len(result.Tools)) + for i, tool := range result.Tools { + toolNames[i] = tool.Name + } + sort.Strings(toolNames) + + wantNames := []string{"plain-tool", "valid-tool"} + if !slices.Equal(toolNames, wantNames) { + t.Errorf("ListTools returned %v, want %v", toolNames, wantNames) + } +} + // TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance: // 405 Method Not Allowed responses MUST include an Allow header. func TestStreamable405AllowHeader(t *testing.T) { From 9dd69071d4f2b51453a3aeddb6cf53d017ffe7e8 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 30 Apr 2026 07:55:00 +0000 Subject: [PATCH 13/17] fix: add mutex protection to toolCache and provide thread-safe accessors in ClientSession --- mcp/client.go | 12 ++++++- mcp/client_test.go | 74 +++++++++++++++++++++++++++++++++++++++ mcp/streamable_headers.go | 11 ------ 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index d992aac2..a3a66132 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -325,6 +325,8 @@ type ClientSession struct { pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} + // toolCacheMu guards toolCache. + toolCacheMu sync.RWMutex // toolCache stores tool definitions keyed by name. // It is used to look up x-mcp-header annotations when // constructing Mcp-Param-* headers for tools/call requests. @@ -375,6 +377,8 @@ func (cs *ClientSession) Wait() error { } func (cs *ClientSession) cacheTools(tools []*Tool) { + cs.toolCacheMu.Lock() + defer cs.toolCacheMu.Unlock() if cs.toolCache == nil { cs.toolCache = make(map[string]*Tool, len(tools)) } @@ -383,6 +387,12 @@ func (cs *ClientSession) cacheTools(tools []*Tool) { } } +func (cs *ClientSession) getCachedTool(name string) *Tool { + cs.toolCacheMu.RLock() + defer cs.toolCacheMu.RUnlock() + return cs.toolCache[name] +} + // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup @@ -1021,7 +1031,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } - if tool := cs.toolCache[params.Name]; tool != nil { + if tool := cs.getCachedTool(params.Name); tool != nil { ctx = context.WithValue(ctx, toolContextKey, tool) } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) diff --git a/mcp/client_test.go b/mcp/client_test.go index fc37c3eb..8c840d9d 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) { } } +func TestToolCache(t *testing.T) { + tool1 := &Tool{Name: "tool1", Description: "first"} + tool2 := &Tool{Name: "tool2", Description: "second"} + tool1Updated := &Tool{Name: "tool1", Description: "updated"} + + testCases := []struct { + name string + cacheBatches [][]*Tool + lookup string + want *Tool + }{ + { + name: "empty cache", + lookup: "tool1", + want: nil, + }, + { + name: "single tool found", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "tool1", + want: tool1, + }, + { + name: "unknown tool", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "nonexistent", + want: nil, + }, + { + name: "multiple tools single batch", + cacheBatches: [][]*Tool{{tool1, tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "additive first tool retained", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool1", + want: tool1, + }, + { + name: "additive second tool added", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "overwrite existing entry", + cacheBatches: [][]*Tool{{tool1}, {tool1Updated}}, + lookup: "tool1", + want: tool1Updated, + }, + { + name: "empty batch no-op", + cacheBatches: [][]*Tool{{}}, + lookup: "tool1", + want: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := &ClientSession{} + for _, batch := range tc.cacheBatches { + cs.cacheTools(batch) + } + got := cs.getCachedTool(tc.lookup) + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff) + } + }) + } +} + func TestClientCapabilitiesOverWire(t *testing.T) { testCases := []struct { name string diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index b54b1760..846b60cd 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -27,9 +27,6 @@ const ( mcpHeaderExtension = "x-mcp-header" ) -// --------------------------------------------------------------------------- -// Shared helpers (used by both client and server) -// --------------------------------------------------------------------------- func extractName(method string, params json.RawMessage) (string, bool) { switch method { @@ -122,10 +119,6 @@ func unmarshalPrimitive(raw json.RawMessage) any { } } -// --------------------------------------------------------------------------- -// Client-side helpers -// --------------------------------------------------------------------------- - // setStandardHeaders populates standard MCP headers. // It requires the protocol version header to be set. func setStandardHeaders(header http.Header, msg jsonrpc.Message) { @@ -279,10 +272,6 @@ func validateHeaderName(name string) error { return nil } -// --------------------------------------------------------------------------- -// Server-side helpers -// --------------------------------------------------------------------------- - func validateMcpHeaders(header http.Header, msg jsonrpc.Message, tool *Tool) error { protocolVersion := header.Get(protocolVersionHeader) if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { From 8d4e94d8bc5e45302636218a3d7611aff7a20bb6 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 30 Apr 2026 13:23:19 +0000 Subject: [PATCH 14/17] test: remove assertion for non-annotated query parameter headers in streamable_test.go --- mcp/streamable_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index b493b1e0..6c18c9a0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2251,9 +2251,6 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { if got := capturedHeaders.Get(paramHeaderPrefix + "Region"); got != "us-west1" { t.Errorf("Mcp-Param-Region = %q, want %q", got, "us-west1") } - if got := capturedHeaders.Get("Mcp-Param-query"); got != "" { - t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got) - } } // TestStreamableParamHeadersServerValidation verifies that the server From e754ff7f04e3cb514563126d5dcbee8c43d3a259 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 30 Apr 2026 13:26:30 +0000 Subject: [PATCH 15/17] refactor: remove redundant blank line in streamable_headers.go --- mcp/streamable_headers.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index 846b60cd..26bf9b3c 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -27,7 +27,6 @@ const ( mcpHeaderExtension = "x-mcp-header" ) - func extractName(method string, params json.RawMessage) (string, bool) { switch method { case "tools/call": From 04652eea470dad2148d7e73a9b616ba9e642dadb Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 30 Apr 2026 13:51:06 +0000 Subject: [PATCH 16/17] merge tests --- mcp/streamable_test.go | 178 +++++++++++++++-------------------------- 1 file changed, 66 insertions(+), 112 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6c18c9a0..2876c4de 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1915,9 +1915,9 @@ func TestStreamableGET(t *testing.T) { } } -// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method and -// Mcp-Name header validation through the full HTTP handler, as specified -// in SEP-2243. +// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method, +// Mcp-Name, and Mcp-Param header validation through the full HTTP handler, +// as specified in SEP-2243. func TestStreamableMcpHeaderValidation(t *testing.T) { // Temporarily register the future version so the handler accepts it. orig := supportedProtocolVersions @@ -1930,6 +1930,25 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { return &CallToolResult{}, nil }) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() @@ -2019,6 +2038,50 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"us-west1"}, + }, + messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + })}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"eu-central1"}, + }, + messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + }, + messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "missing", + }, }) } @@ -2253,115 +2316,6 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { } } -// TestStreamableParamHeadersServerValidation verifies that the server -// validates Mcp-Param-* headers against the body for tools with -// x-mcp-header annotations. -func TestStreamableParamHeadersServerValidation(t *testing.T) { - orig := supportedProtocolVersions - supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) - t.Cleanup(func() { supportedProtocolVersions = orig }) - - server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) - server.AddTool( - &Tool{ - Name: "execute_sql", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "region": map[string]any{ - "type": "string", - "x-mcp-header": "Region", - }, - "query": map[string]any{ - "type": "string", - }, - }, - }, - }, - func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { - return &CallToolResult{}, nil - }) - - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - defer handler.closeAll() - - initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) - initResp := resp(1, &InitializeResult{ - Capabilities: &ServerCapabilities{ - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: minVersionForStandardHeaders, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - }, nil) - - testStreamableHandler(t, handler, []streamableRequest{ - { - method: "POST", - messages: []jsonrpc.Message{initReq}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{initResp}, - wantSessionID: true, - }, - { - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {notificationInitialized}, - }, - messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, - wantStatusCode: http.StatusAccepted, - }, - // Correct param header should succeed. - { - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {"tools/call"}, - nameHeader: {"execute_sql"}, - paramHeaderPrefix + "Region": {"us-west1"}, - }, - messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{ - Name: "execute_sql", - Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, - })}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, - }, - // Mismatched param header value should fail. - { - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {"tools/call"}, - nameHeader: {"execute_sql"}, - paramHeaderPrefix + "Region": {"eu-central1"}, - }, - messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{ - Name: "execute_sql", - Arguments: map[string]any{"region": "us-west1"}, - })}, - wantStatusCode: http.StatusBadRequest, - wantBodyContaining: "header mismatch", - }, - // Missing param header when body has the argument should fail. - { - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {"tools/call"}, - nameHeader: {"execute_sql"}, - }, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{ - Name: "execute_sql", - Arguments: map[string]any{"region": "us-west1"}, - })}, - wantStatusCode: http.StatusBadRequest, - wantBodyContaining: "missing", - }, - }) -} - // TestStreamableFilterValidToolsIntegration verifies that invalid tools // (with bad x-mcp-header annotations) are filtered out when the client // calls ListTools. From d7ef2c12fe9d3e5908d1cdec7005583193041685 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 30 Apr 2026 14:30:39 +0000 Subject: [PATCH 17/17] minor fix --- mcp/streamable_headers.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index 26bf9b3c..6383439a 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -136,27 +136,30 @@ func setStandardHeaders(header http.Header, msg jsonrpc.Message) { } if msg.Method == "tools/call" { if tool, ok := msg.Extra.(*Tool); ok && tool != nil { - setParamHeaders(header, tool, msg.Params) + for k, v := range extractParamHeaders(tool, msg.Params) { + header.Set(k, v) + } } } } } -// setParamHeaders reads x-mcp-header annotations from the tool's InputSchema -// and sets Mcp-Param-{Name} headers on the HTTP request. -func setParamHeaders(header http.Header, tool *Tool, params json.RawMessage) { +// extractParamHeaders reads x-mcp-header annotations from the tool's InputSchema +// and returns the Mcp-Param-{Name} headers to be set on the HTTP request. +func extractParamHeaders(tool *Tool, params json.RawMessage) map[string]string { paramHeaders := extractToolParamHeaders(tool) if len(paramHeaders) == 0 { - return + return nil } var raw struct { Arguments map[string]json.RawMessage `json:"arguments"` } if err := internaljson.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { - return + return nil } + res := make(map[string]string) for paramName, headerName := range paramHeaders { argRaw, ok := raw.Arguments[paramName] if !ok { @@ -173,8 +176,9 @@ func setParamHeaders(header http.Header, tool *Tool, params json.RawMessage) { if !ok { continue } - header.Set(paramHeaderPrefix+headerName, encoded) + res[paramHeaderPrefix+headerName] = encoded } + return res } // filterValidTools returns only tools that have valid @@ -205,6 +209,9 @@ func validateToolParamHeaders(tool *Tool) error { if !ok { continue } + if err := checkForNestedHeaders(ps, propName); err != nil { + return err + } headerNameRaw, exists := ps[mcpHeaderExtension] if !exists { continue @@ -227,16 +234,6 @@ func validateToolParamHeaders(tool *Tool) error { return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %q", propName, propType) } } - - for propName, propSchema := range props { - ps, ok := propSchema.(map[string]any) - if !ok { - continue - } - if err := checkForNestedHeaders(ps, propName); err != nil { - return err - } - } return nil }