From 4589325d16e5c09256b81953367bddd57ef1c1ac Mon Sep 17 00:00:00 2001 From: Eron Wright Date: Wed, 3 Jun 2026 19:10:55 -0700 Subject: [PATCH 1/2] Support tool callbacks in MCP sampling (sampling/createMessage) Adds a parallel SamplingWithToolsHandler alongside the existing SamplingHandler so MCP servers can include a tools array in sampling/createMessage requests. The host drives its model with those tools and returns any tool_use blocks as ToolUseContent; the server remains responsible for executing the tool and continuing the loop in a follow-up sampling request. The initialize handshake now advertises sampling.tools capability, and the MCP toolset selects the appropriate go-sdk handler (basic vs. with-tools) based on which handler is registered. --- pkg/runtime/loop.go | 1 + pkg/runtime/sampling.go | 380 ++++++++++++++++++++++++++++- pkg/runtime/sampling_test.go | 419 ++++++++++++++++++++++++++++++++ pkg/tools/capabilities.go | 19 +- pkg/tools/mcp/mcp.go | 20 +- pkg/tools/mcp/mcp_test.go | 2 + pkg/tools/mcp/reconnect_test.go | 12 +- pkg/tools/mcp/remote.go | 9 +- pkg/tools/mcp/session_client.go | 35 +++ pkg/tools/mcp/stdio.go | 11 +- pkg/tools/sampling.go | 8 + 11 files changed, 896 insertions(+), 20 deletions(-) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 84b66129a..df649f9ab 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -839,6 +839,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events EventSink tools.ConfigureHandlers(toolset, r.elicitationHandler, r.samplingHandler, + r.samplingWithToolsHandler, func() { events.Emit(Authorization(tools.ElicitationActionAccept, a.Name())) }, r.managedOAuth, r.unmanagedOAuthRedirectURI, diff --git a/pkg/runtime/sampling.go b/pkg/runtime/sampling.go index c1ff91559..360c0d908 100644 --- a/pkg/runtime/sampling.go +++ b/pkg/runtime/sampling.go @@ -3,6 +3,7 @@ package runtime import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -14,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/tools" ) // Limits applied to inbound sampling requests to keep a misbehaving or @@ -29,6 +31,13 @@ const ( // maxSamplingBinaryBytes caps the size of an individual image/audio // block before we refuse to inline it as a data URL. maxSamplingBinaryBytes = 8 << 20 // 8 MiB + // maxSamplingTools caps the number of tool definitions a server can + // inject into a single sampling-with-tools request. + maxSamplingTools = 64 + // maxSamplingToolCalls caps the number of tool calls we will return + // from a single sampling-with-tools completion. Per-call argument size + // is bounded by maxSamplingTextBytes. + maxSamplingToolCalls = 32 ) // samplingHandler is the MCP-toolset-side hook that satisfies an inbound @@ -214,12 +223,21 @@ func dataURL(mimeType string, data []byte) string { // reshape the model's reply into something the MCP server didn't ask // for. func samplingModelOptions(req *mcp.CreateMessageParams) []options.Opt { + return samplingModelOptionsFor(req.MaxTokens) +} + +// samplingModelOptionsFor returns the per-request model options shared by the +// basic and with-tools sampling handlers. Structured output is cleared so a +// request cannot inherit the agent's JSON-schema response format; thinking is +// disabled because sampling is a delegated one-shot call rather than an agent +// turn; MaxTokens is honoured when non-zero. +func samplingModelOptionsFor(maxTokens int64) []options.Opt { opts := []options.Opt{ options.WithStructuredOutput(nil), options.WithNoThinking(), } - if req.MaxTokens > 0 { - opts = append(opts, options.WithMaxTokens(req.MaxTokens)) + if maxTokens > 0 { + opts = append(opts, options.WithMaxTokens(maxTokens)) } return opts } @@ -265,3 +283,361 @@ func stopReason(fr chat.FinishReason) string { return "endTurn" } } + +// samplingWithToolsHandler is the MCP-toolset-side hook that satisfies an +// inbound sampling/createMessage request that carries a tools array. It +// forwards the server-supplied tool definitions to the host's model and +// returns any tool_use blocks the model emits; the requesting MCP server +// then executes the tool and continues the loop in a follow-up sampling +// request with tool_result blocks added. +// +// The host never executes the server-supplied tools itself — they exist +// only to inform the model's response. The placeholder handler attached to +// each converted tool surfaces an error if a downstream call site mistakes +// these for ordinary agent tools. +func (r *LocalRuntime) samplingWithToolsHandler(ctx context.Context, req *mcp.CreateMessageWithToolsParams) (*mcp.CreateMessageWithToolsResult, error) { + if req == nil { + return nil, errors.New("sampling request is nil") + } + if len(req.Tools) > maxSamplingTools { + return nil, fmt.Errorf("sampling request includes %d tools (limit %d)", + len(req.Tools), maxSamplingTools) + } + + slog.InfoContext(ctx, "Sampling-with-tools request received from MCP server", + "messages", len(req.Messages), + "tools", len(req.Tools), + "max_tokens", req.MaxTokens, + "system_prompt", req.SystemPrompt != "", + ) + + a := r.CurrentAgent() + if a == nil { + return nil, errors.New("no current agent available to handle sampling request") + } + + messages, err := samplingMessagesV2ToChat(req) + if err != nil { + return nil, fmt.Errorf("converting sampling messages: %w", err) + } + + chatTools := samplingToolsToChat(req.Tools) + + baseModel := a.Model(ctx) + if baseModel == nil { + return nil, errors.New("current agent has no model configured") + } + + model := provider.CloneWithOptions(ctx, baseModel, samplingModelOptionsFor(req.MaxTokens)...) + + stream, err := model.CreateChatCompletionStream(ctx, messages, chatTools) + if err != nil { + return nil, fmt.Errorf("creating sampling completion stream: %w", err) + } + + text, toolCalls, finishReason, err := drainSamplingStreamWithTools(stream) + if err != nil { + return nil, fmt.Errorf("reading sampling completion stream: %w", err) + } + + if len(toolCalls) > maxSamplingToolCalls { + return nil, fmt.Errorf("model emitted %d tool calls (limit %d)", + len(toolCalls), maxSamplingToolCalls) + } + + sr := stopReason(finishReason) + if len(toolCalls) > 0 { + sr = "toolUse" + } + + slog.DebugContext(ctx, "Sampling-with-tools request completed", + "agent", a.Name(), + "model", model.ID().String(), + "finish_reason", finishReason, + "stop_reason", sr, + "tool_calls", len(toolCalls), + "content_bytes", len(text), + ) + + return &mcp.CreateMessageWithToolsResult{ + Role: mcp.Role("assistant"), + Model: model.ID().String(), + Content: buildSamplingWithToolsContent(text, toolCalls), + StopReason: sr, + }, nil +} + +// samplingMessagesV2ToChat converts a CreateMessageWithToolsParams (V2 +// messages with multi-block content) into chat.Messages. The optional system +// prompt is prepended; per-message blocks are folded into one or more chat +// messages depending on which content types are present. +func samplingMessagesV2ToChat(req *mcp.CreateMessageWithToolsParams) ([]chat.Message, error) { + if len(req.Messages) == 0 { + return nil, errors.New("sampling request contains no messages") + } + if len(req.Messages) > maxSamplingMessages { + return nil, fmt.Errorf("sampling request contains %d messages (limit %d)", + len(req.Messages), maxSamplingMessages) + } + + messages := make([]chat.Message, 0, len(req.Messages)+1) + if req.SystemPrompt != "" { + if len(req.SystemPrompt) > maxSamplingTextBytes { + return nil, fmt.Errorf("sampling system prompt is too large (%d bytes, limit %d)", + len(req.SystemPrompt), maxSamplingTextBytes) + } + messages = append(messages, chat.Message{ + Role: chat.MessageRoleSystem, + Content: req.SystemPrompt, + }) + } + for i, m := range req.Messages { + if m == nil { + return nil, fmt.Errorf("sampling message at index %d is nil", i) + } + role, err := samplingRoleToChat(m.Role) + if err != nil { + return nil, err + } + converted, err := samplingV2BlocksToMessages(role, m.Content) + if err != nil { + return nil, fmt.Errorf("sampling message at index %d: %w", i, err) + } + messages = append(messages, converted...) + } + return messages, nil +} + +// samplingV2BlocksToMessages converts a single V2 message's content blocks +// into one or more chat.Messages. Plain blocks (text, image, audio) collapse +// into a single message at the supplied role; tool_use blocks attach as +// ToolCalls on an assistant message; tool_result blocks expand into one +// MessageRoleTool row per result (matching how chat history represents +// parallel tool calls). +func samplingV2BlocksToMessages(role chat.MessageRole, blocks []mcp.Content) ([]chat.Message, error) { + var text strings.Builder + var parts []chat.MessagePart + var toolCalls []tools.ToolCall + var toolResults []chat.Message + + for _, c := range blocks { + switch v := c.(type) { + case nil: + continue + case *mcp.ToolUseContent: + args, err := json.Marshal(v.Input) + if err != nil { + args = []byte("{}") + } + toolCalls = append(toolCalls, tools.ToolCall{ + ID: v.ID, + Type: "function", + Function: tools.FunctionCall{ + Name: v.Name, + Arguments: string(args), + }, + }) + case *mcp.ToolResultContent: + resultText, err := samplingToolResultText(v.Content) + if err != nil { + return nil, fmt.Errorf("tool_result content: %w", err) + } + toolResults = append(toolResults, chat.Message{ + Role: chat.MessageRoleTool, + Content: resultText, + ToolCallID: v.ToolUseID, + IsError: v.IsError, + }) + default: + t, p, err := samplingContentToChat(c) + if err != nil { + return nil, err + } + if t != "" { + if text.Len() > 0 { + text.WriteString("\n") + } + text.WriteString(t) + } + parts = append(parts, p...) + } + } + + var out []chat.Message + if text.Len() > 0 || len(parts) > 0 || (len(toolCalls) > 0 && role == chat.MessageRoleAssistant) { + msg := chat.Message{ + Role: role, + Content: text.String(), + } + if len(parts) > 0 { + msg.MultiContent = parts + } + if role == chat.MessageRoleAssistant && len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + out = append(out, msg) + } + out = append(out, toolResults...) + return out, nil +} + +// samplingToolResultText flattens the nested content of a tool_result block +// into a single text string. chat.MessageRoleTool messages don't carry +// multi-part content, so non-text blocks render as a placeholder. +func samplingToolResultText(blocks []mcp.Content) (string, error) { + var b strings.Builder + var nonText int + for _, c := range blocks { + t, parts, err := samplingContentToChat(c) + if err != nil { + return "", err + } + if t != "" { + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString(t) + } + nonText += len(parts) + } + if b.Len() == 0 && nonText > 0 { + b.WriteString("[tool returned non-text content]") + } + return b.String(), nil +} + +// samplingToolsToChat converts the server-supplied MCP tool definitions into +// the host's tools.Tool representation so the model can be told which tools +// it may call. The Handler is a no-op: the LLM's tool_use response is sent +// back to the requesting MCP server for execution, never invoked here. +func samplingToolsToChat(mcpTools []*mcp.Tool) []tools.Tool { + if len(mcpTools) == 0 { + return nil + } + out := make([]tools.Tool, 0, len(mcpTools)) + for _, t := range mcpTools { + if t == nil { + continue + } + out = append(out, tools.Tool{ + Name: t.Name, + Category: "mcp-sampling", + Description: t.Description, + Parameters: t.InputSchema, + OutputSchema: t.OutputSchema, + Handler: noOpSamplingToolHandler, + }) + } + return out +} + +func noOpSamplingToolHandler(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return tools.ResultError("sampling tool execution belongs to the requesting MCP server"), nil +} + +// drainSamplingStreamWithTools reads a chat completion stream to completion +// and returns the concatenated assistant text, aggregated tool calls, and +// the final finish reason. It mirrors the tool-call aggregation in +// pkg/runtime/streaming.go::handleStream but omits agent events, telemetry, +// session bookkeeping, and the XML fallback — none of which apply to a +// one-shot delegated completion. +func drainSamplingStreamWithTools(stream chat.MessageStream) (string, []tools.ToolCall, chat.FinishReason, error) { + defer stream.Close() + + var text strings.Builder + var toolCalls []tools.ToolCall + toolIndex := make(map[string]int) + var providerFinishReason chat.FinishReason + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return "", nil, "", err + } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] + + if choice.Delta.Content != "" { + text.WriteString(choice.Delta.Content) + } + + for _, delta := range choice.Delta.ToolCalls { + idx, ok := toolIndex[delta.ID] + if !ok { + idx = len(toolCalls) + toolIndex[delta.ID] = idx + toolCalls = append(toolCalls, tools.ToolCall{ + ID: delta.ID, + Type: delta.Type, + }) + } + tc := &toolCalls[idx] + if delta.Type != "" { + tc.Type = delta.Type + } + if delta.Function.Name != "" { + tc.Function.Name = delta.Function.Name + } + if delta.Function.Arguments != "" { + tc.Function.Arguments += delta.Function.Arguments + } + } + + if choice.FinishReason != "" { + providerFinishReason = choice.FinishReason + } + if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength { + break + } + } + + finishReason := providerFinishReason + if finishReason == "" { + switch { + case len(toolCalls) > 0: + finishReason = chat.FinishReasonToolCalls + case text.Len() > 0: + finishReason = chat.FinishReasonStop + default: + finishReason = chat.FinishReasonNull + } + } + switch { + case finishReason == chat.FinishReasonToolCalls && len(toolCalls) == 0: + finishReason = chat.FinishReasonNull + case finishReason == chat.FinishReasonStop && len(toolCalls) > 0: + finishReason = chat.FinishReasonToolCalls + } + + return text.String(), toolCalls, finishReason, nil +} + +// buildSamplingWithToolsContent assembles the assistant response Content +// slice. Any leading text becomes a TextContent block; each tool call +// becomes a ToolUseContent block with the function arguments parsed as a +// JSON object. Malformed arguments fall back to an empty input map so the +// server still sees the call (and can report a tool-side validation error) +// rather than the loop terminating on the client. +func buildSamplingWithToolsContent(text string, toolCalls []tools.ToolCall) []mcp.Content { + var blocks []mcp.Content + if strings.TrimSpace(text) != "" { + blocks = append(blocks, &mcp.TextContent{Text: text}) + } + for _, tc := range toolCalls { + input := map[string]any{} + if tc.Function.Arguments != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &input) + } + blocks = append(blocks, &mcp.ToolUseContent{ + ID: tc.ID, + Name: tc.Function.Name, + Input: input, + }) + } + return blocks +} diff --git a/pkg/runtime/sampling_test.go b/pkg/runtime/sampling_test.go index 70f55ca4b..304dafa59 100644 --- a/pkg/runtime/sampling_test.go +++ b/pkg/runtime/sampling_test.go @@ -1,6 +1,8 @@ package runtime import ( + "encoding/json" + "io" "strings" "testing" @@ -9,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/tools" ) func TestSamplingMessagesToChat(t *testing.T) { @@ -219,3 +222,419 @@ func TestDataURL(t *testing.T) { assert.Equal(t, "data:image/png;base64,UE5HQllURVM=", dataURL("image/png", []byte("PNGBYTES"))) assert.Equal(t, "data:application/octet-stream;base64,YQ==", dataURL("", []byte("a"))) } + +func TestSamplingMessagesV2ToChat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req *mcp.CreateMessageWithToolsParams + want []chat.Message + wantErr bool + }{ + { + name: "single user text block", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, + }, + { + name: "system prompt is prepended", + req: &mcp.CreateMessageWithToolsParams{ + SystemPrompt: "be terse", + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hi"}}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "be terse"}, + {Role: chat.MessageRoleUser, Content: "hi"}, + }, + }, + { + name: "multiple text blocks are concatenated", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: "first"}, + &mcp.TextContent{Text: "second"}, + }}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "first\nsecond"}, + }, + }, + { + name: "text and image in one message", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: "describe"}, + &mcp.ImageContent{Data: []byte("PNG"), MIMEType: "image/png"}, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleUser, + Content: "describe", + MultiContent: []chat.MessagePart{{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,UE5H"}, + }}, + }, + }, + }, + { + name: "tool_use becomes assistant ToolCalls", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "assistant", Content: []mcp.Content{ + &mcp.ToolUseContent{ + ID: "call_1", + Name: "get_weather", + Input: map[string]any{"city": "Paris"}, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ + ID: "call_1", + Type: "function", + Function: tools.FunctionCall{ + Name: "get_weather", + Arguments: `{"city":"Paris"}`, + }, + }}, + }, + }, + }, + { + name: "tool_result expands to tool-role message", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ + ToolUseID: "call_1", + Content: []mcp.Content{&mcp.TextContent{Text: "sunny, 22C"}}, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleTool, + Content: "sunny, 22C", + ToolCallID: "call_1", + }, + }, + }, + { + name: "tool_result IsError surfaces", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ + ToolUseID: "call_1", + Content: []mcp.Content{&mcp.TextContent{Text: "no such city"}}, + IsError: true, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleTool, + Content: "no such city", + ToolCallID: "call_1", + IsError: true, + }, + }, + }, + { + name: "parallel tool_results expand to multiple rows", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ToolUseID: "a", Content: []mcp.Content{&mcp.TextContent{Text: "1"}}}, + &mcp.ToolResultContent{ToolUseID: "b", Content: []mcp.Content{&mcp.TextContent{Text: "2"}}}, + }}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleTool, Content: "1", ToolCallID: "a"}, + {Role: chat.MessageRoleTool, Content: "2", ToolCallID: "b"}, + }, + }, + { + name: "too many messages is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: tooManyV2Messages(maxSamplingMessages + 1), + }, + wantErr: true, + }, + { + name: "nil message entry is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{nil}, + }, + wantErr: true, + }, + { + name: "oversize text block is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Repeat("a", maxSamplingTextBytes+1)}, + }}, + }, + }, + wantErr: true, + }, + { + name: "empty messages is rejected", + req: &mcp.CreateMessageWithToolsParams{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := samplingMessagesV2ToChat(tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func tooManyV2Messages(n int) []*mcp.SamplingMessageV2 { + out := make([]*mcp.SamplingMessageV2, n) + for i := range out { + out[i] = &mcp.SamplingMessageV2{ + Role: "user", + Content: []mcp.Content{&mcp.TextContent{Text: "x"}}, + } + } + return out +} + +func TestSamplingToolsToChat(t *testing.T) { + t.Parallel() + + t.Run("nil input returns nil", func(t *testing.T) { + t.Parallel() + assert.Nil(t, samplingToolsToChat(nil)) + }) + + t.Run("converts and preserves schema", func(t *testing.T) { + t.Parallel() + schema := map[string]any{"type": "object"} + got := samplingToolsToChat([]*mcp.Tool{ + {Name: "lookup", Description: "look up a thing", InputSchema: schema}, + nil, // skipped + {Name: "other"}, + }) + require.Len(t, got, 2) + assert.Equal(t, "lookup", got[0].Name) + assert.Equal(t, "mcp-sampling", got[0].Category) + assert.Equal(t, "look up a thing", got[0].Description) + assert.Equal(t, schema, got[0].Parameters) + assert.NotNil(t, got[0].Handler) + assert.Equal(t, "other", got[1].Name) + }) + + t.Run("noOp handler returns error result", func(t *testing.T) { + t.Parallel() + res, err := noOpSamplingToolHandler(t.Context(), tools.ToolCall{}) + require.NoError(t, err) + require.NotNil(t, res) + assert.True(t, res.IsError) + }) +} + +func TestBuildSamplingWithToolsContent(t *testing.T) { + t.Parallel() + + t.Run("text only", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("hello world", nil) + require.Len(t, got, 1) + text, ok := got[0].(*mcp.TextContent) + require.True(t, ok) + assert.Equal(t, "hello world", text.Text) + }) + + t.Run("tool calls only — empty text is dropped", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent(" ", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn", Arguments: `{"x":1}`}}, + }) + require.Len(t, got, 1) + tu, ok := got[0].(*mcp.ToolUseContent) + require.True(t, ok) + assert.Equal(t, "a", tu.ID) + assert.Equal(t, "fn", tu.Name) + assert.Equal(t, map[string]any{"x": float64(1)}, tu.Input) + }) + + t.Run("text plus parallel tool calls", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("ok", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn1", Arguments: `{}`}}, + {ID: "b", Function: tools.FunctionCall{Name: "fn2", Arguments: `{}`}}, + }) + require.Len(t, got, 3) + _, isText := got[0].(*mcp.TextContent) + _, isToolA := got[1].(*mcp.ToolUseContent) + _, isToolB := got[2].(*mcp.ToolUseContent) + assert.True(t, isText) + assert.True(t, isToolA) + assert.True(t, isToolB) + }) + + t.Run("malformed JSON args fall back to empty input", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn", Arguments: `not json`}}, + }) + require.Len(t, got, 1) + tu, ok := got[0].(*mcp.ToolUseContent) + require.True(t, ok) + assert.Equal(t, map[string]any{}, tu.Input) + }) +} + +func TestSamplingWithToolsHandler_LimitRejection(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{} + _, err := r.samplingWithToolsHandler(t.Context(), &mcp.CreateMessageWithToolsParams{ + Tools: make([]*mcp.Tool, maxSamplingTools+1), + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hi"}}}, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "tools") +} + +// fakeStream feeds a fixed sequence of MessageStreamResponse values into +// drainSamplingStreamWithTools for unit testing. +type fakeStream struct { + responses []chat.MessageStreamResponse + idx int + closed bool +} + +func (f *fakeStream) Recv() (chat.MessageStreamResponse, error) { + if f.idx >= len(f.responses) { + return chat.MessageStreamResponse{}, io.EOF + } + resp := f.responses[f.idx] + f.idx++ + return resp, nil +} + +func (f *fakeStream) Close() { + f.closed = true +} + +func TestDrainSamplingStreamWithTools(t *testing.T) { + t.Parallel() + + t.Run("plain text completion", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{Content: "hello "}}}}, + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{Content: "world"}, FinishReason: chat.FinishReasonStop}}}, + }} + text, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Equal(t, "hello world", text) + assert.Empty(t, calls) + assert.Equal(t, chat.FinishReasonStop, fr) + assert.True(t, s.closed) + }) + + t.Run("tool call aggregation across chunks", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "c1", Type: "function", Function: tools.FunctionCall{Name: "fn", Arguments: `{"a":`}}, + }}}}}, + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "c1", Function: tools.FunctionCall{Arguments: `1}`}}, + }}, FinishReason: chat.FinishReasonToolCalls}}}, + }} + text, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Empty(t, text) + require.Len(t, calls, 1) + assert.Equal(t, "c1", calls[0].ID) + assert.Equal(t, "fn", calls[0].Function.Name) + assert.Equal(t, `{"a":1}`, calls[0].Function.Arguments) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + // Sanity-check that the JSON we accumulated is parseable. + var v map[string]any + require.NoError(t, json.Unmarshal([]byte(calls[0].Function.Arguments), &v)) + }) + + t.Run("parallel tool calls collected by ID", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn1", Arguments: `{}`}}, + {ID: "b", Function: tools.FunctionCall{Name: "fn2", Arguments: `{}`}}, + }}, FinishReason: chat.FinishReasonToolCalls}}}, + }} + _, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + require.Len(t, calls, 2) + assert.Equal(t, "a", calls[0].ID) + assert.Equal(t, "b", calls[1].ID) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) + + t.Run("inferred tool_calls when provider omits finish reason", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "x", Function: tools.FunctionCall{Name: "fn", Arguments: `{}`}}, + }}}}}, + }} + _, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + require.Len(t, calls, 1) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) + + t.Run("stop reconciled to tool_calls when calls present", func(t *testing.T) { + t.Parallel() + // Provider says "stop" but also emits tool calls — reconciliation + // should treat this as a tool-call turn (the early-exit on stop fires + // in handleStream-style aggregation, then reconciliation upgrades). + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "x", Function: tools.FunctionCall{Name: "fn", Arguments: `{}`}}, + }}, FinishReason: chat.FinishReasonStop}}}, + }} + _, _, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) +} diff --git a/pkg/tools/capabilities.go b/pkg/tools/capabilities.go index 878c7feb9..3609f44f0 100644 --- a/pkg/tools/capabilities.go +++ b/pkg/tools/capabilities.go @@ -53,6 +53,14 @@ type Sampleable interface { SetSamplingHandler(handler SamplingHandler) } +// SampleableWithTools is implemented by toolsets that support MCP sampling +// requests carrying a tools array (sampling-with-tools). The handler is +// invoked instead of the basic SamplingHandler when both are registered and +// the SDK negotiates the with-tools variant on the wire. +type SampleableWithTools interface { + SetSamplingWithToolsHandler(handler SamplingWithToolsHandler) +} + // OAuthCapable is implemented by toolsets that support OAuth flows. type OAuthCapable interface { SetOAuthSuccessHandler(handler func()) @@ -81,16 +89,19 @@ type ChangeNotifier interface { } // ConfigureHandlers sets all applicable handlers on a toolset. -// It checks for Elicitable, Sampleable and OAuthCapable interfaces and -// configures them. This is a convenience function that handles the capability -// checking internally. -func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, samplingHandler SamplingHandler, oauthHandler func(), managedOAuth bool, unmanagedOAuthRedirectURI string) { +// It checks for Elicitable, Sampleable, SampleableWithTools, and OAuthCapable +// interfaces and configures them. This is a convenience function that handles +// the capability checking internally. +func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, samplingHandler SamplingHandler, samplingWithToolsHandler SamplingWithToolsHandler, oauthHandler func(), managedOAuth bool, unmanagedOAuthRedirectURI string) { if e, ok := As[Elicitable](ts); ok { e.SetElicitationHandler(elicitHandler) } if s, ok := As[Sampleable](ts); ok { s.SetSamplingHandler(samplingHandler) } + if s, ok := As[SampleableWithTools](ts); ok { + s.SetSamplingWithToolsHandler(samplingWithToolsHandler) + } if o, ok := As[OAuthCapable](ts); ok { o.SetOAuthSuccessHandler(oauthHandler) o.SetManagedOAuth(managedOAuth) diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 1b08f1e89..ffa4cf2eb 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -137,6 +137,7 @@ type mcpClient interface { GetPrompt(ctx context.Context, request *mcp.GetPromptParams) (*mcp.GetPromptResult, error) SetElicitationHandler(handler tools.ElicitationHandler) SetSamplingHandler(handler tools.SamplingHandler) + SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) SetOAuthSuccessHandler(handler func()) SetManagedOAuth(managed bool) SetUnmanagedOAuthRedirectURI(uri string) @@ -198,11 +199,12 @@ var ( // Verify that Toolset implements optional capability interfaces var ( - _ tools.Instructable = (*Toolset)(nil) - _ tools.Elicitable = (*Toolset)(nil) - _ tools.Sampleable = (*Toolset)(nil) - _ tools.OAuthCapable = (*Toolset)(nil) - _ tools.ChangeNotifier = (*Toolset)(nil) + _ tools.Instructable = (*Toolset)(nil) + _ tools.Elicitable = (*Toolset)(nil) + _ tools.Sampleable = (*Toolset)(nil) + _ tools.SampleableWithTools = (*Toolset)(nil) + _ tools.OAuthCapable = (*Toolset)(nil) + _ tools.ChangeNotifier = (*Toolset)(nil) ) // NewToolsetCommand creates a new MCP toolset from a command. @@ -501,7 +503,9 @@ func (c *clientConnector) Connect(ctx context.Context) (lifecycle.Session, error Form: &mcp.FormElicitationCapabilities{}, URL: &mcp.URLElicitationCapabilities{}, }, - Sampling: &mcp.SamplingCapabilities{}, + Sampling: &mcp.SamplingCapabilities{ + Tools: &mcp.SamplingToolsCapabilities{}, + }, }, }, } @@ -845,6 +849,10 @@ func (ts *Toolset) SetSamplingHandler(handler tools.SamplingHandler) { ts.mcpClient.SetSamplingHandler(handler) } +func (ts *Toolset) SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) { + ts.mcpClient.SetSamplingWithToolsHandler(handler) +} + func (ts *Toolset) SetOAuthSuccessHandler(handler func()) { ts.mcpClient.SetOAuthSuccessHandler(handler) } diff --git a/pkg/tools/mcp/mcp_test.go b/pkg/tools/mcp/mcp_test.go index dd89d82c1..8fa4e3755 100644 --- a/pkg/tools/mcp/mcp_test.go +++ b/pkg/tools/mcp/mcp_test.go @@ -44,6 +44,8 @@ func (m *mockMCPClient) SetElicitationHandler(tools.ElicitationHandler) {} func (m *mockMCPClient) SetSamplingHandler(tools.SamplingHandler) {} +func (m *mockMCPClient) SetSamplingWithToolsHandler(tools.SamplingWithToolsHandler) {} + func (m *mockMCPClient) SetOAuthSuccessHandler(func()) {} func (m *mockMCPClient) SetManagedOAuth(bool) {} diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index 8ef8aaa0c..cd4a32271 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -69,11 +69,13 @@ func (m *failingInitClient) GetPrompt(context.Context, *gomcp.GetPromptParams) ( func (m *failingInitClient) SetElicitationHandler(tools.ElicitationHandler) {} func (m *failingInitClient) SetSamplingHandler(tools.SamplingHandler) {} -func (m *failingInitClient) SetOAuthSuccessHandler(func()) {} -func (m *failingInitClient) SetManagedOAuth(bool) {} -func (m *failingInitClient) SetUnmanagedOAuthRedirectURI(string) {} -func (m *failingInitClient) SetToolListChangedHandler(func()) {} -func (m *failingInitClient) SetPromptListChangedHandler(func()) {} +func (m *failingInitClient) SetSamplingWithToolsHandler(tools.SamplingWithToolsHandler) { +} +func (m *failingInitClient) SetOAuthSuccessHandler(func()) {} +func (m *failingInitClient) SetManagedOAuth(bool) {} +func (m *failingInitClient) SetUnmanagedOAuthRedirectURI(string) {} +func (m *failingInitClient) SetToolListChangedHandler(func()) {} +func (m *failingInitClient) SetPromptListChangedHandler(func()) {} func (m *failingInitClient) Wait() error { m.mu.Lock() diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index f4c5d7f54..e633f8163 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -90,12 +90,19 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq toolChanged, promptChanged := c.notificationHandlers() + // Sampling: prefer the with-tools handler when registered. The SDK's two + // CreateMessage* handlers are mutually exclusive, so populate exactly one. opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, - CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } + switch { + case c.samplingWithToolsHandler != nil: + opts.CreateMessageWithToolsHandler = c.handleSamplingWithToolsRequest + case c.samplingHandler != nil: + opts.CreateMessageHandler = c.handleSamplingRequest + } client := gomcp.NewClient(impl, opts) diff --git a/pkg/tools/mcp/session_client.go b/pkg/tools/mcp/session_client.go index 86852a861..473a912e1 100644 --- a/pkg/tools/mcp/session_client.go +++ b/pkg/tools/mcp/session_client.go @@ -23,6 +23,7 @@ type sessionClient struct { promptListChangedHandler func() elicitationHandler tools.ElicitationHandler samplingHandler tools.SamplingHandler + samplingWithToolsHandler tools.SamplingWithToolsHandler oauthSuccessHandler func() mu sync.RWMutex } @@ -188,6 +189,40 @@ func (c *sessionClient) SetSamplingHandler(handler tools.SamplingHandler) { c.mu.Unlock() } +// handleSamplingWithToolsRequest forwards incoming sampling/createMessage +// requests that may include tools to the registered handler. It is used as +// the gomcp CreateMessageWithToolsHandler callback for both stdio and remote +// clients when the with-tools handler is registered. +func (c *sessionClient) handleSamplingWithToolsRequest(ctx context.Context, req *gomcp.CreateMessageWithToolsRequest) (*gomcp.CreateMessageWithToolsResult, error) { + slog.DebugContext(ctx, "Received sampling-with-tools request from MCP server", + "messages", len(req.Params.Messages), + "tools", len(req.Params.Tools), + ) + + c.mu.RLock() + handler := c.samplingWithToolsHandler + c.mu.RUnlock() + + if handler == nil { + return nil, errors.New("no sampling-with-tools handler configured") + } + + result, err := handler(ctx, req.Params) + if err != nil { + return nil, fmt.Errorf("sampling failed: %w", err) + } + + return result, nil +} + +// SetSamplingWithToolsHandler sets the handler that processes sampling +// requests carrying a tools array from the MCP server. +func (c *sessionClient) SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) { + c.mu.Lock() + c.samplingWithToolsHandler = handler + c.mu.Unlock() +} + // requestElicitation invokes the registered elicitation handler directly. // This is used by the OAuth transport to trigger elicitation outside of // the normal MCP request flow. diff --git a/pkg/tools/mcp/stdio.go b/pkg/tools/mcp/stdio.go index feb4e3ac5..5234bb686 100644 --- a/pkg/tools/mcp/stdio.go +++ b/pkg/tools/mcp/stdio.go @@ -38,13 +38,20 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeRequ toolChanged, promptChanged := c.notificationHandlers() - // Create client options with elicitation, sampling, and notification support + // Create client options with elicitation, sampling, and notification support. + // Sampling: prefer the with-tools handler when registered. The SDK's two + // CreateMessage* handlers are mutually exclusive, so populate exactly one. opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, - CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } + switch { + case c.samplingWithToolsHandler != nil: + opts.CreateMessageWithToolsHandler = c.handleSamplingWithToolsRequest + case c.samplingHandler != nil: + opts.CreateMessageHandler = c.handleSamplingRequest + } client := gomcp.NewClient(&gomcp.Implementation{ Name: "docker agent", diff --git a/pkg/tools/sampling.go b/pkg/tools/sampling.go index 0bdb24e35..5af98cd52 100644 --- a/pkg/tools/sampling.go +++ b/pkg/tools/sampling.go @@ -15,3 +15,11 @@ import ( // expected to call the host's model with the supplied messages and return // the model's response (or an error if the request was declined or failed). type SamplingHandler func(ctx context.Context, req *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) + +// SamplingWithToolsHandler handles sampling/createMessage requests that may +// involve tool use. The request carries a tools array and supports messages +// with multi-block content (tool_use, tool_result). The handler is expected +// to forward the tools to the host's model and return any tool_use blocks +// the model emits — the requesting MCP server executes the tools and +// continues the loop in a follow-up sampling request. +type SamplingWithToolsHandler func(ctx context.Context, req *mcp.CreateMessageWithToolsParams) (*mcp.CreateMessageWithToolsResult, error) From 472d8d7df24a4dbcd70ef981bd541fc0bc665772 Mon Sep 17 00:00:00 2001 From: Eron Wright Date: Sat, 6 Jun 2026 17:52:43 -0700 Subject: [PATCH 2/2] Add e2e test for MCP sampling-with-tools round-trip Mounts an in-process gomcp.NewServer on an httptest server via StreamableHTTPHandler. Its one tool, ask_with_calculator, runs a sampling loop: sends sampling/createMessage with a calculator tool, gets a tool_use back from the host LLM, "executes" the calculator, sends a follow-up sampling request carrying the tool_result, and returns the final text. The Gemini side is recorded once and replayed on subsequent runs, so the test runs offline in CI. --- e2e/sampling_test.go | 191 ++++++++++++++++++ .../TestExec_Gemini_SamplingWithTools.yaml | 99 +++++++++ 2 files changed, 290 insertions(+) create mode 100644 e2e/sampling_test.go create mode 100644 e2e/testdata/cassettes/TestExec_Gemini_SamplingWithTools.yaml diff --git a/e2e/sampling_test.go b/e2e/sampling_test.go new file mode 100644 index 000000000..f343f5471 --- /dev/null +++ b/e2e/sampling_test.go @@ -0,0 +1,191 @@ +package e2e_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" +) + +// TestExec_Gemini_SamplingWithTools exercises the MCP sampling-with-tools +// round-trip end-to-end. An in-process gomcp.NewServer is mounted on an +// httptest server via StreamableHTTPHandler. It exposes one tool, +// ask_with_calculator, whose handler runs a sampling loop: it sends +// sampling/createMessage with a tools array, gets a tool_use back from the +// host LLM, "executes" the calculator, sends a follow-up sampling request +// carrying the tool_result, and returns the final text. +func TestExec_Gemini_SamplingWithTools(t *testing.T) { + mcpURL := startSamplingToolsServer(t) + yamlPath := writeSamplingToolsAgent(t, mcpURL) + + out := runCLI(t, "run", "--exec", "--yolo", yamlPath, "--model=google/gemini-2.5-flash", "What is 17 plus 25?") + + require.Contains(t, out, "ask_with_calculator") + require.Contains(t, out, "42") +} + +// startSamplingToolsServer mounts an MCP server on an httptest server and +// returns its URL. The server exposes a single tool whose handler drives a +// sampling-with-tools loop against the connecting client. +func startSamplingToolsServer(t *testing.T) string { + t.Helper() + + server := gomcp.NewServer(&gomcp.Implementation{ + Name: "sampling-tools-test", + Version: "0.0.1", + }, nil) + + gomcp.AddTool(server, &gomcp.Tool{ + Name: "ask_with_calculator", + Description: "Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.", + }, askWithCalculator) + + handler := gomcp.NewStreamableHTTPHandler( + func(*http.Request) *gomcp.Server { return server }, + nil, + ) + httpSrv := httptest.NewServer(handler) + t.Cleanup(httpSrv.Close) + + return httpSrv.URL +} + +func writeSamplingToolsAgent(t *testing.T, mcpURL string) string { + t.Helper() + yamlPath := filepath.Join(t.TempDir(), "agent.yaml") + agentYAML := fmt.Appendf(nil, `agents: + root: + model: google/gemini-2.5-flash + description: "Test agent for MCP sampling-with-tools end-to-end verification" + instruction: | + You have access to one tool: ask_with_calculator. Whenever the user asks + a math word problem, call ask_with_calculator with the user's question. + Then report its answer verbatim to the user. + toolsets: + - type: mcp + allow_private_ips: true + remote: + url: %s + transport_type: streamable +`, mcpURL) + require.NoError(t, os.WriteFile(yamlPath, agentYAML, 0o644)) + return yamlPath +} + +type askInput struct { + Question string `json:"question" jsonschema:"the natural-language question to answer with help of the calculator"` +} + +func askWithCalculator(ctx context.Context, req *gomcp.CallToolRequest, in askInput) (*gomcp.CallToolResult, any, error) { + calculator := &gomcp.Tool{ + Name: "calculator", + Description: "Add two integers. Returns the sum.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{"type": "integer"}, + "y": map[string]any{"type": "integer"}, + }, + "required": []string{"x", "y"}, + }, + } + + messages := []*gomcp.SamplingMessageV2{{ + Role: "user", + Content: []gomcp.Content{&gomcp.TextContent{Text: in.Question}}, + }} + + for round := 1; round <= 4; round++ { + res, err := req.Session.CreateMessageWithTools(ctx, &gomcp.CreateMessageWithToolsParams{ + MaxTokens: 1024, + Messages: messages, + Tools: []*gomcp.Tool{calculator}, + SystemPrompt: "You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence.", + }) + if err != nil { + return nil, nil, fmt.Errorf("sampling round %d: %w", round, err) + } + + messages = append(messages, &gomcp.SamplingMessageV2{ + Role: res.Role, + Content: res.Content, + }) + + var toolUses []*gomcp.ToolUseContent + var finalText strings.Builder + for _, c := range res.Content { + switch v := c.(type) { + case *gomcp.ToolUseContent: + toolUses = append(toolUses, v) + case *gomcp.TextContent: + finalText.WriteString(v.Text) + } + } + + if len(toolUses) == 0 { + return &gomcp.CallToolResult{ + Content: []gomcp.Content{&gomcp.TextContent{Text: finalText.String()}}, + }, nil, nil + } + + toolResults := make([]gomcp.Content, 0, len(toolUses)) + for _, tu := range toolUses { + result, err := runCalculator(tu) + if err != nil { + toolResults = append(toolResults, &gomcp.ToolResultContent{ + ToolUseID: tu.ID, + Content: []gomcp.Content{&gomcp.TextContent{Text: err.Error()}}, + IsError: true, + }) + continue + } + toolResults = append(toolResults, &gomcp.ToolResultContent{ + ToolUseID: tu.ID, + Content: []gomcp.Content{&gomcp.TextContent{Text: result}}, + }) + } + + messages = append(messages, &gomcp.SamplingMessageV2{ + Role: "user", + Content: toolResults, + }) + } + + return nil, nil, fmt.Errorf("sampling loop did not terminate within 4 rounds") +} + +func runCalculator(tu *gomcp.ToolUseContent) (string, error) { + if tu.Name != "calculator" { + return "", fmt.Errorf("unknown tool: %s", tu.Name) + } + x, errX := toInt(tu.Input["x"]) + y, errY := toInt(tu.Input["y"]) + if errX != nil || errY != nil { + raw, _ := json.Marshal(tu.Input) + return "", fmt.Errorf("calculator expects integer x and y, got %s", raw) + } + return fmt.Sprintf("%d", x+y), nil +} + +func toInt(v any) (int64, error) { + switch n := v.(type) { + case float64: + return int64(n), nil + case int64: + return n, nil + case int: + return int64(n), nil + case json.Number: + return n.Int64() + default: + return 0, fmt.Errorf("not a number: %T", v) + } +} diff --git a/e2e/testdata/cassettes/TestExec_Gemini_SamplingWithTools.yaml b/e2e/testdata/cassettes/TestExec_Gemini_SamplingWithTools.yaml new file mode 100644 index 000000000..74a1bb42b --- /dev/null +++ b/e2e/testdata/cassettes/TestExec_Gemini_SamplingWithTools.yaml @@ -0,0 +1,99 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: generativelanguage.googleapis.com + body: | + {"contents":[{"parts":[{"text":"You have access to one tool: ask_with_calculator. Whenever the user asks\na math word problem, call ask_with_calculator with the user's question.\nThen report its answer verbatim to the user.\n"}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"}],"generationConfig":{},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.","name":"ask_with_calculator","parameters":{"properties":{"question":{"description":"the natural-language question to answer with help of the calculator","type":"string"}},"required":["question"],"type":"object"}}]}]} + form: + alt: + - sse + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"functionCall\": {\"name\": \"ask_with_calculator\",\"args\": {\"question\": \"What is 17 plus 25?\"}},\"thoughtSignature\": \"CiQBDDnWx0f32fnXCeUit6K5jCaDOjrvkvnvsHY5MIV4ozFP2DkKWQEMOdbH80G2+BEkAXkEaOtaQkTeMtam0g5v81QW4CbZq0EWzMpBvR4qorY0uY1UsLTkkoSC2Jrwtbx9HKcx2O5eg8e7glVpviskbssHnvs/Ym050fnXqUPVCswBAQw51sfCV/T8djp98wEvfztWe9iFUQ6D3oT94v4X9lE0WSuXqEhk/GceAChWa2DgtoR9Su4AzkZJjl1Yi4mR4DpAZp1/+jQBMgm2v8g+zUWqWt4beV/YmMmJREjaI56NpBq19U+t1WzlTk5aQyT7KcH6EoFdJaEyCiY+B1pdo4Nm4HhDb9J+C0evbqlU7jXGy9bn5GBjeAS7xMdVh/IJ2oLmabcwdA4htWMWKLYEyUb87kU+cvumMYbhTQvn40csP+j3IKVazFhsNcmQ\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0,\"finishMessage\": \"Model generated function call(s).\"}],\"usageMetadata\": {\"promptTokenCount\": 128,\"candidatesTokenCount\": 26,\"totalTokenCount\": 215,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 128}],\"thoughtsTokenCount\": 61,\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hL8kaue3J7CEz7IPw8H7kAY\"}\r\n\r\n" + headers: {} + status: 200 OK + code: 200 + duration: 1.079262793s + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: generativelanguage.googleapis.com + body: | + {"contents":[{"parts":[{"text":"You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence."}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"}],"generationConfig":{"maxOutputTokens":1024,"thinkingConfig":{"thinkingBudget":0}},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Add two integers. Returns the sum.","name":"calculator","parameters":{"properties":{"x":{"type":"integer"},"y":{"type":"integer"}},"required":["x","y"],"type":"object"}}]}]} + form: + alt: + - sse + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"functionCall\": {\"name\": \"calculator\",\"args\": {\"y\": 25,\"x\": 17}}}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0,\"finishMessage\": \"Model generated function call(s).\"}],\"usageMetadata\": {\"promptTokenCount\": 96,\"candidatesTokenCount\": 20,\"totalTokenCount\": 116,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 96}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hb8kas6hIuPqz7IP16znuQ8\"}\r\n\r\n" + headers: {} + status: 200 OK + code: 200 + duration: 788.273412ms + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: generativelanguage.googleapis.com + body: | + {"contents":[{"parts":[{"text":"You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence."}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"},{"parts":[{"functionCall":{"args":{"x":17,"y":25},"name":"calculator"},"thoughtSignature":"c2tpcF90aG91Z2h0X3NpZ25hdHVyZV92YWxpZGF0b3I="}],"role":"model"},{"parts":[{"functionResponse":{"name":"call_f15543e2-b952-4c7d-ab72-bcc56f00c6f6","response":{"result":"42"}}}],"role":"user"}],"generationConfig":{"maxOutputTokens":1024,"thinkingConfig":{"thinkingBudget":0}},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Add two integers. Returns the sum.","name":"calculator","parameters":{"properties":{"x":{"type":"integer"},"y":{"type":"integer"}},"required":["x","y"],"type":"object"}}]}]} + form: + alt: + - sse + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"1\"}],\"role\": \"model\"},\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 164,\"candidatesTokenCount\": 1,\"totalTokenCount\": 165,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 164}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kapnPFPDrz7IPoInsiQQ\"}\r\n\r\ndata: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"7 plus 25 is 42.\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 164,\"candidatesTokenCount\": 11,\"totalTokenCount\": 175,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 164}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kapnPFPDrz7IPoInsiQQ\"}\r\n\r\n" + headers: {} + status: 200 OK + code: 200 + duration: 635.846943ms + - id: 3 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: generativelanguage.googleapis.com + body: | + {"contents":[{"parts":[{"text":"You have access to one tool: ask_with_calculator. Whenever the user asks\na math word problem, call ask_with_calculator with the user's question.\nThen report its answer verbatim to the user.\n"}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"},{"parts":[{"functionCall":{"args":{"question":"What is 17 plus 25?"},"name":"ask_with_calculator"},"thoughtSignature":"CiQBDDnWx0f32fnXCeUit6K5jCaDOjrvkvnvsHY5MIV4ozFP2DkKWQEMOdbH80G2+BEkAXkEaOtaQkTeMtam0g5v81QW4CbZq0EWzMpBvR4qorY0uY1UsLTkkoSC2Jrwtbx9HKcx2O5eg8e7glVpviskbssHnvs/Ym050fnXqUPVCswBAQw51sfCV/T8djp98wEvfztWe9iFUQ6D3oT94v4X9lE0WSuXqEhk/GceAChWa2DgtoR9Su4AzkZJjl1Yi4mR4DpAZp1/+jQBMgm2v8g+zUWqWt4beV/YmMmJREjaI56NpBq19U+t1WzlTk5aQyT7KcH6EoFdJaEyCiY+B1pdo4Nm4HhDb9J+C0evbqlU7jXGy9bn5GBjeAS7xMdVh/IJ2oLmabcwdA4htWMWKLYEyUb87kU+cvumMYbhTQvn40csP+j3IKVazFhsNcmQ"}],"role":"model"},{"parts":[{"functionResponse":{"name":"call_546b970f-9898-4b1f-94e0-e6b140fcbdea","response":{"result":"17 plus 25 is 42."}}}],"role":"user"}],"generationConfig":{},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.","name":"ask_with_calculator","parameters":{"properties":{"question":{"description":"the natural-language question to answer with help of the calculator","type":"string"}},"required":["question"],"type":"object"}}]}]} + form: + alt: + - sse + url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"17 plus \"}],\"role\": \"model\"},\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 272,\"candidatesTokenCount\": 3,\"totalTokenCount\": 275,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 272}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kavfTPOTtz7IPmZDEwQE\"}\r\n\r\ndata: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"25 is 42.\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 272,\"candidatesTokenCount\": 10,\"totalTokenCount\": 282,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 272}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kavfTPOTtz7IPmZDEwQE\"}\r\n\r\n" + headers: {} + status: 200 OK + code: 200 + duration: 537.623295ms