Skip to content

Commit 23cee61

Browse files
committed
feat(streamable_http): elicitation request
Author: Ghosthell
1 parent aef7c8d commit 23cee61

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

client/client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS
478478
return c.handleSamplingRequestTransport(ctx, request)
479479
case string(mcp.MethodElicitationCreate):
480480
return c.handleElicitationRequestTransport(ctx, request)
481+
case string(mcp.MethodPing):
482+
return c.handlePingRequestTransport(ctx, request)
481483
default:
482484
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
483485
}
@@ -579,6 +581,15 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request
579581
return response, nil
580582
}
581583

584+
func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
585+
b, _ := json.Marshal(&mcp.EmptyResult{})
586+
return &transport.JSONRPCResponse{
587+
JSONRPC: mcp.JSONRPC_VERSION,
588+
ID: request.ID,
589+
Result: b,
590+
}, nil
591+
}
592+
582593
func listByPage[T any](
583594
ctx context.Context,
584595
client *Client,

server/streamable_http.go

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
473473
case <-done:
474474
return
475475
}
476+
case elicitationReq := <-session.elicitationRequestChan:
477+
// Send elicitation request to client via SSE
478+
jsonrpcRequest := mcp.JSONRPCRequest{
479+
JSONRPC: "2.0",
480+
ID: mcp.NewRequestId(elicitationReq.requestID),
481+
Request: mcp.Request{
482+
Method: string(mcp.MethodElicitationCreate),
483+
},
484+
Params: elicitationReq.request.Params,
485+
}
486+
select {
487+
case writeChan <- jsonrpcRequest:
488+
case <-done:
489+
return
490+
}
476491
case <-done:
477492
return
478493
}
@@ -612,12 +627,6 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
612627
}
613628
} else if responseMessage.Result != nil {
614629
// Parse result
615-
var result mcp.CreateMessageResult
616-
if err := json.Unmarshal(responseMessage.Result, &result); err != nil {
617-
response.err = fmt.Errorf("failed to parse sampling result: %v", err)
618-
} else {
619-
response.result = &result
620-
}
621630
} else {
622631
response.err = fmt.Errorf("sampling response has neither result nor error")
623632
}
@@ -764,10 +773,17 @@ type samplingRequestItem struct {
764773

765774
type samplingResponseItem struct {
766775
requestID int64
767-
result *mcp.CreateMessageResult
776+
result json.RawMessage
768777
err error
769778
}
770779

780+
// Elicitation support types for HTTP transport
781+
type elicitationRequestItem struct {
782+
requestID int64
783+
request mcp.ElicitationRequest
784+
response chan samplingResponseItem
785+
}
786+
771787
// streamableHttpSession is a session for streamable-http transport
772788
// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
773789
// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
@@ -779,18 +795,21 @@ type streamableHttpSession struct {
779795
logLevels *sessionLogLevelsStore
780796

781797
// Sampling support for bidirectional communication
782-
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
783-
samplingRequests sync.Map // requestID -> pending sampling request context
784-
requestIDCounter atomic.Int64 // for generating unique request IDs
798+
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
799+
elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests
800+
801+
samplingRequests sync.Map // requestID -> pending sampling request context
802+
requestIDCounter atomic.Int64 // for generating unique request IDs
785803
}
786804

787805
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
788806
s := &streamableHttpSession{
789-
sessionID: sessionID,
790-
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
791-
tools: toolStore,
792-
logLevels: levels,
793-
samplingRequestChan: make(chan samplingRequestItem, 10),
807+
sessionID: sessionID,
808+
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
809+
tools: toolStore,
810+
logLevels: levels,
811+
samplingRequestChan: make(chan samplingRequestItem, 10),
812+
elicitationRequestChan: make(chan elicitationRequestItem, 10),
794813
}
795814
return s
796815
}
@@ -877,13 +896,63 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
877896
if response.err != nil {
878897
return nil, response.err
879898
}
880-
return response.result, nil
899+
var result mcp.CreateMessageResult
900+
if err := json.Unmarshal(response.result, &result); err != nil {
901+
return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err)
902+
}
903+
return &result, nil
904+
case <-ctx.Done():
905+
return nil, ctx.Err()
906+
}
907+
}
908+
909+
// RequestElicitation implements SessionWithElicitation interface for HTTP transport
910+
func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
911+
// Generate unique request ID
912+
requestID := s.requestIDCounter.Add(1)
913+
914+
// Create response channel for this specific request
915+
responseChan := make(chan samplingResponseItem, 1)
916+
917+
// Create the sampling request item
918+
elicitationRequest := elicitationRequestItem{
919+
requestID: requestID,
920+
request: request,
921+
response: responseChan,
922+
}
923+
924+
// Store the pending request
925+
s.samplingRequests.Store(requestID, responseChan)
926+
defer s.samplingRequests.Delete(requestID)
927+
928+
// Send the sampling request via the channel (non-blocking)
929+
select {
930+
case s.elicitationRequestChan <- elicitationRequest:
931+
// Request queued successfully
932+
case <-ctx.Done():
933+
return nil, ctx.Err()
934+
default:
935+
return nil, fmt.Errorf("elicitation request queue is full - server overloaded")
936+
}
937+
938+
// Wait for response or context cancellation
939+
select {
940+
case response := <-responseChan:
941+
if response.err != nil {
942+
return nil, response.err
943+
}
944+
var result mcp.ElicitationResult
945+
if err := json.Unmarshal(response.result, &result); err != nil {
946+
return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err)
947+
}
948+
return &result, nil
881949
case <-ctx.Done():
882950
return nil, ctx.Err()
883951
}
884952
}
885953

886954
var _ SessionWithSampling = (*streamableHttpSession)(nil)
955+
var _ SessionWithElicitation = (*streamableHttpSession)(nil)
887956

888957
// --- session id manager ---
889958

@@ -952,6 +1021,7 @@ func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption
9521021
// - null
9531022
// - empty object: {}
9541023
// - empty array: []
1024+
//
9551025
// It also treats nil/whitespace-only input as empty.
9561026
// It does NOT treat 0, false, "" or non-empty composites as empty.
9571027
func isJSONEmpty(data json.RawMessage) bool {

0 commit comments

Comments
 (0)