diff --git a/client/client.go b/client/client.go index 220786b68..192588b52 100644 --- a/client/client.go +++ b/client/client.go @@ -502,6 +502,19 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra } } + // Fix content parsing - HTTP transport unmarshals TextContent as map[string]any + // Use the helper function to properly handle content from different transports + for i := range params.Messages { + if contentMap, ok := params.Messages[i].Content.(map[string]any); ok { + // Parse the content map into a proper Content type + content, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("failed to parse content for message %d: %w", i, err) + } + params.Messages[i].Content = content + } + } + // Create the MCP request mcpRequest := mcp.CreateMessageRequest{ Request: mcp.Request{ diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 9d5218139..6339b6110 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -594,10 +594,14 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool { func (c *StreamableHTTP) listenForever(ctx context.Context) { c.logger.Infof("listening to server forever") for { - connectCtx, cancel := context.WithCancel(ctx) - err := c.createGETConnectionToServer(connectCtx) - cancel() - + // Use the original context for continuous listening - no per-iteration timeout + // The SSE connection itself will detect disconnections via the underlying HTTP transport, + // and the context cancellation will propagate from the parent to stop listening gracefully. + // We don't add an artificial timeout here because: + // 1. Persistent SSE connections are meant to stay open indefinitely + // 2. Network-level timeouts and keep-alives handle connection health + // 3. Context cancellation (user-initiated or system shutdown) provides clean shutdown + err := c.createGETConnectionToServer(ctx) if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") diff --git a/e2e/sampling_http_test.go b/e2e/sampling_http_test.go new file mode 100644 index 000000000..8c81ddcba --- /dev/null +++ b/e2e/sampling_http_test.go @@ -0,0 +1,549 @@ +package e2e + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// TestSamplingHandler implements client.SamplingHandler for e2e testing +type TestSamplingHandler struct { + responses map[string]string + mutex sync.RWMutex +} + +func NewTestSamplingHandler() *TestSamplingHandler { + return &TestSamplingHandler{ + responses: make(map[string]string), + } +} + +func (h *TestSamplingHandler) SetResponse(question, response string) { + h.mutex.Lock() + defer h.mutex.Unlock() + h.responses[question] = response +} + +func (h *TestSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + log.Printf("[TestSamplingHandler] *** CLIENT RECEIVED SAMPLING REQUEST *** with %d messages", len(request.Messages)) + + if len(request.Messages) == 0 { + log.Printf("[TestSamplingHandler] ERROR: no messages provided") + return nil, fmt.Errorf("no messages provided") + } + + // Get the last user message + lastMessage := request.Messages[len(request.Messages)-1] + userText := "" + if textContent, ok := lastMessage.Content.(mcp.TextContent); ok { + userText = textContent.Text + } + + log.Printf("[TestSamplingHandler] CLIENT processing user text: '%s'", userText) + + h.mutex.RLock() + response, exists := h.responses[userText] + h.mutex.RUnlock() + + if !exists { + response = fmt.Sprintf("Test response to: '%s'", userText) + } + + log.Printf("[TestSamplingHandler] CLIENT Question: %s -> Response: %s", userText, response) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: response, + }, + }, + Model: "test-model-v1", + StopReason: "endTurn", + } + + log.Printf("[TestSamplingHandler] *** CLIENT SENDING SAMPLING RESPONSE *** with model: %s", result.Model) + return result, nil +} + +// getAvailablePort finds an available port for testing +func getAvailablePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer listener.Close() + return listener.Addr().(*net.TCPAddr).Port, nil +} + +func TestSamplingHTTPE2E(t *testing.T) { + if testing.Short() { + t.Skip("Skipping e2e test in short mode") + } + + log.Printf("[E2E Test] Starting Sampling HTTP E2E Test") + + // Get available port for HTTP server + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + serverURL := fmt.Sprintf("http://localhost:%d", port) + serverAddr := fmt.Sprintf(":%d", port) + + // Create test sampling handler with predefined responses + samplingHandler := NewTestSamplingHandler() + samplingHandler.SetResponse("What is the capital of France?", "Paris is the capital of France.") + samplingHandler.SetResponse("What is 2+2?", "2+2 equals 4.") + + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("e2e-sampling-server", "1.0.0") + mcpServer.EnableSampling() + + // Add tool that uses sampling - this is the "question" tool + mcpServer.AddTool(mcp.Tool{ + Name: "question", + Description: "Ask a question and get an answer using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + log.Printf("[E2E Test] Tool handler processing question: %s", question) + + // Create sampling request to send back to client + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + MaxTokens: 500, + Temperature: 0.7, + }, + } + + log.Printf("[E2E Test] *** SERVER SENDING SAMPLING REQUEST *** for question: %s", question) + + // Request sampling from client with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + serverFromCtx := server.ServerFromContext(ctx) + if serverFromCtx == nil { + log.Printf("[E2E Test] ERROR: No server in context") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Error: No server in context", + }, + }, + IsError: true, + }, nil + } + + log.Printf("[E2E Test] SERVER calling RequestSampling...") + + // Check what session we have + session := server.ClientSessionFromContext(ctx) + if session != nil { + log.Printf("[E2E Test] SERVER session ID: %s", session.SessionID()) + } else { + log.Printf("[E2E Test] SERVER ERROR: No session in context") + } + + // This creates the sampling request to the client + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + log.Printf("[E2E Test] *** SERVER SAMPLING REQUEST FAILED ***: %v", err) + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + log.Printf("[E2E Test] *** SERVER RECEIVED SAMPLING RESPONSE ***, model: %s", result.Model) + + // Extract response text + var responseText string + if textContent, ok := result.Content.(mcp.TextContent); ok { + responseText = textContent.Text + } else { + responseText = fmt.Sprintf("%v", result.Content) + } + + // Return sampling response as the question tool response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Answer: %s (Model: %s)", responseText, result.Model), + }, + }, + }, nil + }) + + // Start HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + log.Printf("[E2E Test] Starting HTTP server on %s", serverAddr) + if err := httpServer.Start(serverAddr); err != nil && err != http.ErrServerClosed { + log.Printf("[E2E Test] Server error: %v", err) + } + }() + + // Wait for server to start and be ready + time.Sleep(2 * time.Second) + + // Create HTTP transport for client connection to server - enable continuous listening for sampling + httpTransport, err := transport.NewStreamableHTTP(serverURL+"/mcp", transport.WithContinuousListening()) + if err != nil { + t.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + log.Printf("[E2E Test] HTTP transport created, will connect to: %s", serverURL+"/mcp") + + // Create HTTP client with sampling handler - this is the actual client connecting over HTTP + httpClient := client.NewClient(httpTransport, client.WithSamplingHandler(samplingHandler)) + defer httpClient.Close() + + // Start the HTTP client + ctx := context.Background() + if err := httpClient.Start(ctx); err != nil { + t.Fatalf("Failed to start HTTP client: %v", err) + } + + // Initialize MCP session over HTTP + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "e2e-http-test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by WithSamplingHandler + }, + }, + } + + initResponse, err := httpClient.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize HTTP session: %v", err) + } + + log.Printf("[E2E Test] HTTP session initialized. Server capabilities: %+v", initResponse.Capabilities) + log.Printf("[E2E Test] Client session ID: %s", httpTransport.GetSessionId()) + + // Verify sampling capability is supported + if initResponse.Capabilities.Sampling == nil { + t.Fatal("Server should support sampling capability") + } + + // Wait a bit more for continuous listening to establish + log.Printf("[E2E Test] Waiting for continuous listening connection to be established...") + time.Sleep(3 * time.Second) + log.Printf("[E2E Test] Continuous listening should now be established, proceeding with tests...") + + // Test Case 1: HTTP client calls "question" tool - complete e2e flow + t.Run("HTTPClientCallsQuestionTool", func(t *testing.T) { + log.Printf("[E2E Test] HTTP client calling 'question' tool") + + // Client calls "question" tool over HTTP + result, err := httpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "question", + Arguments: map[string]any{ + "question": "What is the capital of France?", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call question tool over HTTP: %v", err) + } + + if result.IsError { + t.Fatalf("Question tool returned error: %v", result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Question tool result should have content") + } + + // Verify response content + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E Test] Question tool response over HTTP: %s", responseText) + + // Verify the complete flow worked: client->server->sampling_request->client->sampling_response->server->tool_response->client + if !strings.Contains(responseText, "Paris is the capital of France") { + t.Errorf("Expected response to contain 'Paris is the capital of France', got: %s", responseText) + } + + if !strings.Contains(responseText, "test-model-v1") { + t.Errorf("Expected response to contain model name, got: %s", responseText) + } + }) + + // Test Case 2: Multiple HTTP sampling requests + t.Run("MultipleHTTPSamplingRequests", func(t *testing.T) { + questions := []string{ + "What is 2+2?", + "What is the capital of France?", + } + + expectedAnswers := []string{ + "2+2 equals 4", + "Paris is the capital of France", + } + + for i, question := range questions { + log.Printf("[E2E Test] HTTP client calling question tool with: %s", question) + result, err := httpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "question", + Arguments: map[string]any{ + "question": question, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call question tool for '%s': %v", question, err) + } + + if result.IsError { + t.Fatalf("Question tool returned error for '%s': %v", question, result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Question tool result should have content") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E Test] HTTP Response for '%s': %s", question, responseText) + + if !strings.Contains(responseText, expectedAnswers[i]) { + t.Errorf("Expected response to contain '%s', got: %s", expectedAnswers[i], responseText) + } + } + }) + + // Cleanup + httpClient.Close() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + t.Logf("Server shutdown error: %v", err) + } + + <-serverDone + log.Printf("[E2E Test] HTTP E2E test completed successfully") +} + +// TestSamplingHTTPBasic creates a basic HTTP sampling test +func TestSamplingHTTPBasic(t *testing.T) { + if testing.Short() { + t.Skip("Skipping HTTP test in short mode") + } + + log.Printf("[E2E HTTP Test] Starting basic HTTP sampling test") + + // Get available port + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + serverURL := fmt.Sprintf("http://localhost:%d", port) + serverAddr := fmt.Sprintf(":%d", port) + + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("e2e-http-server", "1.0.0") + mcpServer.EnableSampling() + + // Add simple echo tool (no sampling needed) + mcpServer.AddTool(mcp.Tool{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message := request.GetString("message", "") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + }, nil + }) + + // Start HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + log.Printf("[E2E HTTP Test] Starting server on %s", serverAddr) + if err := httpServer.Start(serverAddr); err != nil && err != http.ErrServerClosed { + log.Printf("[E2E HTTP Test] Server error: %v", err) + } + }() + + // Wait for server to start + time.Sleep(500 * time.Millisecond) + + // Create HTTP transport (no continuous listening for simple test) + httpTransport, err := transport.NewStreamableHTTP(serverURL + "/mcp") + if err != nil { + t.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create simple client (no sampling handler) + mcpClient := client.NewClient(httpTransport) + + // Start client + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err = mcpClient.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: "e2e-http-test-client", + Version: "1.0.0", + }, + }, + } + + initResponse, err := mcpClient.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Printf("[E2E HTTP Test] Session initialized. Server capabilities: %+v", initResponse.Capabilities) + + // Test basic tool call over HTTP + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{ + "message": "Hello HTTP MCP!", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call echo tool: %v", err) + } + + if result.IsError { + t.Fatalf("Tool returned error: %v", result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Tool result should have content") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E HTTP Test] Tool response: %s", responseText) + + if !strings.Contains(responseText, "Hello HTTP MCP!") { + t.Errorf("Expected response to contain 'Hello HTTP MCP!', got: %s", responseText) + } + + // Cleanup + mcpClient.Close() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + t.Logf("Server shutdown error: %v", err) + } + + <-serverDone + log.Printf("[E2E HTTP Test] HTTP test completed successfully") +} + +// TestMain sets up test environment +func TestMain(m *testing.M) { + // Enable debug logging for better visibility during tests + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + code := m.Run() + os.Exit(code) +} \ No newline at end of file diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go index ea887c588..a2ca13baf 100644 --- a/examples/sampling_server/main.go +++ b/examples/sampling_server/main.go @@ -83,7 +83,7 @@ func main() { Content: []mcp.Content{ mcp.TextContent{ Type: "text", - Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, getTextFromContent(result.Content)), + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, mcp.GetTextFromContent(result.Content)), }, }, }, nil @@ -125,21 +125,3 @@ func main() { log.Fatalf("Server error: %v", err) } } - -// Helper function to extract text from content -func getTextFromContent(content any) string { - switch c := content.(type) { - case mcp.TextContent: - return c.Text - case map[string]any: - // Handle JSON unmarshaled content - if text, ok := c["text"].(string); ok { - return text - } - return fmt.Sprintf("%v", content) - case string: - return c - default: - return fmt.Sprintf("%v", content) - } -} diff --git a/mcp/utils.go b/mcp/utils.go index 0a3cde236..971fb585b 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -941,3 +941,31 @@ func ParseStringMap(request CallToolRequest, key string, defaultValue map[string func ToBoolPtr(b bool) *bool { return &b } + +// GetTextFromContent extracts text from a Content interface that might be a TextContent struct +// or a map[string]any that was unmarshaled from JSON. This is useful when dealing with content +// that comes from different transport layers that may handle JSON differently. +// +// This function uses fallback behavior for non-text content - it returns a string representation +// via fmt.Sprintf for any content that cannot be extracted as text. This is a lossy operation +// intended for convenience in logging and display scenarios. +// +// For strict type validation, use ParseContent() instead, which returns an error for invalid content. +func GetTextFromContent(content any) string { + switch c := content.(type) { + case TextContent: + return c.Text + case map[string]any: + // Handle JSON unmarshaled content + if contentType, exists := c["type"]; exists && contentType == "text" { + if text, exists := c["text"].(string); exists { + return text + } + } + return fmt.Sprintf("%v", content) + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} diff --git a/server/stdio.go b/server/stdio.go index d80941c3d..80131f06c 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -606,7 +606,18 @@ func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { if err := json.Unmarshal(response.Result, &result); err != nil { samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) } else { - samplingResp.result = &result + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + samplingResp.err = fmt.Errorf("failed to parse sampling response content: %w", err) + } else { + result.Content = content + samplingResp.result = &result + } + } else { + samplingResp.result = &result + } } } diff --git a/server/streamable_http.go b/server/streamable_http.go index 056dc876c..9ad37fea1 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -309,7 +309,19 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + // Check if a persistent session exists (for sampling support), otherwise create ephemeral session + // Persistent sessions are created by GET (continuous listening) connections + var session *streamableHttpSession + if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { + if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { + session = persistentSession + } + } + + // Create ephemeral session if no persistent session exists + if session == nil { + session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + } // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -420,16 +432,23 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) - if err := s.server.RegisterSession(r.Context(), session); err != nil { - http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) - return + // Get or create session atomically to prevent TOCTOU races + // where concurrent GETs could both create and register duplicate sessions + var session *streamableHttpSession + newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + actual, loaded := s.activeSessions.LoadOrStore(sessionID, newSession) + session = actual.(*streamableHttpSession) + + if !loaded { + // We created a new session, need to register it + if err := s.server.RegisterSession(r.Context(), session); err != nil { + s.activeSessions.Delete(sessionID) + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + defer s.activeSessions.Delete(sessionID) } - defer s.server.UnregisterSession(r.Context(), sessionID) - - // Register session for sampling response delivery - s.activeSessions.Store(sessionID, session) - defer s.activeSessions.Delete(sessionID) // Set the client context before handling the message w.Header().Set("Content-Type", "text/event-stream") @@ -626,7 +645,8 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) } } else if responseMessage.Result != nil { - // Parse result + // Store the result to be unmarshaled later + response.result = responseMessage.Result } else { response.err = fmt.Errorf("sampling response has neither result nor error") } @@ -900,6 +920,17 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp if err := json.Unmarshal(response.result, &result); err != nil { return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) } + + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + // HTTP transport unmarshals Content as map[string]any, we need to convert it to the proper type + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("failed to parse sampling response content: %w", err) + } + result.Content = content + } + return &result, nil case <-ctx.Done(): return nil, ctx.Err() diff --git a/www/docs/pages/clients/advanced-sampling.mdx b/www/docs/pages/clients/advanced-sampling.mdx index 81a4cc9aa..02fc959a2 100644 --- a/www/docs/pages/clients/advanced-sampling.mdx +++ b/www/docs/pages/clients/advanced-sampling.mdx @@ -6,6 +6,24 @@ Learn how to implement MCP clients that can handle sampling requests from server Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result. +:::danger[Critical Security Requirement] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), sampling implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests. + +**You MUST implement approval flows that:** +- Present each sampling request to the user for review before execution +- Allow users to view and edit prompts before sending to the LLM +- Display generated responses for user approval before returning to the server +- Provide clear UI to accept or reject requests at each stage + +**Without human approval, your implementation:** +- Allows servers to make unauthorized LLM requests without user consent +- May expose sensitive information through unreviewed prompts +- Creates uncontrolled API costs from automated sampling +- Violates user trust and security best practices + +The examples below show basic handler implementation. **You must add approval logic** before using in production. +::: + ## Implementing a Sampling Handler Create a sampling handler by implementing the `SamplingHandler` interface: diff --git a/www/docs/pages/servers/advanced-sampling.mdx b/www/docs/pages/servers/advanced-sampling.mdx index 1bc05eb6e..4ae98fcd5 100644 --- a/www/docs/pages/servers/advanced-sampling.mdx +++ b/www/docs/pages/servers/advanced-sampling.mdx @@ -6,6 +6,25 @@ Learn how to implement MCP servers that can request LLM completions from clients Sampling allows MCP servers to request LLM completions from clients, enabling bidirectional communication where servers can leverage client-side LLM capabilities. This is particularly useful for tools that need to generate content, answer questions, or perform reasoning tasks. +:::info[User Consent Required] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), clients **SHOULD** implement human-in-the-loop approval for sampling requests. + +When you request sampling from a client: +- The user will typically be prompted to review and approve your request +- The user may modify your prompts before sending to their LLM +- The user may reject your request entirely +- Response times may be longer due to user interaction + +**Design your tools accordingly:** +- Provide clear descriptions of why sampling is needed +- Use descriptive system prompts explaining the purpose +- Handle rejection errors gracefully +- Consider timeouts for user approval delays +- Don't assume immediate or automatic approval + +Well-designed sampling requests improve user trust and approval rates. +::: + ## Enabling Sampling To enable sampling in your server, call `EnableSampling()` during server setup: diff --git a/www/docs/pages/transports/http.mdx b/www/docs/pages/transports/http.mdx index 39089ad6b..24ed0a392 100644 --- a/www/docs/pages/transports/http.mdx +++ b/www/docs/pages/transports/http.mdx @@ -689,6 +689,160 @@ This works for all MCP request types including: The headers are automatically populated by the transport layer and are available in your handlers without any additional configuration. +## Sampling Support + +StreamableHTTP transport now supports bidirectional sampling, allowing servers to request LLM completions from clients. This enables advanced scenarios where servers can leverage client-side LLM capabilities. + +:::warning[Security: Human-in-the-Loop Required] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling), implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests. + +**Your sampling handler implementation MUST:** +- Present sampling requests to users for review before execution +- Allow users to view and edit prompts before sending to the LLM +- Display generated responses for approval before returning to the server +- Provide clear UI to accept or reject sampling requests + +Failing to implement approval flows creates serious security and trust risks, including: +- Servers making unauthorized LLM requests on behalf of users +- Exposure of sensitive data through unreviewed prompts +- Uncontrolled API costs from automated sampling +- Lack of user consent for AI interactions + +See the [example implementation](#example-with-approval-flow) below for a reference approval pattern. +::: + +### Requirements for Sampling + +To enable sampling with StreamableHTTP transport, the client **must** use the `WithContinuousListening()` option: + +```go +// Client setup with sampling support +httpTransport, err := transport.NewStreamableHTTP( + serverURL, + transport.WithContinuousListening(), // Required for sampling +) + +// Create client with sampling handler +mcpClient := client.NewClient(httpTransport, + client.WithSamplingHandler(samplingHandler)) +``` + +Without `WithContinuousListening()`, the client won't maintain a persistent connection to receive sampling requests from the server. + +### Server-Side Implementation + +Enable sampling in your StreamableHTTP server: + +```go +mcpServer := server.NewMCPServer("HTTP Sampling Server", "1.0.0") +mcpServer.EnableSampling() + +// Add a tool that uses sampling +mcpServer.AddTool(mcp.Tool{ + Name: "ask-llm", + Description: "Ask the LLM a question", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "Question to ask", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question := mcp.ParseString(req, "question", "") + + // Request sampling from client + samplingRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + MaxTokens: 1000, + }, + } + + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Sampling failed: %v", err)), nil + } + + return mcp.NewToolResultText(mcp.GetTextFromContent(result.Content)), nil +}) +``` + +### How It Works + +1. **Persistent Connection**: When `WithContinuousListening()` is enabled, the client maintains a persistent SSE connection to the server +2. **Bidirectional Communication**: The server can send sampling requests through the SSE stream +3. **Response Channel**: The client responds to sampling requests via HTTP POST to the same endpoint +4. **Session Correlation**: Responses are correlated using session IDs to ensure they reach the correct handler + +### Limitations + +- Sampling requires `WithContinuousListening()` to maintain the SSE connection +- Without continuous listening, the transport operates in stateless request/response mode only +- Network interruptions may require reconnection and re-establishment of the sampling channel + +### Example with Approval Flow + +Here's a reference implementation showing proper human-in-the-loop approval: + +```go +type ApprovalSamplingHandler struct { + llmClient LLMClient // Your actual LLM client + ui UserInterface // Your UI for presenting requests to users +} + +func (h *ApprovalSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Step 1: Present the sampling request to the user for review + approved, modifiedRequest, err := h.ui.PresentSamplingRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("failed to get user approval: %w", err) + } + + if !approved { + return nil, fmt.Errorf("user rejected sampling request") + } + + // Step 2: Send the approved/modified request to the LLM + response, err := h.llmClient.CreateCompletion(ctx, modifiedRequest) + if err != nil { + return nil, fmt.Errorf("LLM request failed: %w", err) + } + + // Step 3: Present the response to the user for final approval + approved, modifiedResponse, err := h.ui.PresentSamplingResponse(ctx, response) + if err != nil { + return nil, fmt.Errorf("failed to get response approval: %w", err) + } + + if !approved { + return nil, fmt.Errorf("user rejected sampling response") + } + + // Step 4: Return the approved response to the server + return modifiedResponse, nil +} +``` + +**Key Points:** +- Users must explicitly approve both the request (before sending to LLM) and the response (before returning to server) +- Users can modify prompts or responses before approval +- Rejection at any stage returns an error to the server +- The UI should clearly display what the server is requesting and why + ## Next Steps - **[In-Process Transport](/transports/inprocess)** - Learn about embedded scenarios