Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.java.mcp.server;

import java.time.Duration;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -27,10 +28,64 @@ void onInitialize(
);

/**
* Called when tool/call request is received.
* Called when tool/call request is received. Fires at the start of the call,
* before execution begins. Paired with {@link #onToolCallComplete} which fires after.
*
* @param method the JSON-RPC method name (e.g. {@code "tools/call"})
* @param toolName the name of the tool being invoked, or {@code null} if not available
*/
void onToolCall(
String method,
String toolName
);

/**
* Called after a tool call completes, whether it succeeded or failed.
* Every {@link #onToolCall} is guaranteed a matching {@code onToolCallComplete}.
*
* @param method the JSON-RPC method name (e.g. {@code "tools/call"})
* @param toolName the name of the tool that was invoked, or {@code null} if not available
* @param serverId the MCP server that handled the call (e.g. {@code "aws-lambda-mcp"}),
* or {@code null} if the tool was not found
* @param latency wall-clock time from request receipt to response
* @param success {@code true} if the tool call returned a successful result,
* {@code false} if it returned an error or threw an exception
* @param isProxy {@code true} if the call was forwarded to a proxy (StdioProxy or HttpMcpProxy),
* {@code false} if it was handled in-process by a local Smithy service
*/
default void onToolCallComplete(
String method,
String toolName,
String serverId,
Duration latency,
boolean success,
boolean isProxy
Comment thread
psp65 marked this conversation as resolved.
) {}

/**
* Called when a tool call results in an error. Always preceded by
* {@link #onToolCallComplete} with {@code success=false} for the same call.
*
* @param method the JSON-RPC method name (e.g. {@code "tools/call"})
* @param toolName the name of the tool that failed, or {@code null} if not available
* @param serverId the MCP server that handled the call, or {@code null} if the tool was not found
* @param errorMessage a description of the error
*/
default void onToolCallError(
String method,
String toolName,
String serverId,
String errorMessage
) {}

/**
* Called when a {@code tools/list} request is received.
*
* @param method the JSON-RPC method name (e.g. {@code "tools/list"})
* @param toolCount the number of tools returned after filtering
*/
default void onToolsList(
String method,
int toolCount
) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.io.StringWriter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
Expand Down Expand Up @@ -156,37 +157,39 @@ yield switch (method) {

private JsonRpcResponse handleInitialize(JsonRpcRequest req) {
if (metricsObserver != null) {
var params = req.getParams();
var clientInfo = params.getMember("clientInfo");
var capabilities = params.getMember("capabilities");

String extractedProtocolVersion = params.getMember("protocolVersion") != null
? params.getMember("protocolVersion").asString()
: null;

String clientName = clientInfo != null && clientInfo.getMember("name") != null
? clientInfo.getMember("name").asString()
: null;

String clientTitle = clientInfo != null && clientInfo.getMember("title") != null
? clientInfo.getMember("title").asString()
: null;

boolean rootsListChanged = capabilities != null
&& capabilities.getMember("roots") != null
&& capabilities.getMember("roots").getMember("listChanged") != null
&& capabilities.getMember("roots").getMember("listChanged").asBoolean();

boolean sampling = capabilities != null && capabilities.getMember("sampling") != null;
boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null;

metricsObserver.onInitialize("initialize",
extractedProtocolVersion,
rootsListChanged,
sampling,
elicitation,
clientName,
clientTitle);
safeObserve(() -> {
var params = req.getParams();
var clientInfo = params.getMember("clientInfo");
var capabilities = params.getMember("capabilities");

String extractedProtocolVersion = params.getMember("protocolVersion") != null
? params.getMember("protocolVersion").asString()
: null;

String clientName = clientInfo != null && clientInfo.getMember("name") != null
? clientInfo.getMember("name").asString()
: null;

String clientTitle = clientInfo != null && clientInfo.getMember("title") != null
? clientInfo.getMember("title").asString()
: null;

boolean rootsListChanged = capabilities != null
&& capabilities.getMember("roots") != null
&& capabilities.getMember("roots").getMember("listChanged") != null
&& capabilities.getMember("roots").getMember("listChanged").asBoolean();

boolean sampling = capabilities != null && capabilities.getMember("sampling") != null;
boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null;

metricsObserver.onInitialize("initialize",
extractedProtocolVersion,
rootsListChanged,
sampling,
elicitation,
clientName,
clientTitle);
});
}

this.initializeRequest.compareAndSet(null, req);
Expand Down Expand Up @@ -251,10 +254,17 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) {
}

private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) {
var filteredTools = tools.values()
.stream()
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
.toList();

if (metricsObserver != null) {
safeObserve(() -> metricsObserver.onToolsList("tools/list", filteredTools.size()));
}

var result = ListToolsResult.builder()
.tools(tools.values()
.stream()
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
.tools(filteredTools.stream()
.map(tool -> extractToolInfo(tool, protocolVersion))
.toList())
.build();
Expand All @@ -266,18 +276,36 @@ private JsonRpcResponse handleToolsCall(
Consumer<JsonRpcResponse> asyncResponseCallback,
ProtocolVersion protocolVersion
) {
String toolName = req.getParams().getMember("name") != null
? req.getParams().getMember("name").asString()
: null;

if (metricsObserver != null) {
String toolName = req.getParams().getMember("name") != null
? req.getParams().getMember("name").asString()
: null;
metricsObserver.onToolCall("tools/call", toolName);
safeObserve(() -> metricsObserver.onToolCall("tools/call", toolName));
}

var operationName = req.getParams().getMember("name").asString();
var tool = tools.get(operationName);
long startNanos = System.nanoTime();
var tool = tools.get(toolName);

if (tool == null) {
return createErrorResponse(req, "No such tool: " + operationName);
Duration latency = Duration.ofNanos(System.nanoTime() - startNanos);
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete(
"tools/call",
toolName,
null,
latency,
false,
false);
metricsObserver.onToolCallError(
"tools/call",
toolName,
null,
"No such tool: " + toolName);
});
}
return createErrorResponse(req, "No such tool: " + toolName);
}

