From eea1130b9d02bb8ab102c8055eb269949b92f216 Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 00:53:43 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20Add=20docstrings=20to=20`main`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstrings generation was requested by @yuehaii. * https://github.com/mark3labs/mcp-go/pull/620#issuecomment-3430059681 The following files were modified: * `client/client.go` * `client/transport/inprocess.go` * `examples/roots_client/main.go` * `examples/roots_http_client/main.go` * `examples/roots_http_server/main.go` * `examples/roots_server/main.go` * `server/inprocess_session.go` * `server/server.go` * `server/streamable_http.go` --- client/client.go | 102 ++++++++++++++++++---- client/transport/inprocess.go | 18 +++- examples/roots_client/main.go | 136 +++++++++++++++++++++++++++++ examples/roots_http_client/main.go | 122 ++++++++++++++++++++++++++ examples/roots_http_server/main.go | 95 ++++++++++++++++++++ examples/roots_server/main.go | 82 +++++++++++++++++ server/inprocess_session.go | 30 ++++++- server/server.go | 20 ++++- server/streamable_http.go | 73 +++++++++++++++- 9 files changed, 652 insertions(+), 26 deletions(-) create mode 100644 examples/roots_client/main.go create mode 100644 examples/roots_http_client/main.go create mode 100644 examples/roots_http_server/main.go create mode 100644 examples/roots_server/main.go diff --git a/client/client.go b/client/client.go index 7542b7f77..269473b00 100644 --- a/client/client.go +++ b/client/client.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "slices" "sync" "sync/atomic" @@ -25,6 +24,7 @@ type Client struct { serverCapabilities mcp.ServerCapabilities protocolVersion string samplingHandler SamplingHandler + rootsHandler RootsHandler elicitationHandler ElicitationHandler } @@ -38,15 +38,26 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } // WithSamplingHandler sets the sampling handler for the client. -// When set, the client will declare sampling capability during initialization. +// WithSamplingHandler sets the SamplingHandler on the client and causes the client to declare sampling +// capability during Initialize. The provided `handler` will be invoked for incoming sampling requests. +// The returned ClientOption applies this handler to a Client. func WithSamplingHandler(handler SamplingHandler) ClientOption { return func(c *Client) { c.samplingHandler = handler } } +// WithRootsHandler sets the roots handler for the client. +// WithRootsHandler returns a ClientOption that sets the client's RootsHandler. +// When provided, the client will declare the roots capability (ListChanged) during initialization. +func WithRootsHandler(handler RootsHandler) ClientOption { + return func(c *Client) { + c.rootsHandler = handler + } +} + // WithElicitationHandler sets the elicitation handler for the client. -// When set, the client will declare elicitation capability during initialization. +// to declare elicitation capability during initialization. func WithElicitationHandler(handler ElicitationHandler) ClientOption { return func(c *Client) { c.elicitationHandler = handler @@ -141,7 +152,6 @@ func (c *Client) sendRequest( ctx context.Context, method string, params any, - header http.Header, ) (*json.RawMessage, error) { if !c.initialized && method != "initialize" { return nil, fmt.Errorf("client not initialized") @@ -154,7 +164,6 @@ func (c *Client) sendRequest( ID: mcp.NewRequestId(id), Method: method, Params: params, - Header: header, } response, err := c.transport.SendRequest(ctx, request) @@ -180,6 +189,13 @@ func (c *Client) Initialize( if c.samplingHandler != nil { capabilities.Sampling = &struct{}{} } + if c.rootsHandler != nil { + capabilities.Roots = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: true, + } + } // Add elicitation capability if handler is configured if c.elicitationHandler != nil { capabilities.Elicitation = &struct{}{} @@ -196,7 +212,7 @@ func (c *Client) Initialize( Capabilities: capabilities, } - response, err := c.sendRequest(ctx, "initialize", params, request.Header) + response, err := c.sendRequest(ctx, "initialize", params) if err != nil { return nil, err } @@ -241,7 +257,7 @@ func (c *Client) Initialize( } func (c *Client) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil, nil) + _, err := c.sendRequest(ctx, "ping", nil) return err } @@ -322,7 +338,7 @@ func (c *Client) ReadResource( ctx context.Context, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header) + response, err := c.sendRequest(ctx, "resources/read", request.Params) if err != nil { return nil, err } @@ -334,7 +350,7 @@ func (c *Client) Subscribe( ctx context.Context, request mcp.SubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header) + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) return err } @@ -342,7 +358,7 @@ func (c *Client) Unsubscribe( ctx context.Context, request mcp.UnsubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header) + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) return err } @@ -386,7 +402,7 @@ func (c *Client) GetPrompt( ctx context.Context, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header) + response, err := c.sendRequest(ctx, "prompts/get", request.Params) if err != nil { return nil, err } @@ -434,7 +450,7 @@ func (c *Client) CallTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header) + response, err := c.sendRequest(ctx, "tools/call", request.Params) if err != nil { return nil, err } @@ -446,7 +462,7 @@ func (c *Client) SetLevel( ctx context.Context, request mcp.SetLevelRequest, ) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header) + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) return err } @@ -454,7 +470,7 @@ func (c *Client) Complete( ctx context.Context, request mcp.CompleteRequest, ) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header) + response, err := c.sendRequest(ctx, "completion/complete", request.Params) if err != nil { return nil, err } @@ -467,6 +483,27 @@ func (c *Client) Complete( return &result, nil } +func (c *Client) RootListChanges( + ctx context.Context, +) error { + // Send root list changes notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: mcp.MethodNotificationToolsListChanged, + }, + } + + err := c.transport.SendNotification(ctx, notification) + if err != nil { + return fmt.Errorf( + "failed to send root list change notification: %w", + err, + ) + } + return nil +} + // handleIncomingRequest processes incoming requests from the server. // This is the main entry point for server-to-client requests like sampling and elicitation. func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { @@ -477,6 +514,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS return c.handleElicitationRequestTransport(ctx, request) case string(mcp.MethodPing): return c.handlePingRequestTransport(ctx, request) + case string(mcp.MethodListRoots): + return c.handleListRootsRequestTransport(ctx, request) default: return nil, fmt.Errorf("unsupported request method: %s", request.Method) } @@ -539,6 +578,37 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra return response, nil } +// handleListRootsRequestTransport handles list roots requests at the transport level. +func (c *Client) handleListRootsRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.rootsHandler == nil { + return nil, fmt.Errorf("no roots handler configured") + } + + // Create the MCP request + mcpRequest := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + // Call the list roots handler + result, err := c.rootsHandler.ListRoots(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := transport.NewJSONRPCResultResponse(request.ID, resultBytes) + + return response, nil +} + // handleElicitationRequestTransport handles elicitation requests at the transport level. func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { if c.elicitationHandler == nil { @@ -594,7 +664,7 @@ func listByPage[T any]( request mcp.PaginatedRequest, method string, ) (*T, error) { - response, err := client.sendRequest(ctx, method, request.Params, nil) + response, err := client.sendRequest(ctx, method, request.Params) if err != nil { return nil, err } @@ -635,4 +705,4 @@ func (c *Client) GetSessionId() string { // IsInitialized returns true if the client has been initialized. func (c *Client) IsInitialized() bool { return c.initialized -} +} \ No newline at end of file diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 467654265..74fd844eb 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -14,6 +14,7 @@ type InProcessTransport struct { server *server.MCPServer samplingHandler server.SamplingHandler elicitationHandler server.ElicitationHandler + rootsHandler server.RootsHandler session *server.InProcessSession sessionID string @@ -31,12 +32,23 @@ func WithSamplingHandler(handler server.SamplingHandler) InProcessOption { } } +// WithElicitationHandler returns an InProcessOption that sets the elicitation handler on an InProcessTransport. +// The provided handler will be used to handle elicitation requests for the in-process session when the transport is started. func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption { return func(t *InProcessTransport) { t.elicitationHandler = handler } } +// WithRootsHandler returns an InProcessOption that sets the transport's roots handler. +// The provided handler is assigned to the transport's rootsHandler field when the option is applied. +func WithRootsHandler(handler server.RootsHandler) InProcessOption { + return func(t *InProcessTransport) { + t.rootsHandler = handler + } +} + +// NewInProcessTransport creates an InProcessTransport that wraps the provided MCPServer with default (zero-value) configuration. func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { return &InProcessTransport{ server: server, @@ -66,8 +78,8 @@ func (c *InProcessTransport) Start(ctx context.Context) error { c.startedMu.Unlock() // Create and register session if we have handlers - if c.samplingHandler != nil || c.elicitationHandler != nil { - c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler) + if c.samplingHandler != nil || c.elicitationHandler != nil || c.rootsHandler != nil { + c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler, c.rootsHandler) if err := c.server.RegisterSession(ctx, c.session); err != nil { c.startedMu.Lock() c.started = false @@ -130,4 +142,4 @@ func (c *InProcessTransport) Close() error { func (c *InProcessTransport) GetSessionId() string { return "" -} +} \ No newline at end of file diff --git a/examples/roots_client/main.go b/examples/roots_client/main.go new file mode 100644 index 000000000..8d64bd433 --- /dev/null +++ b/examples/roots_client/main.go @@ -0,0 +1,136 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockRootsHandler implements client.RootsHandler for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockRootsHandler struct{} + +func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + result := &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: "app", + URI: "file:///User/haxxx/app", + }, + { + Name: "test-project", + URI: "file:///User/haxxx/projects/test-project", + }, + }, + } + return result, nil +} + +// main starts a mock MCP roots client that communicates with a subprocess over stdio. +// It expects the server command as the first command-line argument, creates a stdio +// transport and an MCP client with a MockRootsHandler, starts and initializes the +// client, logs server info and available tools, notifies the server of root list +// changes, invokes the "roots" tool and prints any text content returned, and +// shuts down the client gracefully on SIGINT or SIGTERM. +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: roots_client ") + } + + serverCommand := os.Args[1] + serverArgs := os.Args[2:] + + // Create stdio transport to communicate with the server + stdio := transport.NewStdio(serverCommand, nil, serverArgs...) + + // Create roots handler + rootsHandler := &MockRootsHandler{} + + // Create client with roots capability + mcpClient := client.NewClient(stdio, client.WithRootsHandler(rootsHandler)) + + ctx := context.Background() + + // Start the client + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create a context that cancels on signal + ctx, cancel := context.WithCancel(ctx) + go func() { + <-sigChan + log.Println("Received shutdown signal, closing client...") + cancel() + }() + + // Move defer after error checking + defer func() { + if err := mcpClient.Close(); err != nil { + log.Printf("Error closing client: %v", err) + } + }() + + // Initialize the connection + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "roots-stdio-server", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by WithSamplingHandler + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + log.Printf("Server capabilities: %+v", initResult.Capabilities) + + // list tools + toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + log.Printf("Available tools:") + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s", tool.Name, tool.Description) + } + + // mock the root change + if err := mcpClient.RootListChanges(ctx); err != nil { + log.Printf("fail to notify root list change: %v", err) + } + + // call server tool + request := mcp.CallToolRequest{} + request.Params.Name = "roots" + request.Params.Arguments = "{\"testonly\": \"yes\"}" + result, err := mcpClient.CallTool(ctx, request) + if err != nil { + log.Fatalf("failed to call tool roots: %v", err) + } else if len(result.Content) > 0 { + resultStr := "" + for _, content := range result.Content { + if textContent, ok := content.(mcp.TextContent); ok { + resultStr += fmt.Sprintf("%s\n", textContent.Text) + } + } + fmt.Printf("client call tool result: %s", resultStr) + } +} \ No newline at end of file diff --git a/examples/roots_http_client/main.go b/examples/roots_http_client/main.go new file mode 100644 index 000000000..fca64ea3d --- /dev/null +++ b/examples/roots_http_client/main.go @@ -0,0 +1,122 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockRootsHandler implements client.RootsHandler for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockRootsHandler struct{} + +func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + result := &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: "app", + URI: "file:///User/haxxx/app", + }, + { + Name: "test-project", + URI: "file:///User/haxxx/projects/test-project", + }, + }, + } + return result, nil +} + +// then waits for a shutdown signal. +func main() { + // Create roots handler + rootsHandler := &MockRootsHandler{} + + // Create HTTP transport directly + httpTransport, err := transport.NewStreamableHTTP( + "http://localhost:8080/mcp", // Replace with your MCP server URL + transport.WithContinuousListening(), + ) + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create client with roots support + mcpClient := client.NewClient( + httpTransport, + client.WithRootsHandler(rootsHandler), + ) + + // Start the client + ctx := context.Background() + err = mcpClient.Start(ctx) + if err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Initialize the MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{ + // Roots capability will be automatically added by the client + }, + ClientInfo: mcp.Implementation{ + Name: "roots-http-client", + Version: "1.0.0", + }, + }, + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Println("HTTP MCP client with roots support started successfully!") + log.Println("The client is now ready to handle roots requests from the server.") + log.Println("When the server sends a roots request, the MockRootsHandler will process it.") + + // In a real application, you would keep the client running to handle roots requests + // For this example, we'll just demonstrate that it's working + + // mock the root change + if err := mcpClient.RootListChanges(ctx); err != nil { + log.Printf("fail to notify root list change: %v", err) + } + + // call server tool + request := mcp.CallToolRequest{} + request.Params.Name = "roots" + request.Params.Arguments = "{\"testonly\": \"yes\"}" + result, err := mcpClient.CallTool(ctx, request) + if err != nil { + log.Fatalf("failed to call tool roots: %v", err) + } else if len(result.Content) > 0 { + resultStr := "" + for _, content := range result.Content { + if textContent, ok := content.(mcp.TextContent); ok { + resultStr += fmt.Sprintf("%s\n", textContent.Text) + } + } + fmt.Printf("client call tool result: %s", resultStr) + } + + // Keep the client running (in a real app, you'd have your main application logic here) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + select { + case <-ctx.Done(): + log.Println("Client context cancelled") + case <-sigChan: + log.Println("Received shutdown signal") + } +} \ No newline at end of file diff --git a/examples/roots_http_server/main.go b/examples/roots_http_server/main.go new file mode 100644 index 000000000..239c95727 --- /dev/null +++ b/examples/roots_http_server/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleNotification prints the method name of the received MCP JSON-RPC notification to standard output. +func handleNotification(ctx context.Context, notification mcp.JSONRPCNotification) { + fmt.Printf("notification received: %v", notification.Notification.Method) +} + +// main starts an MCP HTTP server named "roots-http-server" with tool capabilities and roots support. +// It registers a notification handler for ToolsListChanged, adds a "roots" tool that queries the server's roots and returns a textual result, +// logs startup and usage instructions, and launches a streamable HTTP server on port 8080. +func main() { + // Enable roots capability + opts := []server.ServerOption{ + server.WithToolCapabilities(true), + server.WithRoots(), + } + // Create MCP server with roots capability + mcpServer := server.NewMCPServer("roots-http-server", "1.0.0", opts...) + + // Add list root list change notification + mcpServer.AddNotificationHandler(mcp.MethodNotificationToolsListChanged, handleNotification) + + // Add a simple tool to test roots list + mcpServer.AddTool(mcp.Tool{ + Name: "roots", + Description: "list root result", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "testonly": map[string]any{ + "type": "string", + "description": "is this test only?", + }, + }, + Required: []string{"testonly"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + rootRequest := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + if result, err := mcpServer.RequestRoots(ctx, rootRequest); err == nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Root list: %v", result.Roots), + }, + }, + }, nil + + } else { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Fail to list root, %v", err), + }, + }, + }, err + } + }) + + log.Println("Starting MCP Http server with roots support") + log.Println("Http Endpoint: http://localhost:8080/mcp") + log.Println("") + log.Println("This server supports roots over HTTP transport.") + log.Println("Clients must:") + log.Println("1. Initialize with roots capability") + log.Println("2. Establish SSE connection for bidirectional communication") + log.Println("3. Handle incoming roots requests from the server") + log.Println("4. Send responses back via HTTP POST") + log.Println("") + log.Println("Available tools:") + log.Println("- roots: Send back the list root request)") + + // Create HTTP server + httpOpts := []server.StreamableHTTPOption{} + httpServer := server.NewStreamableHTTPServer(mcpServer, httpOpts...) + fmt.Printf("Starting HTTP server\n") + if err := httpServer.Start(":8080"); err != nil { + fmt.Printf("HTTP server failed: %v\n", err) + } +} \ No newline at end of file diff --git a/examples/roots_server/main.go b/examples/roots_server/main.go new file mode 100644 index 000000000..dd5f49baa --- /dev/null +++ b/examples/roots_server/main.go @@ -0,0 +1,82 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleNotification handles JSON-RPC notifications by printing the notification method to standard output. +func handleNotification(ctx context.Context, notification mcp.JSONRPCNotification) { + fmt.Printf("notification received: %v", notification.Notification.Method) +} + +// main sets up and runs an MCP stdio server named "roots-stdio-server" with tool and roots capabilities. +// +// It registers a handler for ToolsListChanged notifications, enables sampling, and adds a "roots" tool +// that requests and returns the current root list. The program serves the MCP server over stdio and +// logs a fatal error if the server fails to start. +func main() { + // Enable roots capability + opts := []server.ServerOption{ + server.WithToolCapabilities(true), + server.WithRoots(), + } + // Create MCP server with roots capability + mcpServer := server.NewMCPServer("roots-stdio-server", "1.0.0", opts...) + + // Add list root list change notification + mcpServer.AddNotificationHandler(mcp.MethodNotificationToolsListChanged, handleNotification) + mcpServer.EnableSampling() + + // Add a simple tool to test roots list + mcpServer.AddTool(mcp.Tool{ + Name: "roots", + Description: "list root result", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "testonly": map[string]any{ + "type": "string", + "description": "is this test only?", + }, + }, + Required: []string{"testonly"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + rootRequest := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + if result, err := mcpServer.RequestRoots(ctx, rootRequest); err == nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Root list: %v", result.Roots), + }, + }, + }, nil + + } else { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Fail to list root, %v", err), + }, + }, + }, err + } + }) + + // Create stdio server + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server Stdio error: %v\n", err) + } +} \ No newline at end of file diff --git a/server/inprocess_session.go b/server/inprocess_session.go index c6fddc601..ce73dee15 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -20,6 +20,10 @@ type ElicitationHandler interface { Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) } +type RootsHandler interface { + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} + type InProcessSession struct { sessionID string notifications chan mcp.JSONRPCNotification @@ -29,9 +33,12 @@ type InProcessSession struct { clientCapabilities atomic.Value samplingHandler SamplingHandler elicitationHandler ElicitationHandler + rootsHandler RootsHandler mu sync.RWMutex } +// NewInProcessSession creates a new InProcessSession for the provided sessionID and samplingHandler. +// The returned session has a buffered notifications channel and its sampling handler set; elicitation and roots handlers remain unset. func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { return &InProcessSession{ sessionID: sessionID, @@ -40,12 +47,16 @@ func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InP } } -func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler) *InProcessSession { +// NewInProcessSessionWithHandlers creates an InProcessSession with the given session ID and handler implementations. +// The session is created with a buffered notifications channel (capacity 100) and the provided sampling, elicitation, +// and roots handlers attached. +func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler, rootsHandler RootsHandler) *InProcessSession { return &InProcessSession{ sessionID: sessionID, notifications: make(chan mcp.JSONRPCNotification, 100), samplingHandler: samplingHandler, elicitationHandler: elicitationHandler, + rootsHandler: rootsHandler, } } @@ -128,7 +139,19 @@ func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.E return handler.Elicit(ctx, request) } -// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func (s *InProcessSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + s.mu.RLock() + handler := s.rootsHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no roots handler available") + } + + return handler.ListRoots(ctx, request) +} + +// GenerateInProcessSessionID returns a session identifier formatted as "inprocess-", where is the current Unix time in nanoseconds, suitable for use as a unique in-process session ID. func GenerateInProcessSessionID() string { return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) } @@ -140,4 +163,5 @@ var ( _ SessionWithClientInfo = (*InProcessSession)(nil) _ SessionWithSampling = (*InProcessSession)(nil) _ SessionWithElicitation = (*InProcessSession)(nil) -) + _ SessionWithRoots = (*InProcessSession)(nil) +) \ No newline at end of file diff --git a/server/server.go b/server/server.go index f45c03536..c47f1b511 100644 --- a/server/server.go +++ b/server/server.go @@ -183,6 +183,7 @@ type serverCapabilities struct { logging *bool sampling *bool elicitation *bool + roots *bool } // resourceCapabilities defines the supported resource-related features @@ -319,14 +320,23 @@ func WithLogging() ServerOption { } } -// WithElicitation enables elicitation capabilities for the server +// WithElicitation returns a ServerOption that enables the server's elicitation capability. +// When applied to an MCPServer, it sets the server's capabilities.elicitation flag to true. func WithElicitation() ServerOption { return func(s *MCPServer) { s.capabilities.elicitation = mcp.ToBoolPtr(true) } } -// WithInstructions sets the server instructions for the client returned in the initialize response +// WithRoots returns a ServerOption that enables the roots capability on the MCPServer. +func WithRoots() ServerOption { + return func(s *MCPServer) { + s.capabilities.roots = mcp.ToBoolPtr(true) + } +} + +// WithInstructions returns a ServerOption that sets the server instructions sent to clients in the initialize response. +// The provided instructions string is stored on the MCPServer and included in InitializeResult.Instructions. func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { s.instructions = instructions @@ -696,6 +706,10 @@ func (s *MCPServer) handleInitialize( capabilities.Elicitation = &struct{}{} } + if s.capabilities.roots != nil && *s.capabilities.roots { + capabilities.Roots = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ @@ -1270,4 +1284,4 @@ func createErrorResponse( ID: mcp.NewRequestId(id), Error: mcp.NewJSONRPCErrorDetails(code, message, nil), } -} +} \ No newline at end of file diff --git a/server/streamable_http.go b/server/streamable_http.go index 8af6f1478..9fd2d6482 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -552,6 +552,20 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case rootsReq := <-session.rootsRequestChan: + // Send list roots request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(rootsReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -887,6 +901,13 @@ type elicitationRequestItem struct { response chan samplingResponseItem } +// Roots support types for HTTP transport +type rootsRequestItem struct { + requestID int64 + request mcp.ListRootsRequest + response chan samplingResponseItem +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -901,11 +922,14 @@ type streamableHttpSession struct { // Sampling support for bidirectional communication samplingRequestChan chan samplingRequestItem // server -> client sampling requests elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests + rootsRequestChan chan rootsRequestItem // server -> client list roots requests samplingRequests sync.Map // requestID -> pending sampling request context requestIDCounter atomic.Int64 // for generating unique request IDs } +// newStreamableHttpSession creates and returns a streamableHttpSession initialized with the given session ID and per-session stores. +// The returned session has buffered channels for notifications, sampling, elicitation, and roots requests and holds references to the provided tools, resources, and log level stores. func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ sessionID: sessionID, @@ -915,6 +939,7 @@ func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, re logLevels: levels, samplingRequestChan: make(chan samplingRequestItem, 10), elicitationRequestChan: make(chan elicitationRequestItem, 10), + rootsRequestChan: make(chan rootsRequestItem, 10), } return s } @@ -1031,6 +1056,51 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp } } +// ListRoots implements SessionWithRoots interface for HTTP transport +func (s *streamableHttpSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + rootsRequest := rootsRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.rootsRequestChan <- rootsRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("list roots request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.ListRootsResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal list roots response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // RequestElicitation implements SessionWithElicitation interface for HTTP transport func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { // Generate unique request ID @@ -1078,6 +1148,7 @@ func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request var _ SessionWithSampling = (*streamableHttpSession)(nil) var _ SessionWithElicitation = (*streamableHttpSession)(nil) +var _ SessionWithRoots = (*streamableHttpSession)(nil) // --- session id manager --- @@ -1207,4 +1278,4 @@ func isJSONEmpty(data json.RawMessage) bool { trimmed[3] == 'l' } return false -} +} \ No newline at end of file