diff --git a/scripts/generate.js b/scripts/generate.js index 37cdc21..1ab5aad 100644 --- a/scripts/generate.js +++ b/scripts/generate.js @@ -75,6 +75,8 @@ async function main() { export const CLIENT_METHODS = ${JSON.stringify(metadata.clientMethods, null, 2)} as const; +export const PROTOCOL_METHODS = ${JSON.stringify(metadata.protocolMethods, null, 2)} as const; + export const PROTOCOL_VERSION = ${metadata.version}; `, { parser: "typescript" }, diff --git a/src/acp.test.ts b/src/acp.test.ts index cfeca98..1bd6a47 100644 --- a/src/acp.test.ts +++ b/src/acp.test.ts @@ -721,7 +721,8 @@ describe("Connection", () => { const appAgent = createAgent({ name: "app-agent" }) .onRequest(AGENT_METHODS.initialize, (c) => { events.push(`initialize:${c.params.protocolVersion}`); - expect(Object.keys(c).sort()).toEqual(["client", "params"]); + expect(Object.keys(c).sort()).toEqual(["client", "params", "signal"]); + expect(c.signal.aborted).toBe(false); return { protocolVersion: c.params.protocolVersion, @@ -742,7 +743,7 @@ describe("Connection", () => { }, }, (c) => { - expect(Object.keys(c).sort()).toEqual(["client", "params"]); + expect(Object.keys(c).sort()).toEqual(["client", "params", "signal"]); events.push(`agent-route:${String(c.params.message)}`); }, ) @@ -777,7 +778,8 @@ describe("Connection", () => { return { message: String(message).toUpperCase() }; }, (c) => { - expect(Object.keys(c).sort()).toEqual(["agent", "params"]); + expect(Object.keys(c).sort()).toEqual(["agent", "params", "signal"]); + expect(c.signal.aborted).toBe(false); events.push(`client-route:${String(c.params.message)}`); return { message: c.params.message }; @@ -824,6 +826,78 @@ describe("Connection", () => { ]); }); + it("aborts app request context signals for protocol cancellation", async () => { + const requestSignal = Promise.withResolvers(); + + const appAgent = createAgent({ name: "cancel-signal-agent" }).onRequest( + "vendor/slow", + (params) => params as Record, + async (c) => { + requestSignal.resolve(c.signal); + await new Promise((resolve) => { + c.signal.addEventListener("abort", () => resolve(), { once: true }); + }); + + return { cancelled: c.signal.aborted }; + }, + ); + const appClient = createClient({ name: "cancel-signal-client" }); + + const result = await appClient.connectWith(appAgent, async (agentCx) => { + const response = agentCx.request<{ cancelled: boolean }>( + "vendor/slow", + {}, + ); + const signal = await requestSignal.promise; + + expect(signal.aborted).toBe(false); + await agentCx.notify(methods.protocol.cancelRequest, { requestId: 0 }); + + return response; + }); + + expect(result).toEqual({ cancelled: true }); + }); + + it("maps app request abort errors to request cancellation", async () => { + const requestSignal = Promise.withResolvers(); + + const appAgent = createAgent({ name: "abort-error-agent" }).onRequest( + "vendor/abort", + (params) => params as Record, + async (c) => { + requestSignal.resolve(c.signal); + await new Promise((_, reject) => { + c.signal.addEventListener( + "abort", + () => { + const error = new Error("aborted"); + error.name = "AbortError"; + reject(error); + }, + { once: true }, + ); + }); + + return {}; + }, + ); + const appClient = createClient({ name: "abort-error-client" }); + + await appClient.connectWith(appAgent, async (agentCx) => { + const response = agentCx.request("vendor/abort", {}); + const signal = await requestSignal.promise; + + expect(signal.aborted).toBe(false); + await agentCx.notify(methods.protocol.cancelRequest, { requestId: 0 }); + + await expect(response).rejects.toMatchObject({ + code: -32800, + message: "Request cancelled", + }); + }); + }); + it("returns peer contexts from app connection handles", async () => { const events: string[] = []; diff --git a/src/acp.ts b/src/acp.ts index 25e188a..ee7b7c0 100644 --- a/src/acp.ts +++ b/src/acp.ts @@ -4,6 +4,7 @@ export type * from "./schema/types.gen.js"; export { AGENT_METHODS, CLIENT_METHODS, + PROTOCOL_METHODS, PROTOCOL_VERSION, } from "./schema/index.js"; export * from "./stream.js"; @@ -16,6 +17,7 @@ export type { ErrorResponse, MaybePromise, Result, + SendRequestOptions, } from "./jsonrpc.js"; import type { Stream } from "./stream.js"; @@ -28,6 +30,7 @@ import type { IncomingMessage, JsonRpcHandler, MaybePromise, + SendRequestOptions, } from "./jsonrpc.js"; function emptyObjectResponse(response: T | null | undefined | void): T { @@ -124,6 +127,9 @@ export const methods = { complete: schema.CLIENT_METHODS.elicitation_complete, }, }, + protocol: { + cancelRequest: schema.PROTOCOL_METHODS.cancel_request, + }, } as const; const startActiveSession = Symbol("startActiveSession"); @@ -192,8 +198,9 @@ class AcpContext { method: string, params?: Req, mapResponse?: (response: Resp) => Output, + options?: SendRequestOptions, ): Promise { - return this.cx.sendRequest(method, params, mapResponse); + return this.cx.sendRequest(method, params, mapResponse, options); } /** @internal */ @@ -232,16 +239,22 @@ export class AgentContext extends AcpContext { request( method: Method, params: ClientRequestParamsByMethod[Method], + options?: SendRequestOptions, ): Promise; request( method: string, params?: Params, + options?: SendRequestOptions, ): Promise; - request(method: string, params?: unknown): Promise { + request( + method: string, + params?: unknown, + options?: SendRequestOptions, + ): Promise { const spec = clientRequestSpecsByMethod[method] as | AcpRequestSpec | undefined; - return this.sendRequest(method, params, spec?.mapResponse); + return this.sendRequest(method, params, spec?.mapResponse, options); } /** @@ -279,13 +292,17 @@ export class ClientContext extends AcpContext { /** @internal */ [startActiveSession]( params: schema.NewSessionRequest, + options?: SendRequestOptions, ): Promise { return this.sendRequest< schema.NewSessionRequest, schema.NewSessionResponse, ActiveSession - >(schema.AGENT_METHODS.session_new, params, (response) => - this.attachSession(response), + >( + schema.AGENT_METHODS.session_new, + params, + (response) => this.attachSession(response), + options, ); } @@ -347,16 +364,22 @@ export class ClientContext extends AcpContext { request( method: Method, params: AgentRequestParamsByMethod[Method], + options?: SendRequestOptions, ): Promise; request( method: string, params?: Params, + options?: SendRequestOptions, ): Promise; - request(method: string, params?: unknown): Promise { + request( + method: string, + params?: unknown, + options?: SendRequestOptions, + ): Promise { const spec = agentRequestSpecsByMethod[method] as | AcpRequestSpec | undefined; - return this.sendRequest(method, params, spec?.mapResponse); + return this.sendRequest(method, params, spec?.mapResponse, options); } /** @@ -657,8 +680,8 @@ export class SessionBuilder { * Call `dispose()` on the returned session when you no longer need update * routing, or use `withSession(...)` to scope disposal automatically. */ - async start(): Promise { - return this.cx[startActiveSession](this.toRequest()); + async start(options?: SendRequestOptions): Promise { + return this.cx[startActiveSession](this.toRequest(), options); } /** @@ -752,12 +775,17 @@ export class ActiveSession { */ prompt( prompt: string | schema.ContentBlock | Array, + options?: SendRequestOptions, ): Promise { this.updates.clearErrors(); - const response = this.cx.request(schema.AGENT_METHODS.session_prompt, { - sessionId: this.sessionId, - prompt: this.promptBlocks(prompt), - }); + const response = this.cx.request( + schema.AGENT_METHODS.session_prompt, + { + sessionId: this.sessionId, + prompt: this.promptBlocks(prompt), + }, + options, + ); void response.then( (value) => { this.updates.enqueue({ @@ -874,6 +902,11 @@ export type AgentHandlerContext = { * Parsed request or notification params. */ params: Params; + /** + * AbortSignal for the current request, or the connection signal for + * notifications. + */ + signal: AbortSignal; /** * Typed client context for calling client-side ACP methods. */ @@ -888,6 +921,11 @@ export type ClientHandlerContext = { * Parsed request or notification params. */ params: Params; + /** + * AbortSignal for the current request, or the connection signal for + * notifications. + */ + signal: AbortSignal; /** * Typed agent context for calling agent-side ACP methods. */ @@ -980,14 +1018,18 @@ function notificationSpec( function registerAppRequest( builder: ConnectionBuilder, spec: AcpRequestSpec, - context: (params: Params, cx: ConnectionContext) => Context, + context: ( + params: Params, + cx: ConnectionContext, + signal: AbortSignal, + ) => Context, handler: (context: Context) => MaybePromise, ): void { builder.onReceiveRequest( spec.method, (params) => parseParams(spec.params, params), async (params, responder, cx) => { - const response = await handler(context(params, cx)); + const response = await handler(context(params, cx, responder.signal)); await responder.respond( (spec.mapResponse ? spec.mapResponse(response) @@ -1000,13 +1042,17 @@ function registerAppRequest( function registerAppNotification( builder: ConnectionBuilder, spec: AcpNotificationSpec, - context: (params: Params, cx: ConnectionContext) => Context, + context: ( + params: Params, + cx: ConnectionContext, + signal: AbortSignal, + ) => Context, handler: (context: Context) => MaybePromise, ): void { builder.onReceiveNotification( spec.method, (params) => parseParams(spec.params, params), - (params, cx) => handler(context(params, cx)), + (params, cx) => handler(context(params, cx, cx.signal)), ); } @@ -1518,9 +1564,11 @@ export type ClientNotificationParamsByMethod = { function agentHandlerContext( params: Params, client: AgentContext, + signal: AbortSignal, ): AgentHandlerContext { return { params, + signal, client, }; } @@ -1528,9 +1576,11 @@ function agentHandlerContext( function clientHandlerContext( params: Params, agent: ClientContext, + signal: AbortSignal, ): ClientHandlerContext { return { params, + signal, agent, }; } @@ -1825,7 +1875,8 @@ export class AgentApp { registerAppRequest( this.builder, spec, - (params, cx) => agentHandlerContext(params, AgentContext.create(cx)), + (params, cx, signal) => + agentHandlerContext(params, AgentContext.create(cx), signal), handler, ); return this; @@ -1838,7 +1889,8 @@ export class AgentApp { registerAppNotification( this.builder, spec, - (params, cx) => agentHandlerContext(params, AgentContext.create(cx)), + (params, cx, signal) => + agentHandlerContext(params, AgentContext.create(cx), signal), handler, ); return this; @@ -2068,7 +2120,8 @@ export class ClientApp { registerAppRequest( this.builder, spec, - (params, cx) => clientHandlerContext(params, ClientContext.create(cx)), + (params, cx, signal) => + clientHandlerContext(params, ClientContext.create(cx), signal), handler, ); return this; @@ -2081,7 +2134,8 @@ export class ClientApp { registerAppNotification( this.builder, spec, - (params, cx) => clientHandlerContext(params, ClientContext.create(cx)), + (params, cx, signal) => + clientHandlerContext(params, ClientContext.create(cx), signal), handler, ); return this; @@ -2642,16 +2696,27 @@ export class AgentSideConnection { request( method: Method, params: ClientRequestParamsByMethod[Method], + options?: SendRequestOptions, ): Promise; request( method: string, params?: Params, + options?: SendRequestOptions, ): Promise; - request(method: string, params?: unknown): Promise { + request( + method: string, + params?: unknown, + options?: SendRequestOptions, + ): Promise { const spec = clientRequestSpecsByMethod[method] as | AcpRequestSpec | undefined; - return this.connection.sendRequest(method, params, spec?.mapResponse); + return this.connection.sendRequest( + method, + params, + spec?.mapResponse, + options, + ); } /** @@ -3402,16 +3467,27 @@ export class ClientSideConnection implements Agent { request( method: Method, params: AgentRequestParamsByMethod[Method], + options?: SendRequestOptions, ): Promise; request( method: string, params?: Params, + options?: SendRequestOptions, ): Promise; - request(method: string, params?: unknown): Promise { + request( + method: string, + params?: unknown, + options?: SendRequestOptions, + ): Promise { const spec = agentRequestSpecsByMethod[method] as | AcpRequestSpec | undefined; - return this.connection.sendRequest(method, params, spec?.mapResponse); + return this.connection.sendRequest( + method, + params, + spec?.mapResponse, + options, + ); } /** diff --git a/src/jsonrpc.test.ts b/src/jsonrpc.test.ts index b6453dd..e950a53 100644 --- a/src/jsonrpc.test.ts +++ b/src/jsonrpc.test.ts @@ -1,6 +1,12 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; -import { isJsonRpcMessage } from "./jsonrpc.js"; +import { Connection, RequestError, isJsonRpcMessage } from "./jsonrpc.js"; +import type { AnyMessage, RequestResponder } from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; + +type ConnectionInternals = { + pendingResponses: Map; +}; describe("JSON-RPC envelope validation", () => { it.each([ @@ -38,3 +44,386 @@ describe("JSON-RPC envelope validation", () => { expect(isJsonRpcMessage(message)).toBe(false); }); }); + +describe("JSON-RPC request cancellation", () => { + it("keeps an aborted outgoing request pending for the peer response", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const slowResponder = Promise.withResolvers(); + const cancelReceived = Promise.withResolvers<{ + requestId: string | number | null; + }>(); + const consoleError = vi + .spyOn(console, "error") + .mockImplementation(() => {}); + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + (_request, responder) => { + slowResponder.resolve(responder); + return new Promise(() => {}); + }, + ) + .onReceiveRequest( + "example/barrier", + (params) => params, + (_, responder) => responder.respond({ ok: true }), + ) + .onReceiveNotification( + "$/cancel_request", + (params) => params as { requestId: string | number | null }, + (params) => { + cancelReceived.resolve(params); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const abortController = new AbortController(); + const response = client.sendRequest("example/slow", {}, undefined, { + cancellationSignal: abortController.signal, + }); + let settled = false; + response.then( + () => { + settled = true; + }, + () => { + settled = true; + }, + ); + const responder = await slowResponder.promise; + const clientInternals = client as unknown as ConnectionInternals; + + abortController.abort("user cancelled"); + + await expect(cancelReceived.promise).resolves.toEqual({ + requestId: responder.id, + }); + await Promise.resolve(); + expect(settled).toBe(false); + expect(clientInternals.pendingResponses.has(responder.id)).toBe(true); + + await responder.respond({ ok: true }); + await expect(response).resolves.toEqual({ ok: true }); + expect(clientInternals.pendingResponses.has(responder.id)).toBe(false); + await expect(client.sendRequest("example/barrier", {})).resolves.toEqual({ + ok: true, + }); + expect(consoleError).not.toHaveBeenCalled(); + } finally { + consoleError.mockRestore(); + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); + + it("sends an already-aborted cancellation signal after the request", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const slowResponder = Promise.withResolvers(); + const cancelReceived = Promise.withResolvers<{ + requestId: string | number | null; + }>(); + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + (_request, responder) => { + slowResponder.resolve(responder); + return new Promise(() => {}); + }, + ) + .onReceiveNotification( + "$/cancel_request", + (params) => params as { requestId: string | number | null }, + (params) => { + cancelReceived.resolve(params); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const response = client.sendRequest("example/slow", {}, undefined, { + cancellationSignal: AbortSignal.abort("already cancelled"), + }); + const responder = await slowResponder.promise; + + await expect(cancelReceived.promise).resolves.toEqual({ + requestId: responder.id, + }); + + await responder.respond({ ok: true }); + await expect(response).resolves.toEqual({ ok: true }); + } finally { + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); + + it("queues already-aborted cancellation before later writes", async () => { + const messages: AnyMessage[] = []; + const firstWriteStarted = Promise.withResolvers(); + const unblockFirstWrite = Promise.withResolvers(); + const thirdWriteCompleted = Promise.withResolvers(); + + let writes = 0; + const client = Connection.builder().connect({ + readable: new ReadableStream(), + writable: new WritableStream({ + async write(message) { + writes += 1; + messages.push(message); + if (writes === 1) { + firstWriteStarted.resolve(); + await unblockFirstWrite.promise; + } + if (writes === 3) { + thirdWriteCompleted.resolve(); + } + }, + }), + }); + + try { + const response = client.sendRequest("example/slow", {}, undefined, { + cancellationSignal: AbortSignal.abort("already cancelled"), + }); + response.catch(() => {}); + await firstWriteStarted.promise; + + const laterNotification = client.sendNotification("example/later", {}); + unblockFirstWrite.resolve(); + await thirdWriteCompleted.promise; + await laterNotification; + + expect(messages[0]).toMatchObject({ + method: "example/slow", + }); + expect(messages[1]).toMatchObject({ + method: "$/cancel_request", + params: { requestId: "id" in messages[0] ? messages[0].id : undefined }, + }); + expect(messages[2]).toMatchObject({ + method: "example/later", + }); + } finally { + client.close(); + await client.closed; + } + }); + + it("keeps manually cancelled requests pending for the peer response", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const slowResponder = Promise.withResolvers(); + const cancelReceived = Promise.withResolvers<{ + requestId: string | number | null; + }>(); + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + (_request, responder) => { + slowResponder.resolve(responder); + return new Promise(() => {}); + }, + ) + .onReceiveNotification( + "$/cancel_request", + (params) => params as { requestId: string | number | null }, + (params) => { + cancelReceived.resolve(params); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const response = client.sendRequest("example/slow", {}); + const responder = await slowResponder.promise; + const clientInternals = client as unknown as ConnectionInternals; + + await client.sendCancelRequest(responder.id); + + await expect(cancelReceived.promise).resolves.toEqual({ + requestId: responder.id, + }); + expect(clientInternals.pendingResponses.has(responder.id)).toBe(true); + + await responder.respond({ ok: true }); + await expect(response).resolves.toEqual({ ok: true }); + expect(clientInternals.pendingResponses.has(responder.id)).toBe(false); + } finally { + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); + + it("aborts the incoming request signal when $/cancel_request is received", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const requestReceived = Promise.withResolvers<{ + id: string | number | null; + signal: AbortSignal; + }>(); + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + async (_request, responder) => { + requestReceived.resolve({ + id: responder.id, + signal: responder.signal, + }); + await new Promise((resolve) => { + responder.signal.addEventListener("abort", () => resolve(), { + once: true, + }); + }); + await responder.respondWithError(RequestError.requestCancelled()); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const response = client.sendRequest("example/slow", {}); + const { id, signal } = await requestReceived.promise; + + expect(signal.aborted).toBe(false); + await client.sendCancelRequest(id); + + await expect(response).rejects.toMatchObject({ + code: -32800, + message: "Request cancelled", + }); + expect(signal.aborted).toBe(true); + expect(signal.reason).toBeInstanceOf(RequestError); + expect((signal.reason as RequestError).code).toBe(-32800); + } finally { + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); + + it("maps raw request abort errors to request cancellation", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const requestReceived = Promise.withResolvers<{ + id: string | number | null; + signal: AbortSignal; + }>(); + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + async (_request, responder) => { + requestReceived.resolve({ + id: responder.id, + signal: responder.signal, + }); + await new Promise((_, reject) => { + responder.signal.addEventListener( + "abort", + () => { + const error = new Error("aborted"); + error.name = "AbortError"; + reject(error); + }, + { once: true }, + ); + }); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const response = client.sendRequest("example/slow", {}); + const { id, signal } = await requestReceived.promise; + + expect(signal.aborted).toBe(false); + await client.sendCancelRequest(id); + + await expect(response).rejects.toMatchObject({ + code: -32800, + message: "Request cancelled", + }); + } finally { + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); + + it("rejects requests started from request abort listeners during close", async () => { + const [clientStream, serverStream] = memoryStreamPair(); + const requestStarted = Promise.withResolvers(); + const closeTimeRequestStarted = Promise.withResolvers(); + const closeError = new Error("closing"); + let closeTimeRequest: Promise | undefined; + + const server = Connection.builder() + .onReceiveRequest( + "example/slow", + (params) => params, + (_request, responder, cx) => { + responder.signal.addEventListener( + "abort", + () => { + closeTimeRequest = cx.sendRequest("example/after-close", {}); + closeTimeRequest.catch(() => {}); + closeTimeRequestStarted.resolve(); + }, + { once: true }, + ); + requestStarted.resolve(); + return new Promise(() => {}); + }, + ) + .connect(serverStream); + const client = Connection.builder().connect(clientStream); + + try { + const response = client.sendRequest("example/slow", {}); + response.catch(() => {}); + await requestStarted.promise; + + server.close(closeError); + await closeTimeRequestStarted.promise; + + expect( + (server as unknown as ConnectionInternals).pendingResponses.size, + ).toBe(0); + expect(closeTimeRequest).toBeDefined(); + await expect(closeTimeRequest!).rejects.toBe(closeError); + } finally { + client.close(); + server.close(); + await Promise.all([client.closed, server.closed]); + } + }); +}); + +function memoryStreamPair(): [Stream, Stream] { + const leftToRight = new TransformStream(); + const rightToLeft = new TransformStream(); + return [ + { + readable: rightToLeft.readable, + writable: leftToRight.writable, + }, + { + readable: leftToRight.readable, + writable: rightToLeft.writable, + }, + ]; +} diff --git a/src/jsonrpc.ts b/src/jsonrpc.ts index d0d62df..2d49bb9 100644 --- a/src/jsonrpc.ts +++ b/src/jsonrpc.ts @@ -59,6 +59,22 @@ export type AnyNotification = { params?: unknown; }; +const CANCEL_REQUEST_METHOD = "$/cancel_request"; +type JsonRpcId = string | number | null; + +/** + * Options for sending a JSON-RPC request. + */ +export type SendRequestOptions = { + /** + * Aborting this signal sends `$/cancel_request` for the outgoing request. + * Cancellation is cooperative: the returned promise is still settled by the + * peer's eventual response, which may be a normal result, partial result, or + * `RequestError.requestCancelled()`. + */ + cancellationSignal?: AbortSignal; +}; + /** * JSON-RPC result payload, either a successful result or an error. */ @@ -176,6 +192,14 @@ function isJsonRpcId(value: unknown): value is string | number | null { ); } +function cancelRequestId(params: unknown): JsonRpcId | undefined { + if (!isRecord(params) || !isJsonRpcId(params["requestId"])) { + return undefined; + } + + return params["requestId"]; +} + function isErrorResponse(value: unknown): value is ErrorResponse { return ( isRecord(value) && @@ -188,6 +212,8 @@ function isErrorResponse(value: unknown): value is ErrorResponse { type ConnectionPendingResponse = { resolve: (response: unknown) => void; reject: (error: unknown) => void; + cleanup?: () => void; + cancellationSent?: boolean; }; /** @@ -215,6 +241,11 @@ export type IncomingRequest = { * Original wire request. */ raw: AnyRequest; + /** + * AbortSignal that aborts when the peer sends `$/cancel_request` for this + * request or when the connection closes. + */ + signal: AbortSignal; /** * Responder used to complete the request. */ @@ -383,6 +414,45 @@ function errorToResult(error: unknown): Result { } } +function requestCancelledError(reason?: unknown): RequestError { + if (reason instanceof RequestError && reason.code === -32800) { + return reason; + } + + return RequestError.requestCancelled(reason); +} + +function errorToRequestResult( + error: unknown, + signal: AbortSignal, +): Result { + const requestCancelled = abortErrorToRequestCancelled(error, signal); + return requestCancelled ? requestCancelled.toResult() : errorToResult(error); +} + +function abortErrorToRequestCancelled( + error: unknown, + signal: AbortSignal, +): RequestError | undefined { + if (!signal.aborted || !isAbortError(error)) { + return undefined; + } + + return requestCancelledError(signal.reason); +} + +function isAbortError(error: unknown): boolean { + if (typeof error !== "object" || error === null) { + return false; + } + + const maybeAbortError = error as { code?: unknown; name?: unknown }; + return ( + maybeAbortError.name === "AbortError" || + maybeAbortError.code === "ABORT_ERR" + ); +} + /** * Responder for one incoming JSON-RPC request. * @@ -398,6 +468,11 @@ export class RequestResponder { */ public readonly id: string | number | null, private sendResult: (result: Result) => Promise, + /** + * AbortSignal for this incoming request. + */ + public readonly signal: AbortSignal = new AbortController().signal, + private finishRequest?: () => void, ) {} /** @@ -432,7 +507,9 @@ export class RequestResponder { } this.didRespond = true; - return this.sendResult(result); + return this.sendResult(result).finally(() => { + this.finishRequest?.(); + }); } } @@ -484,8 +561,9 @@ export class ConnectionContext { method: string, params?: Req, mapResponse?: (response: Resp) => Output, + options?: SendRequestOptions, ): Promise { - return this.connection.sendRequest(method, params, mapResponse); + return this.connection.sendRequest(method, params, mapResponse, options); } /** @@ -495,6 +573,13 @@ export class ConnectionContext { return this.connection.sendNotification(method, params); } + /** + * Sends a protocol-level request cancellation notification. + */ + sendCancelRequest(requestId: JsonRpcId): Promise { + return this.connection.sendCancelRequest(requestId); + } + /** * Registers a handler that can be disposed independently. */ @@ -534,10 +619,9 @@ export type ConnectionOptions = { * class when building generic JSON-RPC middleware or custom dispatch behavior. */ export class Connection { - private pendingResponses: Map< - string | number | null, - ConnectionPendingResponse - > = new Map(); + private pendingResponses: Map = + new Map(); + private incomingRequests: Map = new Map(); private nextRequestId = 0; private staticHandlers: JsonRpcHandler[] = []; private dynamicHandlers: Set = new Set(); @@ -670,14 +754,16 @@ export class Connection { method: string, params?: Req, mapResponse?: (response: Resp) => Output, + options: SendRequestOptions = {}, ): Promise { if (this.abortController.signal.aborted) { return rejectedPromise(this.closedReason()); } const id = this.nextRequestId++; + let cancel = () => {}; const responsePromise = new Promise((resolve, reject) => { - this.pendingResponses.set(id, { + const pendingResponse: ConnectionPendingResponse = { resolve: (response) => { try { const value = mapResponse @@ -689,15 +775,47 @@ export class Connection { } }, reject, + }; + + cancel = () => { + if (pendingResponse.cancellationSent) { + return; + } + + pendingResponse.cancellationSent = true; + pendingResponse.cleanup?.(); + void this.sendCancelRequest(id).catch(() => {}); + }; + + options.cancellationSignal?.addEventListener("abort", cancel, { + once: true, }); + pendingResponse.cleanup = () => { + options.cancellationSignal?.removeEventListener("abort", cancel); + }; + this.pendingResponses.set(id, pendingResponse); }); responsePromise.catch(() => {}); - void this.sendMessage({ jsonrpc: "2.0", id, method, params }).catch( - () => {}, - ); + const requestSent = this.sendMessage({ + jsonrpc: "2.0", + id, + method, + params, + }); + void requestSent.catch(() => {}); + if (options.cancellationSignal?.aborted) { + cancel(); + } return responsePromise; } + /** + * Sends a protocol-level request cancellation notification. + */ + sendCancelRequest(requestId: JsonRpcId): Promise { + return this.sendNotification(CANCEL_REQUEST_METHOD, { requestId }); + } + /** * Sends a JSON-RPC notification. */ @@ -718,11 +836,16 @@ export class Connection { } const closeError: unknown = error ?? new Error("ACP connection closed"); + this.abortController.abort(closeError); for (const pendingResponse of this.pendingResponses.values()) { + pendingResponse.cleanup?.(); pendingResponse.reject(closeError); } this.pendingResponses.clear(); - this.abortController.abort(closeError); + for (const controller of this.incomingRequests.values()) { + controller.abort(closeError); + } + this.incomingRequests.clear(); void this.receiveReader?.cancel(closeError).catch(() => {}); } @@ -797,6 +920,9 @@ export class Connection { } if ("method" in message) { + if (!("id" in message)) { + this.handleProtocolNotification(message); + } void this.processIncomingMessage(this.toIncomingMessage(message)).catch( (error) => this.close(error), ); @@ -850,7 +976,9 @@ export class Connection { } if (current.kind === "request" && !current.responder.responded) { - await current.responder.respondWithResult(errorToResult(error)); + await current.responder.respondWithResult( + errorToRequestResult(error, current.responder.signal), + ); } else { const response = errorToResult(error); if ("error" in response) { @@ -868,17 +996,30 @@ export class Connection { message: AnyRequest | AnyNotification, ): IncomingMessage { if ("id" in message) { + const abortController = new AbortController(); + this.incomingRequests.set(message.id, abortController); + const finishRequest = () => { + if (this.incomingRequests.get(message.id) === abortController) { + this.incomingRequests.delete(message.id); + } + }; + return { kind: "request", method: message.method, params: message.params, raw: message, - responder: new RequestResponder(message.id, (result) => - this.sendMessage({ - jsonrpc: "2.0", - id: message.id, - ...result, - }), + signal: abortController.signal, + responder: new RequestResponder( + message.id, + (result) => + this.sendMessage({ + jsonrpc: "2.0", + id: message.id, + ...result, + }), + abortController.signal, + finishRequest, ), }; } @@ -894,6 +1035,9 @@ export class Connection { private handleResponse(response: AnyResponse): void { const pendingResponse = this.pendingResponses.get(response.id); if (pendingResponse) { + this.pendingResponses.delete(response.id); + pendingResponse.cleanup?.(); + if ("result" in response) { pendingResponse.resolve(response.result); } else if ("error" in response) { @@ -902,12 +1046,29 @@ export class Connection { } else { pendingResponse.reject(RequestError.invalidRequest(response)); } - this.pendingResponses.delete(response.id); } else { console.error("Got response to unknown request", response.id); } } + private handleProtocolNotification(message: AnyNotification): void { + if (message.method !== CANCEL_REQUEST_METHOD) { + return; + } + + const requestId = cancelRequestId(message.params); + if (requestId === undefined) { + return; + } + + const controller = this.incomingRequests.get(requestId); + if (!controller || controller.signal.aborted) { + return; + } + + controller.abort(RequestError.requestCancelled({ requestId })); + } + private closedReason(): unknown { return ( this.abortController.signal.reason ?? new Error("ACP connection closed") @@ -1138,6 +1299,20 @@ export class RequestError extends Error { ); } + /** + * Execution of the request was aborted. + */ + static requestCancelled( + data?: unknown, + additionalMessage?: string, + ): RequestError { + return new RequestError( + -32800, + `Request cancelled${additionalMessage ? `: ${additionalMessage}` : ""}`, + data, + ); + } + /** * Authentication required. */ diff --git a/src/schema/index.ts b/src/schema/index.ts index 8831367..7c3d4fb 100644 --- a/src/schema/index.ts +++ b/src/schema/index.ts @@ -309,4 +309,8 @@ export const CLIENT_METHODS = { elicitation_complete: "elicitation/complete", } as const; +export const PROTOCOL_METHODS = { + cancel_request: "$/cancel_request", +} as const; + export const PROTOCOL_VERSION = 1;