-
Notifications
You must be signed in to change notification settings - Fork 3.6k
fix: add MCP initialize handshake to mcpcurl #2009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| package main | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "bufio" | ||
| "crypto/rand" | ||
| "encoding/json" | ||
| "fmt" | ||
|
|
@@ -376,8 +376,8 @@ func buildJSONRPCRequest(method, toolName string, arguments map[string]any) (str | |
| return string(jsonData), nil | ||
| } | ||
|
|
||
| // executeServerCommand runs the specified command, sends the JSON request to stdin, | ||
| // and returns the response from stdout | ||
| // executeServerCommand runs the specified command, performs the MCP initialization | ||
| // handshake, sends the JSON request to stdin, and returns the response from stdout. | ||
| func executeServerCommand(cmdStr, jsonRequest string) (string, error) { | ||
| // Split the command string into command and arguments | ||
| cmdParts := strings.Fields(cmdStr) | ||
|
|
@@ -393,28 +393,127 @@ func executeServerCommand(cmdStr, jsonRequest string) (string, error) { | |
| return "", fmt.Errorf("failed to create stdin pipe: %w", err) | ||
| } | ||
|
|
||
| // Setup stdout and stderr pipes | ||
| var stdout, stderr bytes.Buffer | ||
| cmd.Stdout = &stdout | ||
| // Setup stdout pipe for line-by-line reading | ||
| stdoutPipe, err := cmd.StdoutPipe() | ||
| if err != nil { | ||
| return "", fmt.Errorf("failed to create stdout pipe: %w", err) | ||
| } | ||
|
|
||
| // Stderr still uses a buffer | ||
| var stderr strings.Builder | ||
| cmd.Stderr = &stderr | ||
|
|
||
| // Start the command | ||
| if err := cmd.Start(); err != nil { | ||
| return "", fmt.Errorf("failed to start command: %w", err) | ||
| } | ||
|
Comment on lines
406
to
409
|
||
|
|
||
| // Write the JSON request to stdin | ||
| // Ensure the child process is cleaned up on any error after Start() | ||
| cleanup := func() { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, can't this just be deferred using defer keyword? Normally must execute functions can be cleanly deferred |
||
| _ = stdin.Close() | ||
| _ = cmd.Wait() | ||
| } | ||
|
|
||
| // Use a scanner with a large buffer for reading JSON-RPC responses | ||
| scanner := bufio.NewScanner(stdoutPipe) | ||
| scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB max line size | ||
|
|
||
| // Step 1: Send MCP initialize request | ||
| initReq, err := buildInitializeRequest() | ||
| if err != nil { | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to build initialize request: %w", err) | ||
| } | ||
| if _, err := io.WriteString(stdin, initReq+"\n"); err != nil { | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to write initialize request: %w", err) | ||
| } | ||
|
|
||
| // Step 2: Read initialize response (skip any server notifications) | ||
| if _, err := readJSONRPCResponse(scanner); err != nil { | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to read initialize response: %w, stderr: %s", err, stderr.String()) | ||
| } | ||
|
|
||
| // Step 3: Send initialized notification | ||
| if _, err := io.WriteString(stdin, buildInitializedNotification()+"\n"); err != nil { | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to write initialized notification: %w", err) | ||
| } | ||
|
|
||
| // Step 4: Send the actual request | ||
| if _, err := io.WriteString(stdin, jsonRequest+"\n"); err != nil { | ||
| return "", fmt.Errorf("failed to write to stdin: %w", err) | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to write request: %w", err) | ||
| } | ||
| _ = stdin.Close() | ||
|
|
||
| // Wait for the command to complete | ||
| if err := cmd.Wait(); err != nil { | ||
| return "", fmt.Errorf("command failed: %w, stderr: %s", err, stderr.String()) | ||
| // Step 5: Read the actual response (skip any server notifications) | ||
| response, err := readJSONRPCResponse(scanner) | ||
| if err != nil { | ||
| cleanup() | ||
| return "", fmt.Errorf("failed to read response: %w, stderr: %s", err, stderr.String()) | ||
| } | ||
|
|
||
| return stdout.String(), nil | ||
| // Close stdin and wait for process to exit. The server will see EOF and | ||
| // exit with a non-zero status, which is expected — we already have the response. | ||
| cleanup() | ||
|
|
||
| return response, nil | ||
| } | ||
|
|
||
| // buildInitializeRequest creates the MCP initialize handshake request. | ||
| func buildInitializeRequest() (string, error) { | ||
| id, err := rand.Int(rand.Reader, big.NewInt(10000)) | ||
| if err != nil { | ||
| return "", fmt.Errorf("failed to generate random ID: %w", err) | ||
| } | ||
| msg := map[string]any{ | ||
| "jsonrpc": "2.0", | ||
| "id": int(id.Int64()), | ||
| "method": "initialize", | ||
| "params": map[string]any{ | ||
| "protocolVersion": "2024-11-05", | ||
| "capabilities": map[string]any{}, | ||
| "clientInfo": map[string]any{ | ||
| "name": "mcpcurl", | ||
| "version": "0.1.0", | ||
| }, | ||
| }, | ||
| } | ||
| data, err := json.Marshal(msg) | ||
| if err != nil { | ||
| return "", fmt.Errorf("failed to marshal initialize request: %w", err) | ||
| } | ||
| return string(data), nil | ||
| } | ||
|
|
||
| // buildInitializedNotification creates the MCP initialized notification. | ||
| func buildInitializedNotification() string { | ||
| return `{"jsonrpc":"2.0","method":"notifications/initialized"}` | ||
| } | ||
|
|
||
| // readJSONRPCResponse reads lines from the scanner, skipping server-initiated | ||
| // notifications (messages without an "id" field), and returns the first response. | ||
| func readJSONRPCResponse(scanner *bufio.Scanner) (string, error) { | ||
| for scanner.Scan() { | ||
| line := scanner.Text() | ||
| // JSON-RPC responses have an "id" field; notifications do not. | ||
| var msg map[string]json.RawMessage | ||
| if err := json.Unmarshal([]byte(line), &msg); err != nil { | ||
| return "", fmt.Errorf("failed to parse JSON-RPC message: %w", err) | ||
| } | ||
| if _, hasID := msg["id"]; hasID { | ||
| if errField, hasErr := msg["error"]; hasErr { | ||
| return "", fmt.Errorf("server returned error: %s", string(errField)) | ||
| } | ||
| return line, nil | ||
| } | ||
| // No "id" — this is a notification, skip it | ||
| } | ||
| if err := scanner.Err(); err != nil { | ||
| return "", err | ||
| } | ||
| return "", fmt.Errorf("unexpected end of output") | ||
| } | ||
|
Comment on lines
497
to
517
|
||
|
|
||
| func printResponse(response string, prettyPrint bool) error { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| package main | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "encoding/json" | ||
| "strings" | ||
| "testing" | ||
| ) | ||
|
|
||
| func TestReadJSONRPCResponse_DirectResponse(t *testing.T) { | ||
| t.Parallel() | ||
| input := `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}` + "\n" | ||
| scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
|
||
| got, err := readJSONRPCResponse(scanner) | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
| if got != `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}` { | ||
| t.Fatalf("unexpected response: %s", got) | ||
| } | ||
| } | ||
|
|
||
| func TestReadJSONRPCResponse_SkipsNotifications(t *testing.T) { | ||
| t.Parallel() | ||
| input := strings.Join([]string{ | ||
| `{"jsonrpc":"2.0","method":"notifications/resources/list_changed","params":{}}`, | ||
| `{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}`, | ||
| `{"jsonrpc":"2.0","id":42,"result":{"content":[{"type":"text","text":"hello"}]}}`, | ||
| }, "\n") + "\n" | ||
| scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
|
||
| got, err := readJSONRPCResponse(scanner) | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
|
|
||
| var msg map[string]json.RawMessage | ||
| if err := json.Unmarshal([]byte(got), &msg); err != nil { | ||
| t.Fatalf("response is not valid JSON: %v", err) | ||
| } | ||
| // Verify we got the response with id:42, not a notification | ||
| var id int | ||
| if err := json.Unmarshal(msg["id"], &id); err != nil { | ||
| t.Fatalf("failed to parse id: %v", err) | ||
| } | ||
| if id != 42 { | ||
| t.Fatalf("expected id 42, got %d", id) | ||
| } | ||
| } | ||
|
|
||
| func TestReadJSONRPCResponse_NoResponse(t *testing.T) { | ||
| t.Parallel() | ||
| // Only notifications, no response | ||
| input := `{"jsonrpc":"2.0","method":"notifications/resources/list_changed","params":{}}` + "\n" | ||
| scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
|
||
| _, err := readJSONRPCResponse(scanner) | ||
| if err == nil { | ||
| t.Fatal("expected error for missing response, got nil") | ||
| } | ||
| if !strings.Contains(err.Error(), "unexpected end of output") { | ||
| t.Fatalf("expected 'unexpected end of output' error, got: %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestReadJSONRPCResponse_EmptyInput(t *testing.T) { | ||
| t.Parallel() | ||
| scanner := bufio.NewScanner(strings.NewReader("")) | ||
|
|
||
| _, err := readJSONRPCResponse(scanner) | ||
| if err == nil { | ||
| t.Fatal("expected error for empty input, got nil") | ||
| } | ||
| } | ||
|
|
||
| func TestReadJSONRPCResponse_InvalidJSON(t *testing.T) { | ||
| t.Parallel() | ||
| input := "not valid json\n" | ||
| scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
|
||
| _, err := readJSONRPCResponse(scanner) | ||
| if err == nil { | ||
| t.Fatal("expected error for invalid JSON, got nil") | ||
| } | ||
| if !strings.Contains(err.Error(), "failed to parse JSON-RPC message") { | ||
| t.Fatalf("expected parse error, got: %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestReadJSONRPCResponse_ServerError(t *testing.T) { | ||
| t.Parallel() | ||
| input := `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}` + "\n" | ||
| scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
|
||
| _, err := readJSONRPCResponse(scanner) | ||
| if err == nil { | ||
| t.Fatal("expected error for server error response, got nil") | ||
| } | ||
| if !strings.Contains(err.Error(), "server returned error") { | ||
| t.Fatalf("expected 'server returned error', got: %v", err) | ||
| } | ||
| if !strings.Contains(err.Error(), "method not found") { | ||
| t.Fatalf("expected error to contain server message, got: %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestBuildInitializeRequest(t *testing.T) { | ||
| t.Parallel() | ||
| got, err := buildInitializeRequest() | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
|
|
||
| var msg map[string]json.RawMessage | ||
| if err := json.Unmarshal([]byte(got), &msg); err != nil { | ||
| t.Fatalf("result is not valid JSON: %v", err) | ||
| } | ||
|
|
||
| // Verify required fields | ||
| for _, field := range []string{"jsonrpc", "id", "method", "params"} { | ||
| if _, ok := msg[field]; !ok { | ||
| t.Errorf("missing required field %q", field) | ||
| } | ||
| } | ||
|
|
||
| // Verify method | ||
| var method string | ||
| if err := json.Unmarshal(msg["method"], &method); err != nil { | ||
| t.Fatalf("failed to parse method: %v", err) | ||
| } | ||
| if method != "initialize" { | ||
| t.Errorf("expected method 'initialize', got %q", method) | ||
| } | ||
|
|
||
| // Verify params contain protocolVersion and clientInfo | ||
| var params map[string]json.RawMessage | ||
| if err := json.Unmarshal(msg["params"], ¶ms); err != nil { | ||
| t.Fatalf("failed to parse params: %v", err) | ||
| } | ||
| for _, field := range []string{"protocolVersion", "capabilities", "clientInfo"} { | ||
| if _, ok := params[field]; !ok { | ||
| t.Errorf("missing params field %q", field) | ||
| } | ||
| } | ||
|
|
||
| var version string | ||
| if err := json.Unmarshal(params["protocolVersion"], &version); err != nil { | ||
| t.Fatalf("failed to parse protocolVersion: %v", err) | ||
| } | ||
| if version != "2024-11-05" { | ||
| t.Errorf("expected protocolVersion '2024-11-05', got %q", version) | ||
| } | ||
| } | ||
|
|
||
| func TestBuildInitializedNotification(t *testing.T) { | ||
| t.Parallel() | ||
| got := buildInitializedNotification() | ||
|
|
||
| var msg map[string]json.RawMessage | ||
| if err := json.Unmarshal([]byte(got), &msg); err != nil { | ||
| t.Fatalf("result is not valid JSON: %v", err) | ||
| } | ||
|
|
||
| // Must have jsonrpc and method | ||
| var method string | ||
| if err := json.Unmarshal(msg["method"], &method); err != nil { | ||
| t.Fatalf("failed to parse method: %v", err) | ||
| } | ||
| if method != "notifications/initialized" { | ||
| t.Errorf("expected method 'notifications/initialized', got %q", method) | ||
| } | ||
|
|
||
| // Must NOT have an id (it's a notification) | ||
| if _, hasID := msg["id"]; hasID { | ||
| t.Error("notification should not have an 'id' field") | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description says it adds 7 unit tests in
cmd/mcpcurl/main_test.go, but there is no such file in this PR/repo state. If tests are intended, they need to be included (and ideally cover the initialize handshake + notification interleaving behavior); otherwise the PR description should be updated to match what’s actually changed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test file was missing from the initial push. It's included in the force push — 7 unit tests covering
readJSONRPCResponse,buildInitializeRequest, andbuildInitializedNotification.