diff --git a/.gitignore b/.gitignore index 1d4dcd5cb..7862cf9bd 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .claude coverage.out coverage.txt +.vscode/launch.json diff --git a/client/client.go b/client/client.go index 929785cd8..59b9651f8 100644 --- a/client/client.go +++ b/client/client.go @@ -24,6 +24,7 @@ type Client struct { serverCapabilities mcp.ServerCapabilities protocolVersion string samplingHandler SamplingHandler + rootsHandler RootsHandler elicitationHandler ElicitationHandler } @@ -44,6 +45,15 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption { } } +// WithRootsHandler sets the roots handler for the client. +// WithRootsHandler returns a ClientOption that sets the client's RootsHandler. +// When provided, the client will declare the roots capability (ListChanged) during initialization. +func WithRootsHandler(handler RootsHandler) ClientOption { + return func(c *Client) { + c.rootsHandler = handler + } +} + // WithElicitationHandler sets the elicitation handler for the client. // When set, the client will declare elicitation capability during initialization. func WithElicitationHandler(handler ElicitationHandler) ClientOption { @@ -177,6 +187,13 @@ func (c *Client) Initialize( if c.samplingHandler != nil { capabilities.Sampling = &struct{}{} } + if c.rootsHandler != nil { + capabilities.Roots = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: true, + } + } // Add elicitation capability if handler is configured if c.elicitationHandler != nil { capabilities.Elicitation = &struct{}{} @@ -464,6 +481,28 @@ func (c *Client) Complete( return &result, nil } +// RootListChanges sends a roots list-changed notification to the server. +func (c *Client) RootListChanges( + ctx context.Context, +) error { + // Send root list changes notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: mcp.MethodNotificationRootsListChanged, + }, + } + + err := c.transport.SendNotification(ctx, notification) + if err != nil { + return fmt.Errorf( + "failed to send root list change notification: %w", + err, + ) + } + return nil +} + // handleIncomingRequest processes incoming requests from the server. // This is the main entry point for server-to-client requests like sampling and elicitation. func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { @@ -474,6 +513,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS return c.handleElicitationRequestTransport(ctx, request) case string(mcp.MethodPing): return c.handlePingRequestTransport(ctx, request) + case string(mcp.MethodListRoots): + return c.handleListRootsRequestTransport(ctx, request) default: return nil, fmt.Errorf("unsupported request method: %s", request.Method) } @@ -536,6 +577,37 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra return response, nil } +// handleListRootsRequestTransport handles list roots requests at the transport level. +func (c *Client) handleListRootsRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.rootsHandler == nil { + return nil, fmt.Errorf("no roots handler configured") + } + + // Create the MCP request + mcpRequest := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + // Call the list roots handler + result, err := c.rootsHandler.ListRoots(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(resultBytes)) + + return response, nil +} + // handleElicitationRequestTransport handles elicitation requests at the transport level. func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { if c.elicitationHandler == nil { diff --git a/client/roots.go b/client/roots.go new file mode 100644 index 000000000..0a17aaf7a --- /dev/null +++ b/client/roots.go @@ -0,0 +1,17 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// RootsHandler defines the interface for handling roots requests from servers. +// Clients can implement this interface to provide roots list to servers. +type RootsHandler interface { + // ListRoots handles a list root request from the server and returns the roots list. + // The implementation should: + // 1. Validate input against the requested schema + // 2. Return the appropriate response + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 467654265..fe17d97f7 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -14,6 +14,7 @@ type InProcessTransport struct { server *server.MCPServer samplingHandler server.SamplingHandler elicitationHandler server.ElicitationHandler + rootsHandler server.RootsHandler session *server.InProcessSession sessionID string @@ -37,6 +38,12 @@ func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption { } } +func WithRootsHandler(handler server.RootsHandler) InProcessOption { + return func(t *InProcessTransport) { + t.rootsHandler = handler + } +} + func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { return &InProcessTransport{ server: server, @@ -66,8 +73,8 @@ func (c *InProcessTransport) Start(ctx context.Context) error { c.startedMu.Unlock() // Create and register session if we have handlers - if c.samplingHandler != nil || c.elicitationHandler != nil { - c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler) + if c.samplingHandler != nil || c.elicitationHandler != nil || c.rootsHandler != nil { + c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler, c.rootsHandler) if err := c.server.RegisterSession(ctx, c.session); err != nil { c.startedMu.Lock() c.started = false diff --git a/examples/roots_client/main.go b/examples/roots_client/main.go new file mode 100644 index 000000000..325be98e7 --- /dev/null +++ b/examples/roots_client/main.go @@ -0,0 +1,163 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// fileURI returns a file:// URI for both Unix and Windows absolute paths. +func fileURI(p string) string { + p = filepath.ToSlash(p) + if !strings.HasPrefix(p, "/") { // e.g., "C:/Users/..." on Windows + p = "/" + p + } + return (&url.URL{Scheme: "file", Path: p}).String() +} + +// MockRootsHandler implements client.RootsHandler for demonstration. +// In a real implementation, this would enumerate workspace/project roots. +type MockRootsHandler struct{} + +// ListRoots implements client.RootsHandler by returning example workspace roots. +func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + home, err := os.UserHomeDir() + if err != nil { + log.Printf("Warning: failed to get home directory: %v", err) + home = "/tmp" // fallback for demonstration + } + app := filepath.ToSlash(filepath.Join(home, "app")) + proj := filepath.ToSlash(filepath.Join(home, "projects", "test-project")) + result := &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: "app", + URI: fileURI(app), + }, + { + Name: "test-project", + URI: fileURI(proj), + }, + }, + } + return result, nil +} + +// main starts a mock MCP roots client that communicates with a subprocess over stdio. +// It expects the server command as the first command-line argument, creates a stdio +// transport and an MCP client with a MockRootsHandler, starts and initializes the +// client, logs server info and available tools, notifies the server of root list +// changes, invokes the "roots" tool and prints any text content returned, and +// shuts down the client gracefully on SIGINT or SIGTERM. +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: roots_client ") + } + + serverCommand := os.Args[1] + serverArgs := os.Args[2:] + + // Create stdio transport to communicate with the server + stdio := transport.NewStdio(serverCommand, nil, serverArgs...) + + // Create roots handler + rootsHandler := &MockRootsHandler{} + + // Create client with roots capability + mcpClient := client.NewClient(stdio, client.WithRootsHandler(rootsHandler)) + + ctx := context.Background() + + // Start the client + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create a context that cancels on signal + ctx, cancel := context.WithCancel(ctx) + go func() { + <-sigChan + log.Println("Received shutdown signal, closing client...") + cancel() + }() + + // Move defer after error checking + defer func() { + if err := mcpClient.Close(); err != nil { + log.Printf("Error closing client: %v", err) + } + }() + + // Initialize the connection + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "roots-stdio-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Roots capability will be automatically added by WithRootsHandler + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + log.Printf("Server capabilities: %+v", initResult.Capabilities) + + // list tools + toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + log.Printf("Available tools:") + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s", tool.Name, tool.Description) + } + + // call server tool + request := mcp.CallToolRequest{} + request.Params.Name = "roots" + request.Params.Arguments = map[string]any{"testonly": "yes"} + result, err := mcpClient.CallTool(ctx, request) + if err != nil { + log.Fatalf("failed to call tool roots: %v", err) + } else if result.IsError { + log.Printf("tool reported error") + } else if len(result.Content) > 0 { + resultStr := "" + for _, content := range result.Content { + switch tc := content.(type) { + case mcp.TextContent: + resultStr += fmt.Sprintf("%s\n", tc.Text) + } + } + fmt.Printf("client call tool result: %s\n", resultStr) + } + + // mock the root change + if err := mcpClient.RootListChanges(ctx); err != nil { + log.Printf("failed to notify root list change: %v", err) + } + + // Keep running until cancelled by signal + <-ctx.Done() + log.Println("Client context cancelled") +} diff --git a/examples/roots_http_client/main.go b/examples/roots_http_client/main.go new file mode 100644 index 000000000..783378336 --- /dev/null +++ b/examples/roots_http_client/main.go @@ -0,0 +1,149 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// fileURI returns a file:// URI for both Unix and Windows absolute paths. +func fileURI(p string) string { + p = filepath.ToSlash(p) + if !strings.HasPrefix(p, "/") { // e.g., "C:/Users/..." on Windows + p = "/" + p + } + return (&url.URL{Scheme: "file", Path: p}).String() +} + +// MockRootsHandler implements client.RootsHandler for demonstration. +// In a real implementation, this would enumerate workspace/project roots. +type MockRootsHandler struct{} + +// ListRoots implements client.RootsHandler by returning example workspace roots. +func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + home, err := os.UserHomeDir() + if err != nil { + log.Printf("Warning: failed to get home directory: %v", err) + home = "/tmp" // fallback for demonstration + } + app := filepath.ToSlash(filepath.Join(home, "app")) + proj := filepath.ToSlash(filepath.Join(home, "projects", "test-project")) + result := &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: "app", + URI: fileURI(app), + }, + { + Name: "test-project", + URI: fileURI(proj), + }, + }, + } + return result, nil +} + +// main starts a mock MCP roots client over HTTP. +// The server tool triggers a roots/list request on the client. +// The client shuts down gracefully on SIGINT or SIGTERM. +func main() { + // Create roots handler + rootsHandler := &MockRootsHandler{} + + // Create HTTP transport directly + httpTransport, err := transport.NewStreamableHTTP( + "http://localhost:8080/mcp", // Replace with your MCP server URL + transport.WithContinuousListening(), + ) + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + + // Create client with roots support + mcpClient := client.NewClient( + httpTransport, + client.WithRootsHandler(rootsHandler), + ) + + // Start the client + ctx := context.Background() + err = mcpClient.Start(ctx) + if err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer func() { + if cerr := mcpClient.Close(); cerr != nil { + log.Printf("Error closing client: %v", cerr) + } + }() + + // Initialize the MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{ + // Roots capability will be automatically added by the client + }, + ClientInfo: mcp.Implementation{ + Name: "roots-http-client", + Version: "1.0.0", + }, + }, + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Println("HTTP MCP client with roots support started successfully!") + log.Println("The client is now ready to handle roots requests from the server.") + log.Println("When the server sends a roots request, the MockRootsHandler will process it.") + + // In a real application, you would keep the client running to handle roots requests + // For this example, we'll just demonstrate that it's working + + // mock the root change + if err := mcpClient.RootListChanges(ctx); err != nil { + log.Printf("failed to notify root list change: %v", err) + } + + // call server tool + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "roots", + Arguments: map[string]any{}, + }, + } + result, err := mcpClient.CallTool(ctx, request) + if err != nil { + log.Fatalf("failed to call tool roots: %v", err) + } else if result.IsError { + log.Printf("tool reported error") + } else if len(result.Content) > 0 { + resultStr := "" + for _, content := range result.Content { + switch tc := content.(type) { + case mcp.TextContent: + resultStr += fmt.Sprintf("%s\n", tc.Text) + } + } + fmt.Printf("client call tool result: %s\n", resultStr) + } + + // Keep the client running (in a real app, you'd have your main application logic here) + waitCtx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer stop() + <-waitCtx.Done() + log.Println("Received shutdown signal") +} diff --git a/examples/roots_http_server/main.go b/examples/roots_http_server/main.go new file mode 100644 index 000000000..f3daa84ed --- /dev/null +++ b/examples/roots_http_server/main.go @@ -0,0 +1,84 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleNotification prints the method name of the received MCP JSON-RPC notification to standard output. +func handleNotification(ctx context.Context, notification mcp.JSONRPCNotification) { + fmt.Printf("notification received: %v", notification.Method) +} + +// main starts an MCP HTTP server named "roots-http-server" with tool capabilities and roots support. +// It registers a notification handler for ToolsListChanged, adds a "roots" tool that queries the server's roots and returns a textual result, +// logs startup and usage instructions, and launches a streamable HTTP server on port 8080. +func main() { + // Enable roots capability + opts := []server.ServerOption{ + server.WithToolCapabilities(true), + server.WithRoots(), + } + // Create MCP server with roots capability + mcpServer := server.NewMCPServer("roots-http-server", "1.0.0", opts...) + + // Add list root list change notification + mcpServer.AddNotificationHandler(mcp.MethodNotificationRootsListChanged, handleNotification) + + // Add a simple tool to test roots list + mcpServer.AddTool(mcp.Tool{ + Name: "roots", + Description: "list root result", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "testonly": map[string]any{ + "type": "string", + "description": "is this test only?", + }, + }, + Required: []string{"testonly"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + rootRequest := mcp.ListRootsRequest{} + + if result, err := mcpServer.RequestRoots(ctx, rootRequest); err == nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Root list: %v", result.Roots), + }, + }, + }, nil + + } else { + return nil, err + } + }) + + log.Println("Starting MCP Http server with roots support") + log.Println("Http Endpoint: http://localhost:8080/mcp") + log.Println("") + log.Println("This server supports roots over HTTP transport.") + log.Println("Clients must:") + log.Println("1. Initialize with roots capability") + log.Println("2. Establish SSE connection for bidirectional communication") + log.Println("3. Handle incoming roots requests from the server") + log.Println("4. Send responses back via HTTP POST") + log.Println("") + log.Println("Available tools:") + log.Println("- roots: Send back the list root request)") + + // Create HTTP server + httpOpts := []server.StreamableHTTPOption{} + httpServer := server.NewStreamableHTTPServer(mcpServer, httpOpts...) + fmt.Printf("Starting HTTP server\n") + if err := httpServer.Start(":8080"); err != nil { + fmt.Printf("HTTP server failed: %v\n", err) + } +} diff --git a/examples/roots_server/main.go b/examples/roots_server/main.go new file mode 100644 index 000000000..9aea5dd9c --- /dev/null +++ b/examples/roots_server/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleNotification handles JSON-RPC notifications by printing the notification method to standard output. +func handleNotification(ctx context.Context, notification mcp.JSONRPCNotification) { + fmt.Printf("notification received: %v\n", notification.Method) +} + +// main sets up and runs an MCP stdio server named "roots-stdio-server" with tool and roots capabilities. +// It registers a handler for RootsListChanged notifications and adds a "roots" tool +// that requests and returns the current roots list. The program serves the MCP server over stdio and +// logs a fatal error if the server fails to start. +func main() { + // Enable roots capability + opts := []server.ServerOption{ + server.WithToolCapabilities(true), + server.WithRoots(), + } + // Create MCP server with roots capability + mcpServer := server.NewMCPServer("roots-stdio-server", "1.0.0", opts...) + + // Register roots list-change notification handler + mcpServer.AddNotificationHandler(mcp.MethodNotificationRootsListChanged, handleNotification) + + // Add a simple tool to test roots list + mcpServer.AddTool(mcp.Tool{ + Name: "roots", + Description: "Requests and returns the current list of roots from the connected client", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + rootRequest := mcp.ListRootsRequest{} + + if result, err := mcpServer.RequestRoots(ctx, rootRequest); err == nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Root list: %v", result.Roots), + }, + }, + StructuredContent: map[string]any{ + "roots": result.Roots, + }, + }, nil + + } else { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Failed to list roots: %v", err), + }, + }, + IsError: true, + }, nil + } + }) + + // Create stdio server + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server Stdio error: %v\n", err) + } +} diff --git a/mcp/types.go b/mcp/types.go index 0f97821b4..1d147a826 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -59,6 +59,10 @@ const ( // https://modelcontextprotocol.io/docs/concepts/elicitation MethodElicitationCreate MCPMethod = "elicitation/create" + // MethodListRoots requests roots list from the client during interactions. + // https://modelcontextprotocol.io/specification/2025-06-18/client/roots + MethodListRoots MCPMethod = "roots/list" + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification MethodNotificationResourcesListChanged = "notifications/resources/list_changed" @@ -70,8 +74,12 @@ const ( MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" // MethodNotificationToolsListChanged notifies when the list of available tools changes. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#list-changed-notification MethodNotificationToolsListChanged = "notifications/tools/list_changed" + + // MethodNotificationRootsListChanged notifies when the list of available roots changes. + // https://modelcontextprotocol.io/specification/2025-06-18/client/roots#root-list-changes + MethodNotificationRootsListChanged = "notifications/roots/list_changed" ) type URITemplate struct { @@ -515,6 +523,8 @@ type ServerCapabilities struct { } `json:"tools,omitempty"` // Present if the server supports elicitation requests to the client. Elicitation *struct{} `json:"elicitation,omitempty"` + // Present if the server supports roots requests to the client. + Roots *struct{} `json:"roots,omitempty"` } // Implementation describes the name and version of an MCP implementation. @@ -1143,7 +1153,6 @@ type PromptReference struct { // structure or access specific locations that the client has permission to read from. type ListRootsRequest struct { Request - Header http.Header `json:"-"` } // ListRootsResult is the client's response to a roots/list request from the server. diff --git a/server/inprocess_session.go b/server/inprocess_session.go index c6fddc601..59ab0f366 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -20,6 +20,11 @@ type ElicitationHandler interface { Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) } +// RootsHandler defines the interface for handling roots list requests from servers. +type RootsHandler interface { + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} + type InProcessSession struct { sessionID string notifications chan mcp.JSONRPCNotification @@ -29,6 +34,7 @@ type InProcessSession struct { clientCapabilities atomic.Value samplingHandler SamplingHandler elicitationHandler ElicitationHandler + rootsHandler RootsHandler mu sync.RWMutex } @@ -40,12 +46,13 @@ func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InP } } -func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler) *InProcessSession { +func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler, rootsHandler RootsHandler) *InProcessSession { return &InProcessSession{ sessionID: sessionID, notifications: make(chan mcp.JSONRPCNotification, 100), samplingHandler: samplingHandler, elicitationHandler: elicitationHandler, + rootsHandler: rootsHandler, } } @@ -128,6 +135,20 @@ func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.E return handler.Elicit(ctx, request) } +// ListRoots sends a list roots request to the client and waits for the response. +// Returns an error if no roots handler is available. +func (s *InProcessSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + s.mu.RLock() + handler := s.rootsHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no roots handler available") + } + + return handler.ListRoots(ctx, request) +} + // GenerateInProcessSessionID generates a unique session ID for inprocess clients func GenerateInProcessSessionID() string { return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) @@ -140,4 +161,5 @@ var ( _ SessionWithClientInfo = (*InProcessSession)(nil) _ SessionWithSampling = (*InProcessSession)(nil) _ SessionWithElicitation = (*InProcessSession)(nil) + _ SessionWithRoots = (*InProcessSession)(nil) ) diff --git a/server/roots.go b/server/roots.go new file mode 100644 index 000000000..29e0b94d1 --- /dev/null +++ b/server/roots.go @@ -0,0 +1,32 @@ +package server + +import ( + "context" + "errors" + + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + // ErrNoClientSession is returned when there is no active client session in the context + ErrNoClientSession = errors.New("no active client session") + // ErrRootsNotSupported is returned when the session does not support roots + ErrRootsNotSupported = errors.New("session does not support roots") +) + +// RequestRoots sends an list roots request to the client. +// The client must have declared roots capability during initialization. +// The session must implement SessionWithRoots to support this operation. +func (s *MCPServer) RequestRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, ErrNoClientSession + } + + // Check if the session supports roots requests + if rootsSession, ok := session.(SessionWithRoots); ok { + return rootsSession.ListRoots(ctx, request) + } + + return nil, ErrRootsNotSupported +} diff --git a/server/roots_test.go b/server/roots_test.go new file mode 100644 index 000000000..c024cad3d --- /dev/null +++ b/server/roots_test.go @@ -0,0 +1,240 @@ +package server + +import ( + "context" + "errors" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockBasicRootsSession implements ClientSession for testing (without roots support) +type mockBasicRootsSession struct { + sessionID string +} + +func (m *mockBasicRootsSession) SessionID() string { + return m.sessionID +} + +func (m *mockBasicRootsSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockBasicRootsSession) Initialize() {} + +func (m *mockBasicRootsSession) Initialized() bool { + return true +} + +// mockRootsSession implements SessionWithRoots for testing +type mockRootsSession struct { + sessionID string + result *mcp.ListRootsResult + err error +} + +func (m *mockRootsSession) SessionID() string { + return m.sessionID +} + +func (m *mockRootsSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockRootsSession) Initialize() {} + +func (m *mockRootsSession) Initialized() bool { + return true +} + +func (m *mockRootsSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestMCPServer_RequestRoots_NoSession(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.capabilities.roots = mcp.ToBoolPtr(true) + + request := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + _, err := server.RequestRoots(context.Background(), request) + + if err == nil { + t.Error("expected error when no session available") + } + + if !errors.Is(err, ErrNoClientSession) { + t.Errorf("expected ErrNoClientSession, got %v", err) + } +} + +func TestMCPServer_RequestRoots_SessionDoesNotSupportRoots(t *testing.T) { + server := NewMCPServer("test", "1.0.0", WithRoots()) + + // Use a regular session that doesn't implement SessionWithRoots + mockSession := &mockBasicRootsSession{sessionID: "test-session"} + + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + _, err := server.RequestRoots(ctx, request) + + if err == nil { + t.Error("expected error when session doesn't support roots") + } + + if !errors.Is(err, ErrRootsNotSupported) { + t.Errorf("expected ErrRootsNotSupported, got %v", err) + } +} + +func TestMCPServer_RequestRoots_Success(t *testing.T) { + opts := []ServerOption{ + WithRoots(), + } + server := NewMCPServer("test", "1.0.0", opts...) + + // Create a mock roots session + mockSession := &mockRootsSession{ + sessionID: "test-session", + result: &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: ".kube", + URI: "file:///User/haxxx/.kube", + }, + { + Name: "project", + URI: "file:///User/haxxx/projects/snative", + }, + }, + }, + } + + // Create context with session + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + + result, err := server.RequestRoots(ctx, request) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result == nil { + t.Error("expected result, got nil") + return + } + + if len(result.Roots) == 0 { + t.Error("roots result is empty") + return + } + + for _, value := range result.Roots { + if value.Name != "project" && value.Name != ".kube" { + t.Errorf("expected root name %q, %q, got %q", "project", ".kube", value.Name) + } + if value.URI != "file:///User/haxxx/.kube" && value.URI != "file:///User/haxxx/projects/snative" { + t.Errorf("expected root URI %q, %q, got %q", "file:///User/haxxx/.kube", "file:///User/haxxx/projects/snative", value.URI) + } + } +} + +func TestRequestRoots(t *testing.T) { + tests := []struct { + name string + session ClientSession + request mcp.ListRootsRequest + expectedError error + }{ + { + name: "successful roots with name and uri", + session: &mockRootsSession{ + sessionID: "test-1", + result: &mcp.ListRootsResult{ + Roots: []mcp.Root{ + { + Name: ".kube", + URI: "file:///User/haxxx/.kube", + }, + { + Name: "project", + URI: "file:///User/haxxx/projects/snative", + }, + }, + }, + }, + request: mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + }, + }, + { + name: "successful roots with empty list", + session: &mockRootsSession{ + sessionID: "test-2", + result: &mcp.ListRootsResult{ + Roots: []mcp.Root{}, + }, + }, + request: mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + }, + }, + { + name: "session does not support roots", + session: &fakeSession{sessionID: "test-3"}, + request: mcp.ListRootsRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + }, + expectedError: ErrRootsNotSupported, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test", "1.0", WithRoots()) + ctx := server.WithContext(context.Background(), tt.session) + + result, err := server.RequestRoots(ctx, tt.request) + + if tt.expectedError != nil { + require.Error(t, err) + assert.True(t, errors.Is(err, tt.expectedError), "expected %v, got %v", tt.expectedError, err) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + }) + } +} diff --git a/server/server.go b/server/server.go index f45c03536..8bb7b64ce 100644 --- a/server/server.go +++ b/server/server.go @@ -183,6 +183,7 @@ type serverCapabilities struct { logging *bool sampling *bool elicitation *bool + roots *bool } // resourceCapabilities defines the supported resource-related features @@ -326,6 +327,13 @@ func WithElicitation() ServerOption { } } +// WithRoots returns a ServerOption that enables the roots capability on the MCPServer +func WithRoots() ServerOption { + return func(s *MCPServer) { + s.capabilities.roots = mcp.ToBoolPtr(true) + } +} + // WithInstructions sets the server instructions for the client returned in the initialize response func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { @@ -696,6 +704,10 @@ func (s *MCPServer) handleInitialize( capabilities.Elicitation = &struct{}{} } + if s.capabilities.roots != nil && *s.capabilities.roots { + capabilities.Roots = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ diff --git a/server/session.go b/server/session.go index 99d6db8d4..9f4471d3b 100644 --- a/server/session.go +++ b/server/session.go @@ -71,6 +71,13 @@ type SessionWithElicitation interface { RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) } +// SessionWithRoots is an extension of ClientSession that can send list roots requests +type SessionWithRoots interface { + ClientSession + // ListRoots sends an list roots request to the client and waits for response + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} + // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations type SessionWithStreamableHTTPConfig interface { ClientSession diff --git a/server/stdio.go b/server/stdio.go index 80131f06c..f5c8ddfd2 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -102,6 +102,7 @@ type stdioSession struct { mu sync.RWMutex // protects writer pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests + pendingRoots map[int64]chan *rootsResponse // for tracking pending list roots requests pendingMu sync.RWMutex // protects pendingRequests and pendingElicitations } @@ -117,6 +118,12 @@ type elicitationResponse struct { err error } +// rootsResponse represents a response to an list root request +type rootsResponse struct { + result *mcp.ListRootsResult + err error +} + func (s *stdioSession) SessionID() string { return "stdio" } @@ -236,6 +243,67 @@ func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMe } } +// ListRoots sends an list roots request to the client and waits for the response. +func (s *stdioSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *rootsResponse, 1) + s.pendingMu.Lock() + s.pendingRoots[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRoots, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodListRoots), + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal list roots request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write list roots request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + // RequestElicitation sends an elicitation request to the client and waits for the response. func (s *stdioSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { s.mu.RLock() @@ -312,12 +380,14 @@ var ( _ SessionWithClientInfo = (*stdioSession)(nil) _ SessionWithSampling = (*stdioSession)(nil) _ SessionWithElicitation = (*stdioSession)(nil) + _ SessionWithRoots = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ notifications: make(chan mcp.JSONRPCNotification, 100), pendingRequests: make(map[int64]chan *samplingResponse), pendingElicitations: make(map[int64]chan *elicitationResponse), + pendingRoots: make(map[int64]chan *rootsResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -522,6 +592,11 @@ func (s *StdioServer) processMessage( return nil } + // Check if this is a response to an list roots request + if s.handleListRootsResponse(rawMessage) { + return nil + } + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) var baseMessage struct { Method string `json:"method"` @@ -692,6 +767,67 @@ func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) boo return true } +// handleListRootsResponse checks if the message is a response to an list roots request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleListRootsResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleListRootsResponse(rawMessage) +} + +// handleListRootsResponse handles incoming list root responses for this session +func (s *stdioSession) handleListRootsResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + id, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Check if we have a pending list root request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRoots[id] + s.pendingMu.RUnlock() + + if !exists { + return false + } + + // Parse and send the response + rootsResp := &rootsResponse{} + + if response.Error != nil { + rootsResp.err = fmt.Errorf("list root request failed: %s", response.Error.Message) + } else { + var result mcp.ListRootsResult + if err := json.Unmarshal(response.Result, &result); err != nil { + rootsResp.err = fmt.Errorf("failed to unmarshal list root response: %w", err) + } else { + rootsResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- rootsResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/server/streamable_http.go b/server/streamable_http.go index 8af6f1478..b5f1274d6 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -552,6 +552,20 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case rootsReq := <-session.rootsRequestChan: + // Send list roots request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(rootsReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -887,6 +901,13 @@ type elicitationRequestItem struct { response chan samplingResponseItem } +// Roots support types for HTTP transport +type rootsRequestItem struct { + requestID int64 + request mcp.ListRootsRequest + response chan samplingResponseItem +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -901,6 +922,7 @@ type streamableHttpSession struct { // Sampling support for bidirectional communication samplingRequestChan chan samplingRequestItem // server -> client sampling requests elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests + rootsRequestChan chan rootsRequestItem // server -> client list roots requests samplingRequests sync.Map // requestID -> pending sampling request context requestIDCounter atomic.Int64 // for generating unique request IDs @@ -915,6 +937,7 @@ func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, re logLevels: levels, samplingRequestChan: make(chan samplingRequestItem, 10), elicitationRequestChan: make(chan elicitationRequestItem, 10), + rootsRequestChan: make(chan rootsRequestItem, 10), } return s } @@ -1031,6 +1054,52 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp } } +// ListRoots implements SessionWithRoots interface for HTTP transport. +// It sends a list roots request to the client via SSE and waits for the response. +func (s *streamableHttpSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the roots request item + rootsRequest := rootsRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the list roots request via the channel (non-blocking) + select { + case s.rootsRequestChan <- rootsRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("list roots request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.ListRootsResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal list roots response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // RequestElicitation implements SessionWithElicitation interface for HTTP transport func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { // Generate unique request ID @@ -1078,6 +1147,7 @@ func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request var _ SessionWithSampling = (*streamableHttpSession)(nil) var _ SessionWithElicitation = (*streamableHttpSession)(nil) +var _ SessionWithRoots = (*streamableHttpSession)(nil) // --- session id manager ---