From 81b217299ab4269bc97e4a7a599e76db5608e40e Mon Sep 17 00:00:00 2001 From: Param Parikh Date: Mon, 6 Apr 2026 15:30:59 -0700 Subject: [PATCH] Extend McpMetricsObserver with tool call completion, error, and tools/list callbacks Add three new default methods to McpMetricsObserver: - onToolCallComplete: fires after tool execution with latency, server ID, success/failure, and proxy flag - onToolCallError: fires on tool call errors with error message - onToolsList: fires on tools/list requests with tool count Instrument McpService.handleToolsCall() with System.nanoTime() latency tracking on both local and proxy code paths, and add error tracking for failed calls and unknown tools. Instrument handleToolsList() with tool count callback. --- .../java/mcp/server/McpMetricsObserver.java | 57 +++- .../smithy/java/mcp/server/McpService.java | 204 +++++++++++---- .../smithy/java/mcp/server/McpServerTest.java | 247 ++++++++++++++++++ 3 files changed, 458 insertions(+), 50 deletions(-) diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java index 51c941d92..e5d287a5b 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java @@ -5,6 +5,7 @@ package software.amazon.smithy.java.mcp.server; +import java.time.Duration; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -27,10 +28,64 @@ void onInitialize( ); /** - * Called when tool/call request is received. + * Called when tool/call request is received. Fires at the start of the call, + * before execution begins. Paired with {@link #onToolCallComplete} which fires after. + * + * @param method the JSON-RPC method name (e.g. {@code "tools/call"}) + * @param toolName the name of the tool being invoked, or {@code null} if not available */ void onToolCall( String method, String toolName ); + + /** + * Called after a tool call completes, whether it succeeded or failed. + * Every {@link #onToolCall} is guaranteed a matching {@code onToolCallComplete}. + * + * @param method the JSON-RPC method name (e.g. {@code "tools/call"}) + * @param toolName the name of the tool that was invoked, or {@code null} if not available + * @param serverId the MCP server that handled the call (e.g. {@code "aws-lambda-mcp"}), + * or {@code null} if the tool was not found + * @param latency wall-clock time from request receipt to response + * @param success {@code true} if the tool call returned a successful result, + * {@code false} if it returned an error or threw an exception + * @param isProxy {@code true} if the call was forwarded to a proxy (StdioProxy or HttpMcpProxy), + * {@code false} if it was handled in-process by a local Smithy service + */ + default void onToolCallComplete( + String method, + String toolName, + String serverId, + Duration latency, + boolean success, + boolean isProxy + ) {} + + /** + * Called when a tool call results in an error. Always preceded by + * {@link #onToolCallComplete} with {@code success=false} for the same call. + * + * @param method the JSON-RPC method name (e.g. {@code "tools/call"}) + * @param toolName the name of the tool that failed, or {@code null} if not available + * @param serverId the MCP server that handled the call, or {@code null} if the tool was not found + * @param errorMessage a description of the error + */ + default void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) {} + + /** + * Called when a {@code tools/list} request is received. + * + * @param method the JSON-RPC method name (e.g. {@code "tools/list"}) + * @param toolCount the number of tools returned after filtering + */ + default void onToolsList( + String method, + int toolCount + ) {} } diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 662020f8a..4739336d2 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -14,6 +14,7 @@ import java.io.StringWriter; import java.math.BigDecimal; import java.math.BigInteger; +import java.time.Duration; import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; @@ -156,37 +157,39 @@ yield switch (method) { private JsonRpcResponse handleInitialize(JsonRpcRequest req) { if (metricsObserver != null) { - var params = req.getParams(); - var clientInfo = params.getMember("clientInfo"); - var capabilities = params.getMember("capabilities"); - - String extractedProtocolVersion = params.getMember("protocolVersion") != null - ? params.getMember("protocolVersion").asString() - : null; - - String clientName = clientInfo != null && clientInfo.getMember("name") != null - ? clientInfo.getMember("name").asString() - : null; - - String clientTitle = clientInfo != null && clientInfo.getMember("title") != null - ? clientInfo.getMember("title").asString() - : null; - - boolean rootsListChanged = capabilities != null - && capabilities.getMember("roots") != null - && capabilities.getMember("roots").getMember("listChanged") != null - && capabilities.getMember("roots").getMember("listChanged").asBoolean(); - - boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; - boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; - - metricsObserver.onInitialize("initialize", - extractedProtocolVersion, - rootsListChanged, - sampling, - elicitation, - clientName, - clientTitle); + safeObserve(() -> { + var params = req.getParams(); + var clientInfo = params.getMember("clientInfo"); + var capabilities = params.getMember("capabilities"); + + String extractedProtocolVersion = params.getMember("protocolVersion") != null + ? params.getMember("protocolVersion").asString() + : null; + + String clientName = clientInfo != null && clientInfo.getMember("name") != null + ? clientInfo.getMember("name").asString() + : null; + + String clientTitle = clientInfo != null && clientInfo.getMember("title") != null + ? clientInfo.getMember("title").asString() + : null; + + boolean rootsListChanged = capabilities != null + && capabilities.getMember("roots") != null + && capabilities.getMember("roots").getMember("listChanged") != null + && capabilities.getMember("roots").getMember("listChanged").asBoolean(); + + boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; + boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; + + metricsObserver.onInitialize("initialize", + extractedProtocolVersion, + rootsListChanged, + sampling, + elicitation, + clientName, + clientTitle); + }); } this.initializeRequest.compareAndSet(null, req); @@ -251,10 +254,17 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) { } private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) { + var filteredTools = tools.values() + .stream() + .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .toList(); + + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolsList("tools/list", filteredTools.size())); + } + var result = ListToolsResult.builder() - .tools(tools.values() - .stream() - .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .tools(filteredTools.stream() .map(tool -> extractToolInfo(tool, protocolVersion)) .toList()) .build(); @@ -266,18 +276,36 @@ private JsonRpcResponse handleToolsCall( Consumer asyncResponseCallback, ProtocolVersion protocolVersion ) { + String toolName = req.getParams().getMember("name") != null + ? req.getParams().getMember("name").asString() + : null; + if (metricsObserver != null) { - String toolName = req.getParams().getMember("name") != null - ? req.getParams().getMember("name").asString() - : null; - metricsObserver.onToolCall("tools/call", toolName); + safeObserve(() -> metricsObserver.onToolCall("tools/call", toolName)); } - var operationName = req.getParams().getMember("name").asString(); - var tool = tools.get(operationName); + long startNanos = System.nanoTime(); + var tool = tools.get(toolName); if (tool == null) { - return createErrorResponse(req, "No such tool: " + operationName); + Duration latency = Duration.ofNanos(System.nanoTime() - startNanos); + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete( + "tools/call", + toolName, + null, + latency, + false, + false); + metricsObserver.onToolCallError( + "tools/call", + toolName, + null, + "No such tool: " + toolName); + }); + } + return createErrorResponse(req, "No such tool: " + toolName); } // Check if this tool should be dispatched to a proxy @@ -291,8 +319,46 @@ private JsonRpcResponse handleToolsCall( .build(); // Get response asynchronously and invoke callback - tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> { + tool.proxy().rpc(proxyRequest).thenAccept(response -> { + Duration latency = Duration.ofNanos(System.nanoTime() - startNanos); + boolean success = response.getError() == null; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latency, + success, + true); + if (!success) { + String errMsg = response.getError().getMessage() != null + ? response.getError().getMessage() + : "Unknown error"; + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + errMsg); + } + }); + } + asyncResponseCallback.accept(response); + }).exceptionally(ex -> { + Duration latency = Duration.ofNanos(System.nanoTime() - startNanos); LOG.error("Error from proxy RPC", ex); + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latency, + false, + true); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(ex)); + }); + } asyncResponseCallback .accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex))); return null; @@ -302,16 +368,56 @@ private JsonRpcResponse handleToolsCall( return null; } else { // Handle locally - var operation = tool.operation(); - var argumentsDoc = req.getParams().getMember("arguments"); - var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); - var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); - var output = operation.function().apply(input, null); - var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); - return createSuccessResponse(req.getId(), result); + try { + var operation = tool.operation(); + var argumentsDoc = req.getParams().getMember("arguments"); + var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); + var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); + var output = operation.function().apply(input, null); + var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); + Duration latency = Duration.ofNanos(System.nanoTime() - startNanos); + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latency, + true, + false)); + } + return createSuccessResponse(req.getId(), result); + } catch (Exception e) { + Duration latency = Duration.ofNanos(System.nanoTime() - startNanos); + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latency, + false, + false); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(e)); + }); + } + throw e; + } } } + private void safeObserve(Runnable observation) { + try { + observation.run(); + } catch (Exception e) { + LOG.warn("Metrics observer error", e); + } + } + + private static String safeErrorMessage(Throwable t) { + return t.getMessage() != null ? t.getMessage() : t.getClass().getName(); + } + /** * Sets the notification writer for forwarding notifications from proxies. */ diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index aef941dec..6778ff62b 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -1670,6 +1670,253 @@ void testOtherNotificationsDoNotInvalidateCache() { assertEquals(1, callCounter.get(), "Cache should not be invalidated by other notifications"); } + // --- Metrics Observer Tests --- + + @Test + void testMetricsObserverOnInitialize() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + write("initialize", + Document.of(Map.of( + "protocolVersion", + Document.of("2025-03-26"), + "clientInfo", + Document.of(Map.of( + "name", + Document.of("test-client"), + "title", + Document.of("Test Client"))), + "capabilities", + Document.of(Map.of( + "roots", + Document.of(Map.of("listChanged", Document.of(true))), + "sampling", + Document.of(Map.of())))))); + read(); + + assertEquals(1, observer.initializeCount); + assertEquals("2025-03-26", observer.lastProtocolVersion); + assertEquals("test-client", observer.lastClientName); + assertEquals("Test Client", observer.lastClientTitle); + assertTrue(observer.lastRootsListChanged); + assertTrue(observer.lastSampling); + } + + @Test + void testMetricsObserverOnToolsList() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(1, observer.toolsListCount); + assertEquals(6, observer.lastToolCount); + + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(2, observer.toolsListCount); + } + + @Test + void testMetricsObserverOnToolCallForNonExistentTool() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NonExistentTool"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + assertNotNull(response.getError()); + assertEquals(1, observer.toolCallCount); + assertEquals("NonExistentTool", observer.lastToolCallName); + // onToolCallComplete should also fire for "tool not found" with success=false + assertEquals(1, observer.toolCallCompleteCount); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertNull(observer.lastCompleteServerId); + assertNotNull(observer.lastCompleteLatency); + assertEquals(1, observer.toolCallErrorCount); + assertEquals("NonExistentTool", observer.lastErrorToolName); + assertTrue(observer.lastErrorMessage.contains("No such tool")); + } + + @Test + void testMetricsObserverOnToolCallLocal() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + // onToolCall fires at the start of every tool call + assertEquals(1, observer.toolCallCount); + assertEquals("NoIOOperation", observer.lastToolCallName); + // onToolCallComplete fires after execution (ProxyService proxies to localhost which + // is not running, so this is a local execution that results in an error response) + assertEquals(1, observer.toolCallCompleteCount); + assertEquals("NoIOOperation", observer.lastCompleteToolName); + assertEquals("test-mcp", observer.lastCompleteServerId); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertNotNull(observer.lastCompleteLatency); + } + + private static class TestMetricsObserver implements McpMetricsObserver { + int initializeCount; + String lastProtocolVersion; + boolean lastRootsListChanged; + boolean lastSampling; + boolean lastElicitation; + String lastClientName; + String lastClientTitle; + + int toolCallCount; + String lastToolCallName; + + int toolCallCompleteCount; + String lastCompleteToolName; + String lastCompleteServerId; + Duration lastCompleteLatency; + boolean lastCompleteSuccess; + boolean lastCompleteIsProxy; + + int toolCallErrorCount; + String lastErrorToolName; + String lastErrorServerId; + String lastErrorMessage; + + int toolsListCount; + int lastToolCount; + + @Override + public void onInitialize( + String method, + String extractedProtocolVersion, + boolean rootsListChanged, + boolean sampling, + boolean elicitation, + String clientName, + String clientTitle + ) { + initializeCount++; + lastProtocolVersion = extractedProtocolVersion; + lastRootsListChanged = rootsListChanged; + lastSampling = sampling; + lastElicitation = elicitation; + lastClientName = clientName; + lastClientTitle = clientTitle; + } + + @Override + public void onToolCall(String method, String toolName) { + toolCallCount++; + lastToolCallName = toolName; + } + + @Override + public void onToolCallComplete( + String method, + String toolName, + String serverId, + Duration latency, + boolean success, + boolean isProxy + ) { + toolCallCompleteCount++; + lastCompleteToolName = toolName; + lastCompleteServerId = serverId; + lastCompleteLatency = latency; + lastCompleteSuccess = success; + lastCompleteIsProxy = isProxy; + } + + @Override + public void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) { + toolCallErrorCount++; + lastErrorToolName = toolName; + lastErrorServerId = serverId; + lastErrorMessage = errorMessage; + } + + @Override + public void onToolsList(String method, int toolCount) { + toolsListCount++; + lastToolCount = toolCount; + } + } + private static class CacheTestProxy extends McpServerProxy { private final AtomicInteger callCounter; private final List sentNotifications = new ArrayList<>();