diff --git a/cmd/mcpcurl/main.go b/cmd/mcpcurl/main.go index f35e6926c..0dad1ea1a 100644 --- a/cmd/mcpcurl/main.go +++ b/cmd/mcpcurl/main.go @@ -1,7 +1,7 @@ package main import ( - "bytes" + "bufio" "crypto/rand" "encoding/json" "fmt" @@ -376,8 +376,8 @@ func buildJSONRPCRequest(method, toolName string, arguments map[string]any) (str return string(jsonData), nil } -// executeServerCommand runs the specified command, sends the JSON request to stdin, -// and returns the response from stdout +// executeServerCommand runs the specified command, performs the MCP initialization +// handshake, sends the JSON request to stdin, and returns the response from stdout. func executeServerCommand(cmdStr, jsonRequest string) (string, error) { // Split the command string into command and arguments cmdParts := strings.Fields(cmdStr) @@ -393,9 +393,14 @@ func executeServerCommand(cmdStr, jsonRequest string) (string, error) { return "", fmt.Errorf("failed to create stdin pipe: %w", err) } - // Setup stdout and stderr pipes - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout + // Setup stdout pipe for line-by-line reading + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return "", fmt.Errorf("failed to create stdout pipe: %w", err) + } + + // Stderr still uses a buffer + var stderr strings.Builder cmd.Stderr = &stderr // Start the command @@ -403,18 +408,112 @@ func executeServerCommand(cmdStr, jsonRequest string) (string, error) { return "", fmt.Errorf("failed to start command: %w", err) } - // Write the JSON request to stdin + // Ensure the child process is cleaned up on any error after Start() + cleanup := func() { + _ = stdin.Close() + _ = cmd.Wait() + } + + // Use a scanner with a large buffer for reading JSON-RPC responses + scanner := bufio.NewScanner(stdoutPipe) + scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB max line size + + // Step 1: Send MCP initialize request + initReq, err := buildInitializeRequest() + if err != nil { + cleanup() + return "", fmt.Errorf("failed to build initialize request: %w", err) + } + if _, err := io.WriteString(stdin, initReq+"\n"); err != nil { + cleanup() + return "", fmt.Errorf("failed to write initialize request: %w", err) + } + + // Step 2: Read initialize response (skip any server notifications) + if _, err := readJSONRPCResponse(scanner); err != nil { + cleanup() + return "", fmt.Errorf("failed to read initialize response: %w, stderr: %s", err, stderr.String()) + } + + // Step 3: Send initialized notification + if _, err := io.WriteString(stdin, buildInitializedNotification()+"\n"); err != nil { + cleanup() + return "", fmt.Errorf("failed to write initialized notification: %w", err) + } + + // Step 4: Send the actual request if _, err := io.WriteString(stdin, jsonRequest+"\n"); err != nil { - return "", fmt.Errorf("failed to write to stdin: %w", err) + cleanup() + return "", fmt.Errorf("failed to write request: %w", err) } - _ = stdin.Close() - // Wait for the command to complete - if err := cmd.Wait(); err != nil { - return "", fmt.Errorf("command failed: %w, stderr: %s", err, stderr.String()) + // Step 5: Read the actual response (skip any server notifications) + response, err := readJSONRPCResponse(scanner) + if err != nil { + cleanup() + return "", fmt.Errorf("failed to read response: %w, stderr: %s", err, stderr.String()) } - return stdout.String(), nil + // Close stdin and wait for process to exit. The server will see EOF and + // exit with a non-zero status, which is expected — we already have the response. + cleanup() + + return response, nil +} + +// buildInitializeRequest creates the MCP initialize handshake request. +func buildInitializeRequest() (string, error) { + id, err := rand.Int(rand.Reader, big.NewInt(10000)) + if err != nil { + return "", fmt.Errorf("failed to generate random ID: %w", err) + } + msg := map[string]any{ + "jsonrpc": "2.0", + "id": int(id.Int64()), + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "mcpcurl", + "version": "0.1.0", + }, + }, + } + data, err := json.Marshal(msg) + if err != nil { + return "", fmt.Errorf("failed to marshal initialize request: %w", err) + } + return string(data), nil +} + +// buildInitializedNotification creates the MCP initialized notification. +func buildInitializedNotification() string { + return `{"jsonrpc":"2.0","method":"notifications/initialized"}` +} + +// readJSONRPCResponse reads lines from the scanner, skipping server-initiated +// notifications (messages without an "id" field), and returns the first response. +func readJSONRPCResponse(scanner *bufio.Scanner) (string, error) { + for scanner.Scan() { + line := scanner.Text() + // JSON-RPC responses have an "id" field; notifications do not. + var msg map[string]json.RawMessage + if err := json.Unmarshal([]byte(line), &msg); err != nil { + return "", fmt.Errorf("failed to parse JSON-RPC message: %w", err) + } + if _, hasID := msg["id"]; hasID { + if errField, hasErr := msg["error"]; hasErr { + return "", fmt.Errorf("server returned error: %s", string(errField)) + } + return line, nil + } + // No "id" — this is a notification, skip it + } + if err := scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("unexpected end of output") } func printResponse(response string, prettyPrint bool) error { diff --git a/cmd/mcpcurl/main_test.go b/cmd/mcpcurl/main_test.go new file mode 100644 index 000000000..3d0b00d2a --- /dev/null +++ b/cmd/mcpcurl/main_test.go @@ -0,0 +1,178 @@ +package main + +import ( + "bufio" + "encoding/json" + "strings" + "testing" +) + +func TestReadJSONRPCResponse_DirectResponse(t *testing.T) { + t.Parallel() + input := `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}` + "\n" + scanner := bufio.NewScanner(strings.NewReader(input)) + + got, err := readJSONRPCResponse(scanner) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != `{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}` { + t.Fatalf("unexpected response: %s", got) + } +} + +func TestReadJSONRPCResponse_SkipsNotifications(t *testing.T) { + t.Parallel() + input := strings.Join([]string{ + `{"jsonrpc":"2.0","method":"notifications/resources/list_changed","params":{}}`, + `{"jsonrpc":"2.0","method":"notifications/tools/list_changed"}`, + `{"jsonrpc":"2.0","id":42,"result":{"content":[{"type":"text","text":"hello"}]}}`, + }, "\n") + "\n" + scanner := bufio.NewScanner(strings.NewReader(input)) + + got, err := readJSONRPCResponse(scanner) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var msg map[string]json.RawMessage + if err := json.Unmarshal([]byte(got), &msg); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + // Verify we got the response with id:42, not a notification + var id int + if err := json.Unmarshal(msg["id"], &id); err != nil { + t.Fatalf("failed to parse id: %v", err) + } + if id != 42 { + t.Fatalf("expected id 42, got %d", id) + } +} + +func TestReadJSONRPCResponse_NoResponse(t *testing.T) { + t.Parallel() + // Only notifications, no response + input := `{"jsonrpc":"2.0","method":"notifications/resources/list_changed","params":{}}` + "\n" + scanner := bufio.NewScanner(strings.NewReader(input)) + + _, err := readJSONRPCResponse(scanner) + if err == nil { + t.Fatal("expected error for missing response, got nil") + } + if !strings.Contains(err.Error(), "unexpected end of output") { + t.Fatalf("expected 'unexpected end of output' error, got: %v", err) + } +} + +func TestReadJSONRPCResponse_EmptyInput(t *testing.T) { + t.Parallel() + scanner := bufio.NewScanner(strings.NewReader("")) + + _, err := readJSONRPCResponse(scanner) + if err == nil { + t.Fatal("expected error for empty input, got nil") + } +} + +func TestReadJSONRPCResponse_InvalidJSON(t *testing.T) { + t.Parallel() + input := "not valid json\n" + scanner := bufio.NewScanner(strings.NewReader(input)) + + _, err := readJSONRPCResponse(scanner) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } + if !strings.Contains(err.Error(), "failed to parse JSON-RPC message") { + t.Fatalf("expected parse error, got: %v", err) + } +} + +func TestReadJSONRPCResponse_ServerError(t *testing.T) { + t.Parallel() + input := `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}` + "\n" + scanner := bufio.NewScanner(strings.NewReader(input)) + + _, err := readJSONRPCResponse(scanner) + if err == nil { + t.Fatal("expected error for server error response, got nil") + } + if !strings.Contains(err.Error(), "server returned error") { + t.Fatalf("expected 'server returned error', got: %v", err) + } + if !strings.Contains(err.Error(), "method not found") { + t.Fatalf("expected error to contain server message, got: %v", err) + } +} + +func TestBuildInitializeRequest(t *testing.T) { + t.Parallel() + got, err := buildInitializeRequest() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var msg map[string]json.RawMessage + if err := json.Unmarshal([]byte(got), &msg); err != nil { + t.Fatalf("result is not valid JSON: %v", err) + } + + // Verify required fields + for _, field := range []string{"jsonrpc", "id", "method", "params"} { + if _, ok := msg[field]; !ok { + t.Errorf("missing required field %q", field) + } + } + + // Verify method + var method string + if err := json.Unmarshal(msg["method"], &method); err != nil { + t.Fatalf("failed to parse method: %v", err) + } + if method != "initialize" { + t.Errorf("expected method 'initialize', got %q", method) + } + + // Verify params contain protocolVersion and clientInfo + var params map[string]json.RawMessage + if err := json.Unmarshal(msg["params"], ¶ms); err != nil { + t.Fatalf("failed to parse params: %v", err) + } + for _, field := range []string{"protocolVersion", "capabilities", "clientInfo"} { + if _, ok := params[field]; !ok { + t.Errorf("missing params field %q", field) + } + } + + var version string + if err := json.Unmarshal(params["protocolVersion"], &version); err != nil { + t.Fatalf("failed to parse protocolVersion: %v", err) + } + if version != "2024-11-05" { + t.Errorf("expected protocolVersion '2024-11-05', got %q", version) + } +} + +func TestBuildInitializedNotification(t *testing.T) { + t.Parallel() + got := buildInitializedNotification() + + var msg map[string]json.RawMessage + if err := json.Unmarshal([]byte(got), &msg); err != nil { + t.Fatalf("result is not valid JSON: %v", err) + } + + // Must have jsonrpc and method + var method string + if err := json.Unmarshal(msg["method"], &method); err != nil { + t.Fatalf("failed to parse method: %v", err) + } + if method != "notifications/initialized" { + t.Errorf("expected method 'notifications/initialized', got %q", method) + } + + // Must NOT have an id (it's a notification) + if _, hasID := msg["id"]; hasID { + t.Error("notification should not have an 'id' field") + } +}