From 09263bc5511695a77419b4c39244e7fc6eb0c70b Mon Sep 17 00:00:00 2001 From: Ramon Nogueira Date: Sun, 20 Jul 2025 19:46:51 -0700 Subject: [PATCH 1/7] feat: implement MCP elicitation support (#413) * Add ElicitationRequest, ElicitationResult, and related types to mcp/types.go * Implement server-side RequestElicitation method with session support * Add client-side ElicitationHandler interface and request handling * Implement elicitation in stdio and in-process transports * Add comprehensive tests following sampling patterns * Create elicitation example demonstrating usage patterns * Use 'Elicitation' prefix for type names to maintain clarity --- client/client.go | 66 ++++++- client/elicitation.go | 19 ++ client/elicitation_test.go | 225 +++++++++++++++++++++++ client/inprocess_elicitation_test.go | 206 +++++++++++++++++++++ client/transport/inprocess.go | 21 ++- examples/elicitation/main.go | 208 +++++++++++++++++++++ mcp/types.go | 56 ++++++ server/elicitation.go | 25 +++ server/elicitation_test.go | 263 +++++++++++++++++++++++++++ server/inprocess_session.go | 36 +++- server/server.go | 22 ++- server/session.go | 7 + server/stdio.go | 170 +++++++++++++++-- 13 files changed, 1291 insertions(+), 33 deletions(-) create mode 100644 client/elicitation.go create mode 100644 client/elicitation_test.go create mode 100644 client/inprocess_elicitation_test.go create mode 100644 examples/elicitation/main.go create mode 100644 server/elicitation.go create mode 100644 server/elicitation_test.go diff --git a/client/client.go b/client/client.go index 0d47fcbf3..21969789c 100644 --- a/client/client.go +++ b/client/client.go @@ -25,6 +25,7 @@ type Client struct { serverCapabilities mcp.ServerCapabilities protocolVersion string samplingHandler SamplingHandler + elicitationHandler ElicitationHandler } type ClientOption func(*Client) @@ -44,6 +45,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption { } } +// WithElicitationHandler sets the elicitation handler for the client. +// When set, the client will declare elicitation capability during initialization. +func WithElicitationHandler(handler ElicitationHandler) ClientOption { + return func(c *Client) { + c.elicitationHandler = handler + } +} + // WithSession assumes a MCP Session has already been initialized func WithSession() ClientOption { return func(c *Client) { @@ -174,6 +183,10 @@ func (c *Client) Initialize( if c.samplingHandler != nil { capabilities.Sampling = &struct{}{} } + // Add elicitation capability if handler is configured + if c.elicitationHandler != nil { + capabilities.Elicitation = &struct{}{} + } // Ensure we send a params object with all required fields params := struct { @@ -458,11 +471,13 @@ func (c *Client) Complete( } // handleIncomingRequest processes incoming requests from the server. -// This is the main entry point for server-to-client requests like sampling. +// This is the main entry point for server-to-client requests like sampling and elicitation. func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { switch request.Method { case string(mcp.MethodSamplingCreateMessage): return c.handleSamplingRequestTransport(ctx, request) + case string(mcp.MethodElicitationCreate): + return c.handleElicitationRequestTransport(ctx, request) default: return nil, fmt.Errorf("unsupported request method: %s", request.Method) } @@ -515,6 +530,55 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra return response, nil } + +// handleElicitationRequestTransport handles elicitation requests at the transport level. +func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.elicitationHandler == nil { + return nil, fmt.Errorf("no elicitation handler configured") + } + + // Parse the request parameters + var params mcp.ElicitationParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.ElicitationRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: params, + } + + // Call the elicitation handler + result, err := c.elicitationHandler.Elicit(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} + func listByPage[T any]( ctx context.Context, client *Client, diff --git a/client/elicitation.go b/client/elicitation.go new file mode 100644 index 000000000..92f519bf9 --- /dev/null +++ b/client/elicitation.go @@ -0,0 +1,19 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ElicitationHandler defines the interface for handling elicitation requests from servers. +// Clients can implement this interface to request additional information from users. +type ElicitationHandler interface { + // Elicit handles an elicitation request from the server and returns the user's response. + // The implementation should: + // 1. Present the request message to the user + // 2. Validate input against the requested schema + // 3. Allow the user to accept, decline, or cancel + // 4. Return the appropriate response + Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} diff --git a/client/elicitation_test.go b/client/elicitation_test.go new file mode 100644 index 000000000..a05be35c5 --- /dev/null +++ b/client/elicitation_test.go @@ -0,0 +1,225 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockElicitationHandler implements ElicitationHandler for testing +type mockElicitationHandler struct { + result *mcp.ElicitationResult + err error +} + +func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestClient_HandleElicitationRequest(t *testing.T) { + tests := []struct { + name string + handler ElicitationHandler + expectedError string + }{ + { + name: "no handler configured", + handler: nil, + expectedError: "no elicitation handler configured", + }, + { + name: "successful elicitation - accept", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "name": "test-project", + "framework": "react", + }, + }, + }, + }, + }, + { + name: "successful elicitation - decline", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeDecline, + }, + }, + }, + }, + { + name: "successful elicitation - cancel", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeCancel, + }, + }, + }, + }, + { + name: "handler returns error", + handler: &mockElicitationHandler{ + err: fmt.Errorf("user interaction failed"), + }, + expectedError: "user interaction failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &Client{elicitationHandler: tt.handler} + + request := transport.JSONRPCRequest{ + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodElicitationCreate), + Params: map[string]interface{}{ + "message": "Please provide project details", + "requestedSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + } + + result, err := client.handleElicitationRequestTransport(context.Background(), request) + + if tt.expectedError != "" { + if err == nil { + t.Errorf("expected error %q, got nil", tt.expectedError) + } else if err.Error() != tt.expectedError { + t.Errorf("expected error %q, got %q", tt.expectedError, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == nil { + t.Error("expected result, got nil") + } else { + // Verify the response is properly formatted + var elicitationResult mcp.ElicitationResult + if err := json.Unmarshal(result.Result, &elicitationResult); err != nil { + t.Errorf("failed to unmarshal result: %v", err) + } + } + } + }) + } +} + +func TestWithElicitationHandler(t *testing.T) { + handler := &mockElicitationHandler{} + client := &Client{} + + option := WithElicitationHandler(handler) + option(client) + + if client.elicitationHandler != handler { + t.Error("elicitation handler not set correctly") + } +} + +func TestClient_Initialize_WithElicitationHandler(t *testing.T) { + mockTransport := &mockElicitationTransport{ + sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + // Verify that elicitation capability is included + // The client internally converts the typed params to a map for transport + // So we check if we're getting the initialize request + if request.Method != "initialize" { + t.Fatalf("expected initialize method, got %s", request.Method) + } + + // Return successful initialization response + result := mcp.InitializeResult{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ServerInfo: mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + Capabilities: mcp.ServerCapabilities{}, + } + + resultBytes, _ := json.Marshal(result) + return &transport.JSONRPCResponse{ + ID: request.ID, + Result: json.RawMessage(resultBytes), + }, nil + }, + sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil + }, + } + + handler := &mockElicitationHandler{} + client := NewClient(mockTransport, WithElicitationHandler(handler)) + + err := client.Start(context.Background()) + if err != nil { + t.Fatalf("failed to start client: %v", err) + } + + _, err = client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if err != nil { + t.Fatalf("failed to initialize: %v", err) + } +} + +// mockElicitationTransport implements transport.Interface for testing +type mockElicitationTransport struct { + sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) + sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error +} + +func (m *mockElicitationTransport) Start(ctx context.Context) error { + return nil +} + +func (m *mockElicitationTransport) Close() error { + return nil +} + +func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if m.sendRequestFunc != nil { + return m.sendRequestFunc(ctx, request) + } + return nil, nil +} + +func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + if m.sendNotificationFunc != nil { + return m.sendNotificationFunc(ctx, notification) + } + return nil +} + +func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { +} + +func (m *mockElicitationTransport) GetSessionId() string { + return "mock-session" +} diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go new file mode 100644 index 000000000..4718341c7 --- /dev/null +++ b/client/inprocess_elicitation_test.go @@ -0,0 +1,206 @@ +package client + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MockElicitationHandler implements ElicitationHandler for testing +type MockElicitationHandler struct { + // Track calls for verification + CallCount int + LastRequest mcp.ElicitationRequest +} + +func (h *MockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + h.CallCount++ + h.LastRequest = request + + // Simulate user accepting and providing data + return &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "response": "User provided data", + "accepted": true, + }, + }, + }, nil +} + +func TestInProcessElicitation(t *testing.T) { + // Create server with elicitation enabled + mcpServer := server.NewMCPServer("test-server", "1.0.0", server.WithElicitation()) + + // Add a tool that uses elicitation + mcpServer.AddTool(mcp.Tool{ + Name: "test_elicitation", + Description: "Test elicitation functionality", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "action": map[string]any{ + "type": "string", + "description": "Action to perform", + }, + }, + Required: []string{"action"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + action, err := request.RequireString("action") + if err != nil { + return nil, err + } + + // Create elicitation request + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need additional information for " + action, + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "confirm": map[string]interface{}{ + "type": "boolean", + "description": "Confirm the action", + }, + "details": map[string]interface{}{ + "type": "string", + "description": "Additional details", + }, + }, + "required": []string{"confirm"}, + }, + }, + } + + // Request elicitation from client + result, err := mcpServer.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Elicitation failed: " + err.Error(), + }, + }, + IsError: true, + }, nil + } + + // Handle the response + var responseText string + switch result.Response.Type { + case mcp.ElicitationResponseTypeAccept: + responseText = "User accepted and provided data" + case mcp.ElicitationResponseTypeDecline: + responseText = "User declined to provide information" + case mcp.ElicitationResponseTypeCancel: + responseText = "User cancelled the request" + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + }, nil + }) + + // Create handler for elicitation + mockHandler := &MockElicitationHandler{} + + // Create in-process client with elicitation handler + client, err := NewInProcessClientWithElicitationHandler(mcpServer, mockHandler) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize the client + _, err = client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + Elicitation: &struct{}{}, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + // Call the tool that triggers elicitation + result, err := client.CallTool(context.Background(), mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test_elicitation", + Arguments: map[string]any{ + "action": "test-action", + }, + }, + }) + + if err != nil { + t.Fatalf("Failed to call tool: %v", err) + } + + // Verify the result + if len(result.Content) == 0 { + t.Fatal("Expected content in result") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("Expected text content") + } + + if textContent.Text != "User accepted and provided data" { + t.Errorf("Unexpected result: %s", textContent.Text) + } + + // Verify the handler was called + if mockHandler.CallCount != 1 { + t.Errorf("Expected handler to be called once, got %d", mockHandler.CallCount) + } + + if mockHandler.LastRequest.Params.Message != "Need additional information for test-action" { + t.Errorf("Unexpected elicitation message: %s", mockHandler.LastRequest.Params.Message) + } +} + +// NewInProcessClientWithElicitationHandler creates an in-process client with elicitation support +func NewInProcessClientWithElicitationHandler(server *server.MCPServer, handler ElicitationHandler) (*Client, error) { + // Create a wrapper that implements server.ElicitationHandler + serverHandler := &inProcessElicitationHandlerWrapper{handler: handler} + + inProcessTransport := transport.NewInProcessTransportWithOptions(server, + transport.WithElicitationHandler(serverHandler)) + + client := NewClient(inProcessTransport) + client.elicitationHandler = handler + + return client, nil +} + +// inProcessElicitationHandlerWrapper wraps client.ElicitationHandler to implement server.ElicitationHandler +type inProcessElicitationHandlerWrapper struct { + handler ElicitationHandler +} + +func (w *inProcessElicitationHandlerWrapper) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + return w.handler.Elicit(ctx, request) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 59c70940b..3757664a0 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -11,10 +11,11 @@ import ( ) type InProcessTransport struct { - server *server.MCPServer - samplingHandler server.SamplingHandler - session *server.InProcessSession - sessionID string + server *server.MCPServer + samplingHandler server.SamplingHandler + elicitationHandler server.ElicitationHandler + session *server.InProcessSession + sessionID string onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -28,6 +29,12 @@ func WithSamplingHandler(handler server.SamplingHandler) InProcessOption { } } +func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption { + return func(t *InProcessTransport) { + t.elicitationHandler = handler + } +} + func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { return &InProcessTransport{ server: server, @@ -48,9 +55,9 @@ func NewInProcessTransportWithOptions(server *server.MCPServer, opts ...InProces } func (c *InProcessTransport) Start(ctx context.Context) error { - // Create and register session if we have a sampling handler - if c.samplingHandler != nil { - c.session = server.NewInProcessSession(c.sessionID, c.samplingHandler) + // Create and register session if we have handlers + if c.samplingHandler != nil || c.elicitationHandler != nil { + c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler) if err := c.server.RegisterSession(ctx, c.session); err != nil { return fmt.Errorf("failed to register session: %w", err) } diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go new file mode 100644 index 000000000..8a5e4cfcd --- /dev/null +++ b/examples/elicitation/main.go @@ -0,0 +1,208 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "sync/atomic" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// demoElicitationHandler demonstrates how to use elicitation in a tool +func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Create an elicitation request to get project details + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "I need some information to set up your project. Please provide the project details.", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectName": map[string]interface{}{ + "type": "string", + "description": "Name of the project", + "minLength": 1, + }, + "framework": map[string]interface{}{ + "type": "string", + "description": "Frontend framework to use", + "enum": []string{"react", "vue", "angular", "none"}, + }, + "includeTests": map[string]interface{}{ + "type": "boolean", + "description": "Include test setup", + "default": true, + }, + }, + "required": []string{"projectName"}, + }, + }, + } + + // Request elicitation from the client + result, err := s.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return nil, fmt.Errorf("failed to request elicitation: %w", err) + } + + // Handle the user's response + switch result.Response.Type { + case mcp.ElicitationResponseTypeAccept: + // User provided the information + data, ok := result.Response.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected response format") + } + + projectName := data["projectName"].(string) + framework := "none" + if fw, ok := data["framework"].(string); ok { + framework = fw + } + includeTests := true + if tests, ok := data["includeTests"].(bool); ok { + includeTests = tests + } + + // Create project based on user input + message := fmt.Sprintf( + "Created project '%s' with framework: %s, tests: %v", + projectName, framework, includeTests, + ) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(message), + }, + }, nil + + case mcp.ElicitationResponseTypeDecline: + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Project creation cancelled - user declined to provide information"), + }, + }, nil + + case mcp.ElicitationResponseTypeCancel: + return nil, fmt.Errorf("project creation cancelled by user") + + default: + return nil, fmt.Errorf("unexpected response type: %s", result.Response.Type) + } + } +} + +var requestCount atomic.Int32 + +func main() { + // Create server with elicitation capability + mcpServer := server.NewMCPServer( + "elicitation-demo-server", + "1.0.0", + server.WithElicitation(), // Enable elicitation + ) + + // Add a tool that uses elicitation + mcpServer.AddTool( + mcp.NewTool( + "create_project", + mcp.WithDescription("Creates a new project with user-specified configuration"), + ), + demoElicitationHandler(mcpServer), + ) + + // Add another tool that demonstrates conditional elicitation + mcpServer.AddTool( + mcp.NewTool( + "process_data", + mcp.WithDescription("Processes data with optional user confirmation"), + mcp.WithString("data", mcp.Required(), mcp.Description("Data to process")), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + data := request.GetArguments()["data"].(string) + + // Only request elicitation if data seems sensitive + if len(data) > 100 { + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: fmt.Sprintf("The data is %d characters long. Do you want to proceed with processing?", len(data)), + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "proceed": map[string]interface{}{ + "type": "boolean", + "description": "Whether to proceed with processing", + }, + "reason": map[string]interface{}{ + "type": "string", + "description": "Optional reason for your decision", + }, + }, + "required": []string{"proceed"}, + }, + }, + } + + result, err := mcpServer.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return nil, fmt.Errorf("failed to get confirmation: %w", err) + } + + if result.Response.Type != mcp.ElicitationResponseTypeAccept { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Processing cancelled by user"), + }, + }, nil + } + + responseData := result.Response.Value.(map[string]interface{}) + if proceed, ok := responseData["proceed"].(bool); !ok || !proceed { + reason := "No reason provided" + if r, ok := responseData["reason"].(string); ok && r != "" { + reason = r + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("Processing declined: %s", reason)), + }, + }, nil + } + } + + // Process the data + processed := fmt.Sprintf("Processed %d characters of data", len(data)) + count := requestCount.Add(1) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("%s (request #%d)", processed, count)), + }, + }, nil + }, + ) + + // Create and start stdio server + stdioServer := server.NewStdioServer(mcpServer) + + // Handle graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + + go func() { + <-sigChan + cancel() + }() + + fmt.Fprintln(os.Stderr, "Elicitation demo server started") + if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil { + log.Fatal(err) + } +} diff --git a/mcp/types.go b/mcp/types.go index f871b7d9d..39dc811d0 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -56,6 +56,10 @@ const ( // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging MethodSetLogLevel MCPMethod = "logging/setLevel" + // MethodElicitationCreate requests additional information from the user during interactions. + // https://modelcontextprotocol.io/docs/concepts/elicitation + MethodElicitationCreate MCPMethod = "elicitation/create" + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification MethodNotificationResourcesListChanged = "notifications/resources/list_changed" @@ -462,6 +466,8 @@ type ClientCapabilities struct { } `json:"roots,omitempty"` // Present if the client supports sampling from an LLM. Sampling *struct{} `json:"sampling,omitempty"` + // Present if the client supports elicitation requests from the server. + Elicitation *struct{} `json:"elicitation,omitempty"` } // ServerCapabilities represents capabilities that a server may support. Known @@ -492,6 +498,8 @@ type ServerCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } `json:"tools,omitempty"` + // Present if the server supports elicitation requests to the client. + Elicitation *struct{} `json:"elicitation,omitempty"` } // Implementation describes the name and version of an MCP implementation. @@ -814,6 +822,54 @@ func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { return ia >= ib } +/* Elicitation */ + +// ElicitationRequest is a request from the server to the client to request additional +// information from the user during an interaction. +type ElicitationRequest struct { + Request + Params ElicitationParams `json:"params"` +} + +// ElicitationParams contains the parameters for an elicitation request. +type ElicitationParams struct { + // A human-readable message explaining what information is being requested and why. + Message string `json:"message"` + // A JSON Schema defining the expected structure of the user's response. + RequestedSchema any `json:"requestedSchema"` +} + +// ElicitationResult represents the result of an elicitation request. +type ElicitationResult struct { + Result + // The user's response, which could be: + // - The requested information (if user accepted) + // - A decline indicator (if user declined) + // - A cancel indicator (if user cancelled) + Response ElicitationResponse `json:"response"` +} + +// ElicitationResponse represents the user's response to an elicitation request. +type ElicitationResponse struct { + // Type indicates whether the user accepted, declined, or cancelled. + Type ElicitationResponseType `json:"type"` + // Value contains the user's response data if they accepted. + // Should conform to the requestedSchema from the ElicitationRequest. + Value any `json:"value,omitempty"` +} + +// ElicitationResponseType indicates how the user responded to an elicitation request. +type ElicitationResponseType string + +const ( + // ElicitationResponseTypeAccept indicates the user provided the requested information. + ElicitationResponseTypeAccept ElicitationResponseType = "accept" + // ElicitationResponseTypeDecline indicates the user explicitly declined to provide information. + ElicitationResponseTypeDecline ElicitationResponseType = "decline" + // ElicitationResponseTypeCancel indicates the user cancelled without making a choice. + ElicitationResponseTypeCancel ElicitationResponseType = "cancel" +) + /* Sampling */ const ( diff --git a/server/elicitation.go b/server/elicitation.go new file mode 100644 index 000000000..8deee383f --- /dev/null +++ b/server/elicitation.go @@ -0,0 +1,25 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// RequestElicitation sends an elicitation request to the client. +// The client must have declared elicitation capability during initialization. +// The session must implement SessionWithElicitation to support this operation. +func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports elicitation requests + if elicitationSession, ok := session.(SessionWithElicitation); ok { + return elicitationSession.RequestElicitation(ctx, request) + } + + return nil, fmt.Errorf("session does not support elicitation") +} diff --git a/server/elicitation_test.go b/server/elicitation_test.go new file mode 100644 index 000000000..ed6feb51f --- /dev/null +++ b/server/elicitation_test.go @@ -0,0 +1,263 @@ +package server + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockBasicSession implements ClientSession for testing (without elicitation support) +type mockBasicSession struct { + sessionID string +} + +func (m *mockBasicSession) SessionID() string { + return m.sessionID +} + +func (m *mockBasicSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockBasicSession) Initialize() {} + +func (m *mockBasicSession) Initialized() bool { + return true +} + +// mockElicitationSession implements SessionWithElicitation for testing +type mockElicitationSession struct { + sessionID string + result *mcp.ElicitationResult + err error +} + +func (m *mockElicitationSession) SessionID() string { + return m.sessionID +} + +func (m *mockElicitationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockElicitationSession) Initialize() {} + +func (m *mockElicitationSession) Initialized() bool { + return true +} + +func (m *mockElicitationSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestMCPServer_RequestElicitation_NoSession(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.capabilities.elicitation = mcp.ToBoolPtr(true) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some information", + RequestedSchema: map[string]interface{}{ + "type": "object", + }, + }, + } + + _, err := server.RequestElicitation(context.Background(), request) + + if err == nil { + t.Error("expected error when no session available") + } + + expectedError := "no active session" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testing.T) { + server := NewMCPServer("test", "1.0.0", WithElicitation()) + + // Use a regular session that doesn't implement SessionWithElicitation + mockSession := &mockBasicSession{sessionID: "test-session"} + + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some information", + RequestedSchema: map[string]interface{}{ + "type": "object", + }, + }, + } + + _, err := server.RequestElicitation(ctx, request) + + if err == nil { + t.Error("expected error when session doesn't support elicitation") + } + + expectedError := "session does not support elicitation" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +func TestMCPServer_RequestElicitation_Success(t *testing.T) { + server := NewMCPServer("test", "1.0.0", WithElicitation()) + + // Create a mock elicitation session + mockSession := &mockElicitationSession{ + sessionID: "test-session", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "projectName": "my-project", + "framework": "react", + }, + }, + }, + } + + // Create context with session + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Please provide project details", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectName": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + } + + result, err := server.RequestElicitation(ctx, request) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result == nil { + t.Error("expected result, got nil") + return + } + + if result.Response.Type != mcp.ElicitationResponseTypeAccept { + t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseTypeAccept, result.Response.Type) + } + + value, ok := result.Response.Value.(map[string]interface{}) + if !ok { + t.Error("expected value to be a map") + return + } + + if value["projectName"] != "my-project" { + t.Errorf("expected projectName %q, got %q", "my-project", value["projectName"]) + } +} + +func TestRequestElicitation(t *testing.T) { + tests := []struct { + name string + session ClientSession + request mcp.ElicitationRequest + expectedError string + expectedType mcp.ElicitationResponseType + }{ + { + name: "successful elicitation with accept", + session: &mockElicitationSession{ + sessionID: "test-1", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "name": "test-project", + "framework": "react", + }, + }, + }, + }, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Please provide project details", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + expectedType: mcp.ElicitationResponseTypeAccept, + }, + { + name: "elicitation declined by user", + session: &mockElicitationSession{ + sessionID: "test-2", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeDecline, + }, + }, + }, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some info", + RequestedSchema: map[string]interface{}{"type": "object"}, + }, + }, + expectedType: mcp.ElicitationResponseTypeDecline, + }, + { + name: "session does not support elicitation", + session: &fakeSession{sessionID: "test-3"}, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need info", + RequestedSchema: map[string]interface{}{"type": "object"}, + }, + }, + expectedError: "session does not support elicitation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test", "1.0", WithElicitation()) + ctx := server.WithContext(context.Background(), tt.session) + + result, err := server.RequestElicitation(ctx, tt.request) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.expectedType, result.Response.Type) + + if tt.expectedType == mcp.ElicitationResponseTypeAccept { + assert.NotNil(t, result.Response.Value) + } + }) + } +} diff --git a/server/inprocess_session.go b/server/inprocess_session.go index daaf28a5c..c6fddc601 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -15,6 +15,11 @@ type SamplingHandler interface { CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) } +// ElicitationHandler defines the interface for handling elicitation requests from servers. +type ElicitationHandler interface { + Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + type InProcessSession struct { sessionID string notifications chan mcp.JSONRPCNotification @@ -23,6 +28,7 @@ type InProcessSession struct { clientInfo atomic.Value clientCapabilities atomic.Value samplingHandler SamplingHandler + elicitationHandler ElicitationHandler mu sync.RWMutex } @@ -34,6 +40,15 @@ func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InP } } +func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + elicitationHandler: elicitationHandler, + } +} + func (s *InProcessSession) SessionID() string { return s.sessionID } @@ -101,6 +116,18 @@ func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.Crea return handler.CreateMessage(ctx, request) } +func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + handler := s.elicitationHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no elicitation handler available") + } + + return handler.Elicit(ctx, request) +} + // GenerateInProcessSessionID generates a unique session ID for inprocess clients func GenerateInProcessSessionID() string { return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) @@ -108,8 +135,9 @@ func GenerateInProcessSessionID() string { // Ensure interface compliance var ( - _ ClientSession = (*InProcessSession)(nil) - _ SessionWithLogging = (*InProcessSession)(nil) - _ SessionWithClientInfo = (*InProcessSession)(nil) - _ SessionWithSampling = (*InProcessSession)(nil) + _ ClientSession = (*InProcessSession)(nil) + _ SessionWithLogging = (*InProcessSession)(nil) + _ SessionWithClientInfo = (*InProcessSession)(nil) + _ SessionWithSampling = (*InProcessSession)(nil) + _ SessionWithElicitation = (*InProcessSession)(nil) ) diff --git a/server/server.go b/server/server.go index b9fb3612b..95e95fe08 100644 --- a/server/server.go +++ b/server/server.go @@ -182,11 +182,12 @@ func WithPaginationLimit(limit int) ServerOption { // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { - tools *toolCapabilities - resources *resourceCapabilities - prompts *promptCapabilities - logging *bool - sampling *bool + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging *bool + sampling *bool + elicitation *bool } // resourceCapabilities defines the supported resource-related features @@ -323,6 +324,13 @@ func WithLogging() ServerOption { } } +// WithElicitation enables elicitation capabilities for the server +func WithElicitation() ServerOption { + return func(s *MCPServer) { + s.capabilities.elicitation = mcp.ToBoolPtr(true) + } +} + // WithInstructions sets the server instructions for the client returned in the initialize response func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { @@ -689,6 +697,10 @@ func (s *MCPServer) handleInitialize( capabilities.Sampling = &struct{}{} } + if s.capabilities.elicitation != nil && *s.capabilities.elicitation { + capabilities.Elicitation = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ diff --git a/server/session.go b/server/session.go index 11ee8a2f1..3d11df932 100644 --- a/server/session.go +++ b/server/session.go @@ -52,6 +52,13 @@ type SessionWithClientInfo interface { SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) } +// SessionWithElicitation is an extension of ClientSession that can send elicitation requests +type SessionWithElicitation interface { + ClientSession + // RequestElicitation sends an elicitation request to the client and waits for response + RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations type SessionWithStreamableHTTPConfig interface { ClientSession diff --git a/server/stdio.go b/server/stdio.go index 8c270e18b..4f4805270 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -92,16 +92,17 @@ func WithQueueSize(size int) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info - clientCapabilities atomic.Value // stores session-specific client capabilities - writer io.Writer // for sending requests to client - requestID atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects writer - pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests - pendingMu sync.RWMutex // protects pendingRequests + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests + pendingMu sync.RWMutex // protects pendingRequests and pendingElicitations } // samplingResponse represents a response to a sampling request @@ -110,6 +111,12 @@ type samplingResponse struct { err error } +// elicitationResponse represents a response to an elicitation request +type elicitationResponse struct { + result *mcp.ElicitationResult + err error +} + func (s *stdioSession) SessionID() string { return "stdio" } @@ -229,6 +236,69 @@ func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMe } } +// RequestElicitation sends an elicitation request to the client and waits for the response. +func (s *stdioSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *elicitationResponse, 1) + s.pendingMu.Lock() + s.pendingElicitations[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingElicitations, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.ElicitationParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodElicitationCreate), + Params: request.Params, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal elicitation request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write elicitation request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + // SetWriter sets the writer for sending requests to the client. func (s *stdioSession) SetWriter(writer io.Writer) { s.mu.Lock() @@ -237,15 +307,17 @@ func (s *stdioSession) SetWriter(writer io.Writer) { } var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) - _ SessionWithClientInfo = (*stdioSession)(nil) - _ SessionWithSampling = (*stdioSession)(nil) + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) + _ SessionWithElicitation = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), - pendingRequests: make(map[int64]chan *samplingResponse), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), + pendingElicitations: make(map[int64]chan *elicitationResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -445,6 +517,11 @@ func (s *StdioServer) processMessage( return nil } + // Check if this is a response to an elicitation request + if s.handleElicitationResponse(rawMessage) { + return nil + } + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) var baseMessage struct { Method string `json:"method"` @@ -543,6 +620,67 @@ func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { return true } +// handleElicitationResponse checks if the message is a response to an elicitation request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleElicitationResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleElicitationResponse(rawMessage) +} + +// handleElicitationResponse handles incoming elicitation responses for this session +func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + id, err := response.ID.Int64() + if err != nil { + return false + } + + // Check if we have a pending elicitation request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingElicitations[id] + s.pendingMu.RUnlock() + + if !exists { + return false + } + + // Parse and send the response + elicitationResp := &elicitationResponse{} + + if response.Error != nil { + elicitationResp.err = fmt.Errorf("elicitation request failed: %s", response.Error.Message) + } else { + var result mcp.ElicitationResult + if err := json.Unmarshal(response.Result, &result); err != nil { + elicitationResp.err = fmt.Errorf("failed to unmarshal elicitation response: %w", err) + } else { + elicitationResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- elicitationResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( From 2759c2b4b2f7c97741862a24dc77fe815699518c Mon Sep 17 00:00:00 2001 From: Ramon Nogueira Date: Sat, 23 Aug 2025 17:43:19 -0600 Subject: [PATCH 2/7] Address review comments and auto-format --- client/inprocess_elicitation_test.go | 19 ++------ examples/elicitation/main.go | 72 +++++++++++++++++++++++----- server/stdio.go | 2 +- 3 files changed, 67 insertions(+), 26 deletions(-) diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go index 4718341c7..a03575671 100644 --- a/client/inprocess_elicitation_test.go +++ b/client/inprocess_elicitation_test.go @@ -37,20 +37,11 @@ func TestInProcessElicitation(t *testing.T) { mcpServer := server.NewMCPServer("test-server", "1.0.0", server.WithElicitation()) // Add a tool that uses elicitation - mcpServer.AddTool(mcp.Tool{ - Name: "test_elicitation", - Description: "Test elicitation functionality", - InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]any{ - "action": map[string]any{ - "type": "string", - "description": "Action to perform", - }, - }, - Required: []string{"action"}, - }, - }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool( + "test_elicitation", + mcp.WithDescription("Test elicitation functionality"), + mcp.WithString("action", mcp.Description("Action to perform"), mcp.Required()), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { action, err := request.RequireString("action") if err != nil { return nil, err diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go index 8a5e4cfcd..825c73f9a 100644 --- a/examples/elicitation/main.go +++ b/examples/elicitation/main.go @@ -55,17 +55,40 @@ func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { // User provided the information data, ok := result.Response.Value.(map[string]interface{}) if !ok { - return nil, fmt.Errorf("unexpected response format") + return nil, fmt.Errorf("unexpected response format: expected map[string]interface{}, got %T", result.Response.Value) } - projectName := data["projectName"].(string) + // Safely extract projectName (required field) + projectNameRaw, exists := data["projectName"] + if !exists { + return nil, fmt.Errorf("required field 'projectName' is missing from response") + } + projectName, ok := projectNameRaw.(string) + if !ok { + return nil, fmt.Errorf("field 'projectName' must be a string, got %T", projectNameRaw) + } + if projectName == "" { + return nil, fmt.Errorf("field 'projectName' cannot be empty") + } + + // Safely extract framework (optional field) framework := "none" - if fw, ok := data["framework"].(string); ok { - framework = fw + if frameworkRaw, exists := data["framework"]; exists { + if fw, ok := frameworkRaw.(string); ok { + framework = fw + } else { + return nil, fmt.Errorf("field 'framework' must be a string, got %T", frameworkRaw) + } } + + // Safely extract includeTests (optional field) includeTests := true - if tests, ok := data["includeTests"].(bool); ok { - includeTests = tests + if testsRaw, exists := data["includeTests"]; exists { + if tests, ok := testsRaw.(bool); ok { + includeTests = tests + } else { + return nil, fmt.Errorf("field 'includeTests' must be a boolean, got %T", testsRaw) + } } // Create project based on user input @@ -123,7 +146,15 @@ func main() { mcp.WithString("data", mcp.Required(), mcp.Description("Data to process")), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - data := request.GetArguments()["data"].(string) + // Safely extract data argument + dataRaw, exists := request.GetArguments()["data"] + if !exists { + return nil, fmt.Errorf("required parameter 'data' is missing") + } + data, ok := dataRaw.(string) + if !ok { + return nil, fmt.Errorf("parameter 'data' must be a string, got %T", dataRaw) + } // Only request elicitation if data seems sensitive if len(data) > 100 { @@ -160,11 +191,30 @@ func main() { }, nil } - responseData := result.Response.Value.(map[string]interface{}) - if proceed, ok := responseData["proceed"].(bool); !ok || !proceed { + // Safely extract response data + responseData, ok := result.Response.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected response format: expected map[string]interface{}, got %T", result.Response.Value) + } + + // Safely extract proceed field + proceedRaw, exists := responseData["proceed"] + if !exists { + return nil, fmt.Errorf("required field 'proceed' is missing from response") + } + proceed, ok := proceedRaw.(bool) + if !ok { + return nil, fmt.Errorf("field 'proceed' must be a boolean, got %T", proceedRaw) + } + + if !proceed { reason := "No reason provided" - if r, ok := responseData["reason"].(string); ok && r != "" { - reason = r + if reasonRaw, exists := responseData["reason"]; exists { + if r, ok := reasonRaw.(string); ok && r != "" { + reason = r + } else if reasonRaw != nil { + return nil, fmt.Errorf("field 'reason' must be a string, got %T", reasonRaw) + } } return &mcp.CallToolResult{ Content: []mcp.Content{ diff --git a/server/stdio.go b/server/stdio.go index 4f4805270..d80941c3d 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -644,7 +644,7 @@ func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) boo } // Parse the ID as int64 id, err := response.ID.Int64() - if err != nil { + if err != nil || (response.Result == nil && response.Error == nil) { return false } From 5f8ceba62ec0e6e6e4a9cbf87b3d92379b1e43b0 Mon Sep 17 00:00:00 2001 From: Ramon Nogueira Date: Sat, 23 Aug 2025 18:19:55 -0600 Subject: [PATCH 3/7] Address further minor review comments --- client/elicitation_test.go | 12 +++++------ client/inprocess_elicitation_test.go | 31 ++++++++++++++-------------- examples/elicitation/main.go | 26 +++++++++++------------ server/elicitation_test.go | 30 +++++++++++++-------------- 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/client/elicitation_test.go b/client/elicitation_test.go index a05be35c5..425b36f32 100644 --- a/client/elicitation_test.go +++ b/client/elicitation_test.go @@ -40,7 +40,7 @@ func TestClient_HandleElicitationRequest(t *testing.T) { result: &mcp.ElicitationResult{ Response: mcp.ElicitationResponse{ Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]interface{}{ + Value: map[string]any{ "name": "test-project", "framework": "react", }, @@ -84,13 +84,13 @@ func TestClient_HandleElicitationRequest(t *testing.T) { request := transport.JSONRPCRequest{ ID: mcp.NewRequestId(1), Method: string(mcp.MethodElicitationCreate), - Params: map[string]interface{}{ + Params: map[string]any{ "message": "Please provide project details", - "requestedSchema": map[string]interface{}{ + "requestedSchema": map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{"type": "string"}, - "framework": map[string]interface{}{"type": "string"}, + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "framework": map[string]any{"type": "string"}, }, }, }, diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go index a03575671..19dd1a5c0 100644 --- a/client/inprocess_elicitation_test.go +++ b/client/inprocess_elicitation_test.go @@ -24,9 +24,9 @@ func (h *MockElicitationHandler) Elicit(ctx context.Context, request mcp.Elicita return &mcp.ElicitationResult{ Response: mcp.ElicitationResponse{ Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]interface{}{ - "response": "User provided data", - "accepted": true, + Value: map[string]any{ + "confirm": true, + "details": "User provided additional details", }, }, }, nil @@ -51,14 +51,14 @@ func TestInProcessElicitation(t *testing.T) { elicitationRequest := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need additional information for " + action, - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "confirm": map[string]interface{}{ + "properties": map[string]any{ + "confirm": map[string]any{ "type": "boolean", "description": "Confirm the action", }, - "details": map[string]interface{}{ + "details": map[string]any{ "type": "string", "description": "Additional details", }, @@ -107,10 +107,7 @@ func TestInProcessElicitation(t *testing.T) { mockHandler := &MockElicitationHandler{} // Create in-process client with elicitation handler - client, err := NewInProcessClientWithElicitationHandler(mcpServer, mockHandler) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } + client := NewInProcessClientWithElicitationHandler(mcpServer, mockHandler) defer client.Close() // Start the client @@ -119,7 +116,7 @@ func TestInProcessElicitation(t *testing.T) { } // Initialize the client - _, err = client.Initialize(context.Background(), mcp.InitializeRequest{ + _, err := client.Initialize(context.Background(), mcp.InitializeRequest{ Params: mcp.InitializeParams{ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, ClientInfo: mcp.Implementation{ @@ -154,6 +151,11 @@ func TestInProcessElicitation(t *testing.T) { t.Fatal("Expected content in result") } + // Assert that the result is not flagged as an error for the accept path + if result.IsError { + t.Error("Expected result to not be flagged as error for accept response") + } + textContent, ok := result.Content[0].(mcp.TextContent) if !ok { t.Fatal("Expected text content") @@ -174,7 +176,7 @@ func TestInProcessElicitation(t *testing.T) { } // NewInProcessClientWithElicitationHandler creates an in-process client with elicitation support -func NewInProcessClientWithElicitationHandler(server *server.MCPServer, handler ElicitationHandler) (*Client, error) { +func NewInProcessClientWithElicitationHandler(server *server.MCPServer, handler ElicitationHandler) *Client { // Create a wrapper that implements server.ElicitationHandler serverHandler := &inProcessElicitationHandlerWrapper{handler: handler} @@ -182,9 +184,8 @@ func NewInProcessClientWithElicitationHandler(server *server.MCPServer, handler transport.WithElicitationHandler(serverHandler)) client := NewClient(inProcessTransport) - client.elicitationHandler = handler - return client, nil + return client } // inProcessElicitationHandlerWrapper wraps client.ElicitationHandler to implement server.ElicitationHandler diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go index 825c73f9a..6892f8146 100644 --- a/examples/elicitation/main.go +++ b/examples/elicitation/main.go @@ -19,20 +19,20 @@ func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { elicitationRequest := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "I need some information to set up your project. Please provide the project details.", - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "projectName": map[string]interface{}{ + "properties": map[string]any{ + "projectName": map[string]any{ "type": "string", "description": "Name of the project", "minLength": 1, }, - "framework": map[string]interface{}{ + "framework": map[string]any{ "type": "string", "description": "Frontend framework to use", "enum": []string{"react", "vue", "angular", "none"}, }, - "includeTests": map[string]interface{}{ + "includeTests": map[string]any{ "type": "boolean", "description": "Include test setup", "default": true, @@ -53,9 +53,9 @@ func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { switch result.Response.Type { case mcp.ElicitationResponseTypeAccept: // User provided the information - data, ok := result.Response.Value.(map[string]interface{}) + data, ok := result.Response.Value.(map[string]any) if !ok { - return nil, fmt.Errorf("unexpected response format: expected map[string]interface{}, got %T", result.Response.Value) + return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Response.Value) } // Safely extract projectName (required field) @@ -161,14 +161,14 @@ func main() { elicitationRequest := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: fmt.Sprintf("The data is %d characters long. Do you want to proceed with processing?", len(data)), - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "proceed": map[string]interface{}{ + "properties": map[string]any{ + "proceed": map[string]any{ "type": "boolean", "description": "Whether to proceed with processing", }, - "reason": map[string]interface{}{ + "reason": map[string]any{ "type": "string", "description": "Optional reason for your decision", }, @@ -192,9 +192,9 @@ func main() { } // Safely extract response data - responseData, ok := result.Response.Value.(map[string]interface{}) + responseData, ok := result.Response.Value.(map[string]any) if !ok { - return nil, fmt.Errorf("unexpected response format: expected map[string]interface{}, got %T", result.Response.Value) + return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Response.Value) } // Safely extract proceed field diff --git a/server/elicitation_test.go b/server/elicitation_test.go index ed6feb51f..4f6aadfd6 100644 --- a/server/elicitation_test.go +++ b/server/elicitation_test.go @@ -63,7 +63,7 @@ func TestMCPServer_RequestElicitation_NoSession(t *testing.T) { request := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need some information", - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", }, }, @@ -93,7 +93,7 @@ func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testin request := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need some information", - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", }, }, @@ -120,7 +120,7 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { result: &mcp.ElicitationResult{ Response: mcp.ElicitationResponse{ Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]interface{}{ + Value: map[string]any{ "projectName": "my-project", "framework": "react", }, @@ -135,11 +135,11 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { request := mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Please provide project details", - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "projectName": map[string]interface{}{"type": "string"}, - "framework": map[string]interface{}{"type": "string"}, + "properties": map[string]any{ + "projectName": map[string]any{"type": "string"}, + "framework": map[string]any{"type": "string"}, }, }, }, @@ -160,7 +160,7 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseTypeAccept, result.Response.Type) } - value, ok := result.Response.Value.(map[string]interface{}) + value, ok := result.Response.Value.(map[string]any) if !ok { t.Error("expected value to be a map") return @@ -186,7 +186,7 @@ func TestRequestElicitation(t *testing.T) { result: &mcp.ElicitationResult{ Response: mcp.ElicitationResponse{ Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]interface{}{ + Value: map[string]any{ "name": "test-project", "framework": "react", }, @@ -196,11 +196,11 @@ func TestRequestElicitation(t *testing.T) { request: mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Please provide project details", - RequestedSchema: map[string]interface{}{ + RequestedSchema: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "name": map[string]interface{}{"type": "string"}, - "framework": map[string]interface{}{"type": "string"}, + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "framework": map[string]any{"type": "string"}, }, }, }, @@ -220,7 +220,7 @@ func TestRequestElicitation(t *testing.T) { request: mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need some info", - RequestedSchema: map[string]interface{}{"type": "object"}, + RequestedSchema: map[string]any{"type": "object"}, }, }, expectedType: mcp.ElicitationResponseTypeDecline, @@ -231,7 +231,7 @@ func TestRequestElicitation(t *testing.T) { request: mcp.ElicitationRequest{ Params: mcp.ElicitationParams{ Message: "Need info", - RequestedSchema: map[string]interface{}{"type": "object"}, + RequestedSchema: map[string]any{"type": "object"}, }, }, expectedError: "session does not support elicitation", From e0d70ddf804650077fb1385f5f86ca3507504ae7 Mon Sep 17 00:00:00 2001 From: Ramon Nogueira Date: Sat, 23 Aug 2025 18:59:12 -0600 Subject: [PATCH 4/7] Add sentinel errors --- server/elicitation.go | 13 ++++++++++--- server/elicitation_test.go | 19 +++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/server/elicitation.go b/server/elicitation.go index 8deee383f..d3e6d3d4c 100644 --- a/server/elicitation.go +++ b/server/elicitation.go @@ -2,18 +2,25 @@ package server import ( "context" - "fmt" + "errors" "github.com/mark3labs/mcp-go/mcp" ) +var ( + // ErrNoActiveSession is returned when there is no active session in the context + ErrNoActiveSession = errors.New("no active session") + // ErrElicitationNotSupported is returned when the session does not support elicitation + ErrElicitationNotSupported = errors.New("session does not support elicitation") +) + // RequestElicitation sends an elicitation request to the client. // The client must have declared elicitation capability during initialization. // The session must implement SessionWithElicitation to support this operation. func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { session := ClientSessionFromContext(ctx) if session == nil { - return nil, fmt.Errorf("no active session") + return nil, ErrNoActiveSession } // Check if the session supports elicitation requests @@ -21,5 +28,5 @@ func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.Elicitat return elicitationSession.RequestElicitation(ctx, request) } - return nil, fmt.Errorf("session does not support elicitation") + return nil, ErrElicitationNotSupported } diff --git a/server/elicitation_test.go b/server/elicitation_test.go index 4f6aadfd6..47868f813 100644 --- a/server/elicitation_test.go +++ b/server/elicitation_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "testing" "github.com/mark3labs/mcp-go/mcp" @@ -75,9 +76,8 @@ func TestMCPServer_RequestElicitation_NoSession(t *testing.T) { t.Error("expected error when no session available") } - expectedError := "no active session" - if err.Error() != expectedError { - t.Errorf("expected error %q, got %q", expectedError, err.Error()) + if !errors.Is(err, ErrNoActiveSession) { + t.Errorf("expected ErrNoActiveSession, got %v", err) } } @@ -105,9 +105,8 @@ func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testin t.Error("expected error when session doesn't support elicitation") } - expectedError := "session does not support elicitation" - if err.Error() != expectedError { - t.Errorf("expected error %q, got %q", expectedError, err.Error()) + if !errors.Is(err, ErrElicitationNotSupported) { + t.Errorf("expected ErrElicitationNotSupported, got %v", err) } } @@ -176,7 +175,7 @@ func TestRequestElicitation(t *testing.T) { name string session ClientSession request mcp.ElicitationRequest - expectedError string + expectedError error expectedType mcp.ElicitationResponseType }{ { @@ -234,7 +233,7 @@ func TestRequestElicitation(t *testing.T) { RequestedSchema: map[string]any{"type": "object"}, }, }, - expectedError: "session does not support elicitation", + expectedError: ErrElicitationNotSupported, }, } @@ -245,9 +244,9 @@ func TestRequestElicitation(t *testing.T) { result, err := server.RequestElicitation(ctx, tt.request) - if tt.expectedError != "" { + if tt.expectedError != nil { require.Error(t, err) - assert.Contains(t, err.Error(), tt.expectedError) + assert.True(t, errors.Is(err, tt.expectedError), "expected %v, got %v", tt.expectedError, err) return } From 6cd58d8f4496239fdf6582fc57b6dba4b6539691 Mon Sep 17 00:00:00 2001 From: Ramon Nogueira Date: Sat, 23 Aug 2025 19:04:53 -0600 Subject: [PATCH 5/7] Revert sampling formatting changes --- .../streamable_http_sampling_test.go | 84 +++++++++---------- examples/sampling_client/main.go | 4 +- examples/sampling_http_client/main.go | 10 +-- examples/sampling_http_server/main.go | 2 +- server/sampling.go | 2 +- server/sampling_test.go | 12 +-- server/streamable_http_sampling_test.go | 2 +- 7 files changed, 58 insertions(+), 58 deletions(-) diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index 4a38f280e..edba61eac 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -16,27 +16,27 @@ import ( // TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport func TestStreamableHTTP_SamplingFlow(t *testing.T) { - // Create simple test server + // Create simple test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Just respond OK to any requests w.WriteHeader(http.StatusOK) })) defer server.Close() - + // Create HTTP client transport client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up sampling request handler var handledRequest *JSONRPCRequest handlerCalled := make(chan struct{}) client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { handledRequest = &request close(handlerCalled) - + // Simulate sampling handler response result := map[string]any{ "role": "assistant", @@ -47,25 +47,25 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { "model": "test-model", "stopReason": "stop_sequence", } - + resultBytes, _ := json.Marshal(result) - + return &JSONRPCResponse{ JSONRPC: "2.0", ID: request.ID, Result: resultBytes, }, nil }) - + // Start the client ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - + err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Test direct request handling (simulating a sampling request) samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -83,10 +83,10 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { }, }, } - + // Directly test request handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for handler to be called select { case <-handlerCalled: @@ -94,12 +94,12 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + // Verify the request was handled if handledRequest == nil { t.Fatal("Sampling request was not handled") } - + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) } @@ -109,7 +109,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { var errorHandled sync.WaitGroup errorHandled.Add(1) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -118,7 +118,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) @@ -132,25 +132,25 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up request handler that returns an error client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { return nil, fmt.Errorf("sampling failed") }) - + // Start the client ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -158,10 +158,10 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger error handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be handled errorHandled.Wait() } @@ -170,7 +170,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { var errorReceived bool errorReceivedChan := make(chan struct{}) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -179,12 +179,12 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response with method not found if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) if code, ok := errorMap["code"].(float64); ok && code == -32601 { - if message, ok := errorMap["message"].(string); ok && + if message, ok := errorMap["message"].(string); ok && strings.Contains(message, "no handler configured") { errorReceived = true close(errorReceivedChan) @@ -195,21 +195,21 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Don't set any request handler - + ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -217,10 +217,10 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger "method not found" error client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be received select { case <-errorReceivedChan: @@ -228,7 +228,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Method not found error was not received within timeout") } - + if !errorReceived { t.Error("Expected method not found error, but didn't receive it") } @@ -241,13 +241,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Verify it implements BidirectionalInterface _, ok := any(client).(BidirectionalInterface) if !ok { t.Error("StreamableHTTP should implement BidirectionalInterface") } - + // Test SetRequestHandler handlerSet := false handlerSetChan := make(chan struct{}) @@ -256,7 +256,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { close(handlerSetChan) return nil, nil }) - + // Verify handler was set by triggering it ctx := context.Background() client.handleIncomingRequest(ctx, JSONRPCRequest{ @@ -264,7 +264,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { ID: mcp.NewRequestId(1), Method: "test", }) - + // Wait for handler to be called select { case <-handlerSetChan: @@ -272,7 +272,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + if !handlerSet { t.Error("Request handler was not properly set or called") } @@ -315,16 +315,16 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Track which requests have been received and their completion order var requestOrder []int var orderMutex sync.Mutex - + // Set up request handler that simulates different processing times client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { // Extract request ID to determine processing time requestIDValue := request.ID.Value() - + var delay time.Duration var responseText string var requestNum int - + // First request (ID 1) takes longer, second request (ID 2) completes faster if requestIDValue == int64(1) { delay = 100 * time.Millisecond @@ -341,7 +341,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Simulate processing time time.Sleep(delay) - + // Record completion order orderMutex.Lock() requestOrder = append(requestOrder, requestNum) @@ -428,7 +428,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Verify completion order: request 2 should complete first orderMutex.Lock() defer orderMutex.Unlock() - + if len(requestOrder) != 2 { t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) } @@ -493,4 +493,4 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { } } } -} +} \ No newline at end of file diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index a655fde62..093b59817 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -95,7 +95,7 @@ func main() { // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - + // Create a context that cancels on signal ctx, cancel := context.WithCancel(ctx) go func() { @@ -103,7 +103,7 @@ func main() { log.Println("Received shutdown signal, closing client...") cancel() }() - + // Move defer after error checking defer func() { if err := mcpClient.Close(); err != nil { diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go index 946223f7b..98817e6f8 100644 --- a/examples/sampling_http_client/main.go +++ b/examples/sampling_http_client/main.go @@ -63,7 +63,7 @@ func main() { log.Fatalf("Failed to create HTTP transport: %v", err) } defer httpTransport.Close() - + // Create client with sampling support mcpClient := client.NewClient( httpTransport, @@ -81,7 +81,7 @@ func main() { initRequest := mcp.InitializeRequest{ Params: mcp.InitializeParams{ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - Capabilities: mcp.ClientCapabilities{ + Capabilities: mcp.ClientCapabilities{ // Sampling capability will be automatically added by the client }, ClientInfo: mcp.Implementation{ @@ -90,7 +90,7 @@ func main() { }, }, } - + _, err = mcpClient.Initialize(ctx, initRequest) if err != nil { log.Fatalf("Failed to initialize MCP session: %v", err) @@ -102,7 +102,7 @@ func main() { // In a real application, you would keep the client running to handle sampling requests // For this example, we'll just demonstrate that it's working - + // Keep the client running (in a real app, you'd have your main application logic here) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) @@ -113,4 +113,4 @@ func main() { case <-sigChan: log.Println("Received shutdown signal") } -} +} \ No newline at end of file diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go index a178accee..95a2bf29b 100644 --- a/examples/sampling_http_server/main.go +++ b/examples/sampling_http_server/main.go @@ -147,4 +147,4 @@ func main() { if err := httpServer.Start(":8080"); err != nil { log.Fatalf("Server failed to start: %v", err) } -} +} \ No newline at end of file diff --git a/server/sampling.go b/server/sampling.go index 2118db155..4423ccf5f 100644 --- a/server/sampling.go +++ b/server/sampling.go @@ -12,7 +12,7 @@ import ( func (s *MCPServer) EnableSampling() { s.capabilitiesMu.Lock() defer s.capabilitiesMu.Unlock() - + enabled := true s.capabilities.sampling = &enabled } diff --git a/server/sampling_test.go b/server/sampling_test.go index 012bf2fd9..fbecdd70d 100644 --- a/server/sampling_test.go +++ b/server/sampling_test.go @@ -116,7 +116,7 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) { func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { server := NewMCPServer("test", "1.0.0") - + // Verify sampling capability is not set initially ctx := context.Background() initRequest := mcp.InitializeRequest{ @@ -129,25 +129,25 @@ func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { Capabilities: mcp.ClientCapabilities{}, }, } - + result, err := server.handleInitialize(ctx, 1, initRequest) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if result.Capabilities.Sampling != nil { t.Error("sampling capability should not be set before EnableSampling() is called") } - + // Enable sampling server.EnableSampling() - + // Verify sampling capability is now set result, err = server.handleInitialize(ctx, 2, initRequest) if err != nil { t.Fatalf("unexpected error after EnableSampling(): %v", err) } - + if result.Capabilities.Sampling == nil { t.Error("sampling capability should be set after EnableSampling() is called") } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 50be27fa7..4cf57838c 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -213,4 +213,4 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { if !strings.Contains(err.Error(), "queue is full") { t.Errorf("Expected queue full error, got: %v", err) } -} +} \ No newline at end of file From aef7c8ddc7540a4556c52c2fe067c1a25b78ff9d Mon Sep 17 00:00:00 2001 From: Miguel Date: Thu, 11 Sep 2025 16:02:03 -0400 Subject: [PATCH 6/7] Update elicitation response to match spec Updating elicitation response to match MCP spec document https://modelcontextprotocol.io/specification/draft/client/elicitation --- client/elicitation_test.go | 14 ++++++------ client/inprocess_elicitation_test.go | 14 ++++++------ examples/elicitation/main.go | 20 ++++++++-------- mcp/types.go | 30 +++++++++++------------- server/elicitation_test.go | 34 ++++++++++++++-------------- 5 files changed, 54 insertions(+), 58 deletions(-) diff --git a/client/elicitation_test.go b/client/elicitation_test.go index 425b36f32..d21b1f079 100644 --- a/client/elicitation_test.go +++ b/client/elicitation_test.go @@ -38,9 +38,9 @@ func TestClient_HandleElicitationRequest(t *testing.T) { name: "successful elicitation - accept", handler: &mockElicitationHandler{ result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]any{ + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionAccept, + Content: map[string]any{ "name": "test-project", "framework": "react", }, @@ -52,8 +52,8 @@ func TestClient_HandleElicitationRequest(t *testing.T) { name: "successful elicitation - decline", handler: &mockElicitationHandler{ result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeDecline, + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionDecline, }, }, }, @@ -62,8 +62,8 @@ func TestClient_HandleElicitationRequest(t *testing.T) { name: "successful elicitation - cancel", handler: &mockElicitationHandler{ result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeCancel, + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionCancel, }, }, }, diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go index 19dd1a5c0..f659bbb10 100644 --- a/client/inprocess_elicitation_test.go +++ b/client/inprocess_elicitation_test.go @@ -22,9 +22,9 @@ func (h *MockElicitationHandler) Elicit(ctx context.Context, request mcp.Elicita // Simulate user accepting and providing data return &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]any{ + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionAccept, + Content: map[string]any{ "confirm": true, "details": "User provided additional details", }, @@ -84,12 +84,12 @@ func TestInProcessElicitation(t *testing.T) { // Handle the response var responseText string - switch result.Response.Type { - case mcp.ElicitationResponseTypeAccept: + switch result.Action { + case mcp.ElicitationResponseActionAccept: responseText = "User accepted and provided data" - case mcp.ElicitationResponseTypeDecline: + case mcp.ElicitationResponseActionDecline: responseText = "User declined to provide information" - case mcp.ElicitationResponseTypeCancel: + case mcp.ElicitationResponseActionCancel: responseText = "User cancelled the request" } diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go index 6892f8146..742d036ad 100644 --- a/examples/elicitation/main.go +++ b/examples/elicitation/main.go @@ -50,12 +50,12 @@ func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { } // Handle the user's response - switch result.Response.Type { - case mcp.ElicitationResponseTypeAccept: + switch result.Action { + case mcp.ElicitationResponseActionAccept: // User provided the information - data, ok := result.Response.Value.(map[string]any) + data, ok := result.Content.(map[string]any) if !ok { - return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Response.Value) + return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Content) } // Safely extract projectName (required field) @@ -103,18 +103,18 @@ func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { }, }, nil - case mcp.ElicitationResponseTypeDecline: + case mcp.ElicitationResponseActionDecline: return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.NewTextContent("Project creation cancelled - user declined to provide information"), }, }, nil - case mcp.ElicitationResponseTypeCancel: + case mcp.ElicitationResponseActionCancel: return nil, fmt.Errorf("project creation cancelled by user") default: - return nil, fmt.Errorf("unexpected response type: %s", result.Response.Type) + return nil, fmt.Errorf("unexpected response action: %s", result.Action) } } } @@ -183,7 +183,7 @@ func main() { return nil, fmt.Errorf("failed to get confirmation: %w", err) } - if result.Response.Type != mcp.ElicitationResponseTypeAccept { + if result.Action != mcp.ElicitationResponseActionAccept { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.NewTextContent("Processing cancelled by user"), @@ -192,9 +192,9 @@ func main() { } // Safely extract response data - responseData, ok := result.Response.Value.(map[string]any) + responseData, ok := result.Content.(map[string]any) if !ok { - return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Response.Value) + return nil, fmt.Errorf("unexpected response format: expected map[string]any, got %T", result.Content) } // Safely extract proceed field diff --git a/mcp/types.go b/mcp/types.go index 39dc811d0..e28e72e90 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -842,32 +842,28 @@ type ElicitationParams struct { // ElicitationResult represents the result of an elicitation request. type ElicitationResult struct { Result - // The user's response, which could be: - // - The requested information (if user accepted) - // - A decline indicator (if user declined) - // - A cancel indicator (if user cancelled) - Response ElicitationResponse `json:"response"` + ElicitationResponse } // ElicitationResponse represents the user's response to an elicitation request. type ElicitationResponse struct { - // Type indicates whether the user accepted, declined, or cancelled. - Type ElicitationResponseType `json:"type"` - // Value contains the user's response data if they accepted. + // Action indicates whether the user accepted, declined, or cancelled. + Action ElicitationResponseAction `json:"action"` + // Content contains the user's response data if they accepted. // Should conform to the requestedSchema from the ElicitationRequest. - Value any `json:"value,omitempty"` + Content any `json:"content,omitempty"` } -// ElicitationResponseType indicates how the user responded to an elicitation request. -type ElicitationResponseType string +// ElicitationResponseAction indicates how the user responded to an elicitation request. +type ElicitationResponseAction string const ( - // ElicitationResponseTypeAccept indicates the user provided the requested information. - ElicitationResponseTypeAccept ElicitationResponseType = "accept" - // ElicitationResponseTypeDecline indicates the user explicitly declined to provide information. - ElicitationResponseTypeDecline ElicitationResponseType = "decline" - // ElicitationResponseTypeCancel indicates the user cancelled without making a choice. - ElicitationResponseTypeCancel ElicitationResponseType = "cancel" + // ElicitationResponseActionAccept indicates the user provided the requested information. + ElicitationResponseActionAccept ElicitationResponseAction = "accept" + // ElicitationResponseActionDecline indicates the user explicitly declined to provide information. + ElicitationResponseActionDecline ElicitationResponseAction = "decline" + // ElicitationResponseActionCancel indicates the user cancelled without making a choice. + ElicitationResponseActionCancel ElicitationResponseAction = "cancel" ) /* Sampling */ diff --git a/server/elicitation_test.go b/server/elicitation_test.go index 47868f813..5356b3c25 100644 --- a/server/elicitation_test.go +++ b/server/elicitation_test.go @@ -117,9 +117,9 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { mockSession := &mockElicitationSession{ sessionID: "test-session", result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]any{ + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionAccept, + Content: map[string]any{ "projectName": "my-project", "framework": "react", }, @@ -155,11 +155,11 @@ func TestMCPServer_RequestElicitation_Success(t *testing.T) { return } - if result.Response.Type != mcp.ElicitationResponseTypeAccept { - t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseTypeAccept, result.Response.Type) + if result.Action != mcp.ElicitationResponseActionAccept { + t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseActionAccept, result.Action) } - value, ok := result.Response.Value.(map[string]any) + value, ok := result.Content.(map[string]any) if !ok { t.Error("expected value to be a map") return @@ -176,16 +176,16 @@ func TestRequestElicitation(t *testing.T) { session ClientSession request mcp.ElicitationRequest expectedError error - expectedType mcp.ElicitationResponseType + expectedType mcp.ElicitationResponseAction }{ { name: "successful elicitation with accept", session: &mockElicitationSession{ sessionID: "test-1", result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeAccept, - Value: map[string]any{ + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionAccept, + Content: map[string]any{ "name": "test-project", "framework": "react", }, @@ -204,15 +204,15 @@ func TestRequestElicitation(t *testing.T) { }, }, }, - expectedType: mcp.ElicitationResponseTypeAccept, + expectedType: mcp.ElicitationResponseActionAccept, }, { name: "elicitation declined by user", session: &mockElicitationSession{ sessionID: "test-2", result: &mcp.ElicitationResult{ - Response: mcp.ElicitationResponse{ - Type: mcp.ElicitationResponseTypeDecline, + ElicitationResponse: mcp.ElicitationResponse{ + Action: mcp.ElicitationResponseActionDecline, }, }, }, @@ -222,7 +222,7 @@ func TestRequestElicitation(t *testing.T) { RequestedSchema: map[string]any{"type": "object"}, }, }, - expectedType: mcp.ElicitationResponseTypeDecline, + expectedType: mcp.ElicitationResponseActionDecline, }, { name: "session does not support elicitation", @@ -252,10 +252,10 @@ func TestRequestElicitation(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) - assert.Equal(t, tt.expectedType, result.Response.Type) + assert.Equal(t, tt.expectedType, result.Action) - if tt.expectedType == mcp.ElicitationResponseTypeAccept { - assert.NotNil(t, result.Response.Value) + if tt.expectedType == mcp.ElicitationResponseActionAccept { + assert.NotNil(t, result.Action) } }) } From 23cee611929d3919a6d047455de01ca3f3e43098 Mon Sep 17 00:00:00 2001 From: Jebx Date: Fri, 19 Sep 2025 14:01:35 +0200 Subject: [PATCH 7/7] feat(streamable_http): elicitation request Author: Ghosthell --- client/client.go | 11 ++++ server/streamable_http.go | 102 ++++++++++++++++++++++++++++++++------ 2 files changed, 97 insertions(+), 16 deletions(-) diff --git a/client/client.go b/client/client.go index 21969789c..8c7a9adcc 100644 --- a/client/client.go +++ b/client/client.go @@ -478,6 +478,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS return c.handleSamplingRequestTransport(ctx, request) case string(mcp.MethodElicitationCreate): return c.handleElicitationRequestTransport(ctx, request) + case string(mcp.MethodPing): + return c.handlePingRequestTransport(ctx, request) default: return nil, fmt.Errorf("unsupported request method: %s", request.Method) } @@ -579,6 +581,15 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request return response, nil } +func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + b, _ := json.Marshal(&mcp.EmptyResult{}) + return &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: b, + }, nil +} + func listByPage[T any]( ctx context.Context, client *Client, diff --git a/server/streamable_http.go b/server/streamable_http.go index fe9b0d763..056dc876c 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -473,6 +473,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case elicitationReq := <-session.elicitationRequestChan: + // Send elicitation request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(elicitationReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: elicitationReq.request.Params, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -612,12 +627,6 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * } } else if responseMessage.Result != nil { // Parse result - var result mcp.CreateMessageResult - if err := json.Unmarshal(responseMessage.Result, &result); err != nil { - response.err = fmt.Errorf("failed to parse sampling result: %v", err) - } else { - response.result = &result - } } else { response.err = fmt.Errorf("sampling response has neither result nor error") } @@ -764,10 +773,17 @@ type samplingRequestItem struct { type samplingResponseItem struct { requestID int64 - result *mcp.CreateMessageResult + result json.RawMessage err error } +// Elicitation support types for HTTP transport +type elicitationRequestItem struct { + requestID int64 + request mcp.ElicitationRequest + response chan samplingResponseItem +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -779,18 +795,21 @@ type streamableHttpSession struct { logLevels *sessionLogLevelsStore // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests + + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), + elicitationRequestChan: make(chan elicitationRequestItem, 10), } return s } @@ -877,13 +896,63 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp if response.err != nil { return nil, response.err } - return response.result, nil + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// RequestElicitation implements SessionWithElicitation interface for HTTP transport +func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + elicitationRequest := elicitationRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.elicitationRequestChan <- elicitationRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("elicitation request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.ElicitationResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err) + } + return &result, nil case <-ctx.Done(): return nil, ctx.Err() } } var _ SessionWithSampling = (*streamableHttpSession)(nil) +var _ SessionWithElicitation = (*streamableHttpSession)(nil) // --- session id manager --- @@ -952,6 +1021,7 @@ func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption // - null // - empty object: {} // - empty array: [] +// // It also treats nil/whitespace-only input as empty. // It does NOT treat 0, false, "" or non-empty composites as empty. func isJSONEmpty(data json.RawMessage) bool {