// Check if this tool should be dispatched to a proxy
Expand All @@ -291,8 +319,46 @@ private JsonRpcResponse handleToolsCall(
.build();

// Get response asynchronously and invoke callback
tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> {
tool.proxy().rpc(proxyRequest).thenAccept(response -> {
Duration latency = Duration.ofNanos(System.nanoTime() - startNanos);
boolean success = response.getError() == null;
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latency,
success,
true);
if (!success) {
String errMsg = response.getError().getMessage() != null
? response.getError().getMessage()
: "Unknown error";
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
errMsg);
}
});
}
asyncResponseCallback.accept(response);
}).exceptionally(ex -> {
Duration latency = Duration.ofNanos(System.nanoTime() - startNanos);
LOG.error("Error from proxy RPC", ex);
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latency,
false,
true);
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
safeErrorMessage(ex));
});
}
asyncResponseCallback
.accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex)));
return null;
Expand All @@ -302,16 +368,56 @@ private JsonRpcResponse handleToolsCall(
return null;
} else {
// Handle locally
var operation = tool.operation();
var argumentsDoc = req.getParams().getMember("arguments");
var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema());
var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder());
var output = operation.function().apply(input, null);
var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion);
return createSuccessResponse(req.getId(), result);
try {
var operation = tool.operation();
var argumentsDoc = req.getParams().getMember("arguments");
var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema());
var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder());
var output = operation.function().apply(input, null);
var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion);
Duration latency = Duration.ofNanos(System.nanoTime() - startNanos);
if (metricsObserver != null) {
safeObserve(() -> metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latency,
true,
false));
}
return createSuccessResponse(req.getId(), result);
} catch (Exception e) {
Duration latency = Duration.ofNanos(System.nanoTime() - startNanos);
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latency,
false,
false);
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
safeErrorMessage(e));
});
}
throw e;
}
}
}

private void safeObserve(Runnable observation) {
try {
observation.run();
} catch (Exception e) {
LOG.warn("Metrics observer error", e);
}
}

private static String safeErrorMessage(Throwable t) {
return t.getMessage() != null ? t.getMessage() : t.getClass().getName();
}

/**
* Sets the notification writer for forwarding notifications from proxies.
*/
Expand Down
Loading