Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 86 additions & 16 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"context"
"encoding/json"
"fmt"
"net/http"
"slices"
"sync"
"sync/atomic"
Expand All @@ -25,6 +24,7 @@
serverCapabilities mcp.ServerCapabilities
protocolVersion string
samplingHandler SamplingHandler
rootsHandler RootsHandler

Check failure on line 27 in client/client.go

View workflow job for this annotation

GitHub Actions / lint

undefined: RootsHandler (typecheck)
elicitationHandler ElicitationHandler
}

Expand All @@ -38,15 +38,26 @@
}

// WithSamplingHandler sets the sampling handler for the client.
// When set, the client will declare sampling capability during initialization.
// WithSamplingHandler sets the SamplingHandler on the client and causes the client to declare sampling
// capability during Initialize. The provided `handler` will be invoked for incoming sampling requests.
// The returned ClientOption applies this handler to a Client.
func WithSamplingHandler(handler SamplingHandler) ClientOption {
return func(c *Client) {
c.samplingHandler = handler
}
}

// WithRootsHandler sets the roots handler for the client.
// WithRootsHandler returns a ClientOption that sets the client's RootsHandler.
// When provided, the client will declare the roots capability (ListChanged) during initialization.
func WithRootsHandler(handler RootsHandler) ClientOption {

Check failure on line 53 in client/client.go

View workflow job for this annotation

GitHub Actions / lint

undefined: RootsHandler (typecheck)
return func(c *Client) {
c.rootsHandler = handler
}
}

// WithElicitationHandler sets the elicitation handler for the client.
// When set, the client will declare elicitation capability during initialization.
// to declare elicitation capability during initialization.
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
return func(c *Client) {
c.elicitationHandler = handler
Expand Down Expand Up @@ -141,7 +152,6 @@
ctx context.Context,
method string,
params any,
header http.Header,
) (*json.RawMessage, error) {
if !c.initialized && method != "initialize" {
return nil, fmt.Errorf("client not initialized")
Expand All @@ -154,7 +164,6 @@
ID: mcp.NewRequestId(id),
Method: method,
Params: params,
Header: header,
}

response, err := c.transport.SendRequest(ctx, request)
Expand All @@ -180,6 +189,13 @@
if c.samplingHandler != nil {
capabilities.Sampling = &struct{}{}
}
if c.rootsHandler != nil {
capabilities.Roots = &struct {
ListChanged bool `json:"listChanged,omitempty"`
}{
ListChanged: true,
}
}
// Add elicitation capability if handler is configured
if c.elicitationHandler != nil {
capabilities.Elicitation = &struct{}{}
Expand All @@ -196,7 +212,7 @@
Capabilities: capabilities,
}

response, err := c.sendRequest(ctx, "initialize", params, request.Header)
response, err := c.sendRequest(ctx, "initialize", params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -241,7 +257,7 @@
}

func (c *Client) Ping(ctx context.Context) error {
_, err := c.sendRequest(ctx, "ping", nil, nil)
_, err := c.sendRequest(ctx, "ping", nil)
return err
}

Expand Down Expand Up @@ -322,7 +338,7 @@
ctx context.Context,
request mcp.ReadResourceRequest,
) (*mcp.ReadResourceResult, error) {
response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header)
response, err := c.sendRequest(ctx, "resources/read", request.Params)
if err != nil {
return nil, err
}
Expand All @@ -334,15 +350,15 @@
ctx context.Context,
request mcp.SubscribeRequest,
) error {
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header)
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
return err
}

func (c *Client) Unsubscribe(
ctx context.Context,
request mcp.UnsubscribeRequest,
) error {
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header)
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
return err
}

Expand Down Expand Up @@ -386,7 +402,7 @@
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header)
response, err := c.sendRequest(ctx, "prompts/get", request.Params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -434,7 +450,7 @@
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header)
response, err := c.sendRequest(ctx, "tools/call", request.Params)
if err != nil {
return nil, err
}
Expand All @@ -446,15 +462,15 @@
ctx context.Context,
request mcp.SetLevelRequest,
) error {
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header)
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
return err
}

func (c *Client) Complete(
ctx context.Context,
request mcp.CompleteRequest,
) (*mcp.CompleteResult, error) {
response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header)
response, err := c.sendRequest(ctx, "completion/complete", request.Params)
if err != nil {
return nil, err
}
Expand All @@ -467,6 +483,27 @@
return &result, nil
}

func (c *Client) RootListChanges(
ctx context.Context,
) error {
// Send root list changes notification
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: mcp.MethodNotificationToolsListChanged,
},
}

err := c.transport.SendNotification(ctx, notification)
if err != nil {
return fmt.Errorf(
"failed to send root list change notification: %w",
err,
)
}
return nil
}

// handleIncomingRequest processes incoming requests from the server.
// This is the main entry point for server-to-client requests like sampling and elicitation.
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
Expand All @@ -477,6 +514,8 @@
return c.handleElicitationRequestTransport(ctx, request)
case string(mcp.MethodPing):
return c.handlePingRequestTransport(ctx, request)
case string(mcp.MethodListRoots):

Check failure on line 517 in client/client.go

View workflow job for this annotation

GitHub Actions / lint

undefined: mcp.MethodListRoots (typecheck)
return c.handleListRootsRequestTransport(ctx, request)
default:
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
}
Expand Down Expand Up @@ -539,6 +578,37 @@
return response, nil
}

// handleListRootsRequestTransport handles list roots requests at the transport level.
func (c *Client) handleListRootsRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.rootsHandler == nil {
return nil, fmt.Errorf("no roots handler configured")
}

// Create the MCP request
mcpRequest := mcp.ListRootsRequest{
Request: mcp.Request{
Method: string(mcp.MethodListRoots),
},
}

// Call the list roots handler
result, err := c.rootsHandler.ListRoots(ctx, mcpRequest)
if err != nil {
return nil, err
}

// Marshal the result
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}

// Create the transport response
response := transport.NewJSONRPCResultResponse(request.ID, resultBytes)

return response, nil
}

// handleElicitationRequestTransport handles elicitation requests at the transport level.
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.elicitationHandler == nil {
Expand Down Expand Up @@ -594,7 +664,7 @@
request mcp.PaginatedRequest,
method string,
) (*T, error) {
response, err := client.sendRequest(ctx, method, request.Params, nil)
response, err := client.sendRequest(ctx, method, request.Params)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -635,4 +705,4 @@
// IsInitialized returns true if the client has been initialized.
func (c *Client) IsInitialized() bool {
return c.initialized
}
}
18 changes: 15 additions & 3 deletions client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type InProcessTransport struct {
server *server.MCPServer
samplingHandler server.SamplingHandler
elicitationHandler server.ElicitationHandler
rootsHandler server.RootsHandler
session *server.InProcessSession
sessionID string

Expand All @@ -31,12 +32,23 @@ func WithSamplingHandler(handler server.SamplingHandler) InProcessOption {
}
}

// WithElicitationHandler returns an InProcessOption that sets the elicitation handler on an InProcessTransport.
// The provided handler will be used to handle elicitation requests for the in-process session when the transport is started.
func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption {
return func(t *InProcessTransport) {
t.elicitationHandler = handler
}
}

// WithRootsHandler returns an InProcessOption that sets the transport's roots handler.
// The provided handler is assigned to the transport's rootsHandler field when the option is applied.
func WithRootsHandler(handler server.RootsHandler) InProcessOption {
return func(t *InProcessTransport) {
t.rootsHandler = handler
}
}

// NewInProcessTransport creates an InProcessTransport that wraps the provided MCPServer with default (zero-value) configuration.
func NewInProcessTransport(server *server.MCPServer) *InProcessTransport {
return &InProcessTransport{
server: server,
Expand Down Expand Up @@ -66,8 +78,8 @@ func (c *InProcessTransport) Start(ctx context.Context) error {
c.startedMu.Unlock()

// Create and register session if we have handlers
if c.samplingHandler != nil || c.elicitationHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler)
if c.samplingHandler != nil || c.elicitationHandler != nil || c.rootsHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler, c.rootsHandler)
if err := c.server.RegisterSession(ctx, c.session); err != nil {
c.startedMu.Lock()
c.started = false
Expand Down Expand Up @@ -130,4 +142,4 @@ func (c *InProcessTransport) Close() error {

func (c *InProcessTransport) GetSessionId() string {
return ""
}
}
Loading
Loading