From 7a10a7e6c99bc66db7dac0f3c8eb0afe46c6e197 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 23 Jun 2026 16:02:12 +1000 Subject: [PATCH 1/2] feat: add experimental HTTP backend SPI --- package.json | 5 + src/acp.ts | 41 +- src/connection.ts | 213 ++++++++- src/http-backend.conformance.test.ts | 663 +++++++++++++++++++++++++++ src/http-backend.ts | 149 ++++++ src/jsonrpc.ts | 49 +- src/server-elicitation.test.ts | 302 ++++++++++++ src/server-permission.test.ts | 69 +++ src/server-session-sse.test.ts | 41 ++ src/server-websocket-upgrade.test.ts | 57 +++ src/server.ts | 198 ++++---- src/test-support/test-http-server.ts | 8 +- 12 files changed, 1661 insertions(+), 134 deletions(-) create mode 100644 src/http-backend.conformance.test.ts create mode 100644 src/http-backend.ts create mode 100644 src/server-elicitation.test.ts diff --git a/package.json b/package.json index 023a702..290b237 100644 --- a/package.json +++ b/package.json @@ -45,6 +45,11 @@ "import": "./dist/server.js", "default": "./dist/server.js" }, + "./experimental/http-backend": { + "types": "./dist/http-backend.d.ts", + "import": "./dist/http-backend.js", + "default": "./dist/http-backend.js" + }, "./experimental/node": { "types": "./dist/node-adapter.d.ts", "import": "./dist/node-adapter.js", diff --git a/src/acp.ts b/src/acp.ts index ee7b7c0..3f6d8b9 100644 --- a/src/acp.ts +++ b/src/acp.ts @@ -15,6 +15,7 @@ export type { AnyRequest, AnyResponse, ErrorResponse, + JsonRpcRequestIdGenerator, MaybePromise, Result, SendRequestOptions, @@ -29,6 +30,7 @@ import type { HandleResult, IncomingMessage, JsonRpcHandler, + JsonRpcRequestIdGenerator, MaybePromise, SendRequestOptions, } from "./jsonrpc.js"; @@ -184,6 +186,16 @@ export interface ClientConnection extends AcpConnection { readonly agent: ClientContext; } +export interface AcpConnectionOptions { + /** + * Allocates IDs for JSON-RPC requests sent by this app-side connection. + * + * Most users should not need this. HTTP server integrations use it to keep + * server-originated request IDs unique across distributed server instances. + */ + readonly requestIdGenerator?: JsonRpcRequestIdGenerator; +} + class AcpContext { /** @internal */ constructor(private readonly cx: ConnectionContext) {} @@ -1675,7 +1687,7 @@ const appBuilder = Symbol("appBuilder"); const runAgentConnectHandlers = Symbol("runAgentConnectHandlers"); const runClientConnectHandlers = Symbol("runClientConnectHandlers"); -type AppConnectOptions = { +type AppConnectOptions = AcpConnectionOptions & { readonly deferConnectHandlers?: boolean; }; @@ -1746,7 +1758,6 @@ export class AgentApp { ): AgentConnection { return this.connectConnection(target, options).connection; } - /** * Connects this agent app to a transport stream for the lifetime of `op`. * @@ -1756,6 +1767,7 @@ export class AgentApp { connectWith( stream: Stream, op: (context: AgentContext) => MaybePromise, + options?: AppConnectOptions, ): Promise; /** * Connects this agent app directly to a client app for the lifetime of `op`. @@ -1767,8 +1779,12 @@ export class AgentApp { connectWith( target: Stream | ClientApp, op: (context: AgentContext) => MaybePromise, + options: AppConnectOptions = {}, ): Promise { - const { rawConnection, connection } = this.connectConnection(target); + const { rawConnection, connection } = this.connectConnection( + target, + options, + ); return rawConnection.runUntil(() => op(connection.client)); } @@ -1901,7 +1917,7 @@ export class AgentApp { options: AppConnectOptions = {}, ): AgentConnectionState { if (isStream(target)) { - const state = this.openStreamConnection(target); + const state = this.openStreamConnection(target, options); if (!options.deferConnectHandlers) { this[runAgentConnectHandlers](state.connection); } @@ -1925,8 +1941,13 @@ export class AgentApp { return state; } - private openStreamConnection(stream: Stream): AgentConnectionState { - const rawConnection = this.builder.connect(stream); + private openStreamConnection( + stream: Stream, + options: AcpConnectionOptions = {}, + ): AgentConnectionState { + const rawConnection = this.builder.connect(stream, { + requestIdGenerator: options.requestIdGenerator, + }); return { rawConnection, connection: agentConnection(rawConnection, this.connectHandlers), @@ -2542,10 +2563,14 @@ export class AgentSideConnection { * * @deprecated Prefer `agent({ name }).connect(stream)`. */ - constructor(toAgent: (conn: AgentSideConnection) => Agent, stream: Stream) { + constructor( + toAgent: (conn: AgentSideConnection) => Agent, + stream: Stream, + options?: AcpConnectionOptions, + ) { this.connection = legacyAgentApp(toAgent(this)) [appBuilder]() - .connect(stream); + .connect(stream, options); } /** diff --git a/src/connection.ts b/src/connection.ts index 34e8cb3..4210333 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -5,11 +5,30 @@ import { sessionIdFromResponseResult, } from "./protocol.js"; -import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; +import type { + AnyMessage, + AnyResponse, + JsonRpcRequestIdGenerator, +} from "./jsonrpc.js"; import type { Stream } from "./stream.js"; +import type { + AcpHttpBackend, + HttpBackendAcceptClientMethodMessageInput, + HttpBackendAcceptClientResponseInput, + HttpBackendAcceptResult, + HttpBackendCloseConnectionInput, + HttpBackendInitializeInput, + HttpBackendInitializeResult, + HttpBackendLoadConnectionInput, + HttpBackendLoadedConnection, + HttpBackendOpenConnectionStreamInput, + HttpBackendOpenSessionStreamInput, + HttpBackendTouchConnectionInput, +} from "./http-backend.js"; export interface AgentConnectOptions { readonly deferConnectHandlers?: boolean; + readonly requestIdGenerator?: JsonRpcRequestIdGenerator; } export interface AgentConnectionLifecycle { @@ -115,7 +134,7 @@ export class ConnectionState { private shutdownPromise: Promise | undefined; private resolveClosed: () => void = () => {}; - constructor(agent: AgentConnector) { + constructor(agent: AgentConnector, options: AgentConnectOptions = {}) { this.connectionId = globalThis.crypto.randomUUID(); this.closed = new Promise((resolve) => { this.resolveClosed = resolve; @@ -133,6 +152,7 @@ export class ConnectionState { this.agentConnection = agent.connect(stream, { deferConnectHandlers: true, + requestIdGenerator: options.requestIdGenerator, }); this.observeAgentConnection(); } @@ -206,6 +226,18 @@ export class ConnectionState { return stream; } + trackPendingResponseRoute(key: string, route: ResponseRoute): void { + this.pendingRoutes.set(key, route); + } + + clientResponseRoute(key: string): ResponseRoute | undefined { + return this.clientResponseRoutes.get(key); + } + + clearClientResponseRoute(key: string): void { + this.clientResponseRoutes.delete(key); + } + async shutdown(): Promise { if (!this.shutdownPromise) { this.shutdownPromise = this.runShutdown(); @@ -368,15 +400,21 @@ export class ConnectionRegistry { private readonly connections = new Map(); private readonly pendingConnections = new Map(); - createConnection(agent: AgentConnector): ConnectionState { - const connection = new ConnectionState(agent); + createConnection( + agent: AgentConnector, + options: AgentConnectOptions = {}, + ): ConnectionState { + const connection = new ConnectionState(agent, options); this.connections.set(connection.connectionId, connection); this.trackConnectionClose(connection); return connection; } - createPendingConnection(agent: AgentConnector): ConnectionState { - const connection = new ConnectionState(agent); + createPendingConnection( + agent: AgentConnector, + options: AgentConnectOptions = {}, + ): ConnectionState { + const connection = new ConnectionState(agent, options); this.pendingConnections.set(connection.connectionId, connection); this.trackConnectionClose(connection); return connection; @@ -443,6 +481,169 @@ export class ConnectionRegistry { } } +export class InMemoryAcpHttpBackend implements AcpHttpBackend { + constructor( + private readonly registry = new ConnectionRegistry(), + readonly generateServerRequestId?: JsonRpcRequestIdGenerator, + ) {} + + async initialize({ + agent, + message, + signal, + }: HttpBackendInitializeInput): Promise { + if (!("id" in message) || message.id === null) { + throw new Error("Initialize request must include an ID"); + } + + const connection = this.registry.createPendingConnection(agent, { + requestIdGenerator: this.generateServerRequestId, + }); + + try { + await connection.writeInbound(message); + const response = await connection.recvInitial(message.id); + + if (signal.aborted) { + throw new Error("Request aborted"); + } + + connection.startRouter(); + this.registry.register(connection); + connection.startConnectHandlers(); + + return { + connectionId: connection.connectionId, + response, + }; + } catch (error) { + this.registry.discard(connection.connectionId); + throw error; + } + } + + async loadConnection({ + connectionId, + }: HttpBackendLoadConnectionInput): Promise< + HttpBackendLoadedConnection | undefined + > { + const connection = this.registry.get(connectionId); + + if (!connection) { + return undefined; + } + + return { connectionId }; + } + + async touchConnection( + _input: HttpBackendTouchConnectionInput, + ): Promise { + // In-memory connections do not need TTL refresh. + } + + async acceptClientMethodMessage({ + connectionId, + message, + route, + responseRoute, + }: HttpBackendAcceptClientMethodMessageInput): Promise { + const connection = this.registry.get(connectionId); + + if (!connection) { + return { + ok: false, + status: 404, + message: "Unknown Acp-Connection-Id", + }; + } + + if (route !== "connection") { + connection.ensureSession(route.session); + } + + const key = "id" in message ? messageIdKey(message.id) : undefined; + if (key) { + connection.trackPendingResponseRoute(key, responseRoute); + } + + await connection.writeInbound(message); + return { ok: true }; + } + + async acceptClientResponse({ + connectionId, + message, + headerSessionId, + }: HttpBackendAcceptClientResponseInput): Promise { + const connection = this.registry.get(connectionId); + + if (!connection) { + return { + ok: false, + status: 404, + message: "Unknown Acp-Connection-Id", + }; + } + + const key = messageIdKey(message.id); + const route = key ? connection.clientResponseRoute(key) : undefined; + + if (route && route !== "connection" && !headerSessionId) { + return { + ok: false, + status: 400, + message: "Missing Acp-Session-Id", + }; + } + + if (route && route !== "connection" && headerSessionId !== route.session) { + return { + ok: false, + status: 400, + message: "Mismatched Acp-Session-Id", + }; + } + + if (key) { + connection.clearClientResponseRoute(key); + } + + await connection.writeInbound(message); + return { ok: true }; + } + + async openConnectionStream({ + connectionId, + }: HttpBackendOpenConnectionStreamInput): Promise< + OutboundSubscription | undefined + > { + return this.registry.get(connectionId)?.connectionStream.subscribe(); + } + + async openSessionStream({ + connectionId, + sessionId, + }: HttpBackendOpenSessionStreamInput): Promise< + OutboundSubscription | undefined + > { + return this.registry + .get(connectionId) + ?.ensureSession(sessionId) + .subscribe(); + } + + async closeConnection({ + connectionId, + }: HttpBackendCloseConnectionInput): Promise { + return Boolean(this.registry.remove(connectionId)); + } + + async close(): Promise { + await this.registry.closeAll(); + } +} + class OutboundSubscriber { readonly stream: ReadableStream; diff --git a/src/http-backend.conformance.test.ts b/src/http-backend.conformance.test.ts new file mode 100644 index 0000000..65dbaff --- /dev/null +++ b/src/http-backend.conformance.test.ts @@ -0,0 +1,663 @@ +import { describe, expect, it } from "vitest"; + +import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { AcpServer } from "./server.js"; +import { parseSseStream } from "./sse.js"; +import { PROTOCOL_VERSION, agent as createAgentApp, methods } from "./acp.js"; +import { createTestAgentApp } from "./test-support/test-agent.js"; + +import type { AgentApp } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, +}; + +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +}; + +const promptRequest = { + jsonrpc: "2.0", + id: 3, + method: "session/prompt", + params: { + sessionId: "session-1", + prompt: [{ type: "text", text: "Hello" }], + }, +}; + +type HarnessRole = + | "initialize" + | "post" + | "connectionSse" + | "sessionSse" + | "delete"; + +interface HttpBackendHarness { + readonly name: string; + readonly serverRequestIdPrefix: string; + handle(role: HarnessRole, request: Request): Promise; + close(): Promise; +} + +type HarnessFactory = (createAgent: () => AgentApp) => HttpBackendHarness; + +const harnesses: Array<{ + readonly name: string; + readonly createHarness: HarnessFactory; +}> = [ + { + name: "in-memory backend", + createHarness: (createAgent) => { + let nextRequestId = 0; + const server = new AcpServer({ + createAgent, + httpBackend: new InMemoryAcpHttpBackend( + new ConnectionRegistry(), + () => `memory-${nextRequestId++}`, + ), + }); + + return { + name: "in-memory backend", + serverRequestIdPrefix: "memory-", + handle: (_role, request) => server.handleRequest(request), + close: () => server.close(), + }; + }, + }, + { + name: "fake distributed backend", + createHarness: (createAgent) => { + const registry = new ConnectionRegistry(); + const counters = new Map(); + const createServer = (nodeId: string): AcpServer => + new AcpServer({ + createAgent, + httpBackend: new InMemoryAcpHttpBackend(registry, () => { + const next = counters.get(nodeId) ?? 0; + counters.set(nodeId, next + 1); + return `${nodeId}-${next}`; + }), + }); + const servers = { + initialize: createServer("node-a"), + post: createServer("node-b"), + connectionSse: createServer("node-c"), + sessionSse: createServer("node-d"), + delete: createServer("node-e"), + } satisfies Record; + + return { + name: "fake distributed backend", + serverRequestIdPrefix: "node-a-", + handle: (role, request) => servers[role].handleRequest(request), + close: async () => { + await Promise.all( + Array.from(new Set(Object.values(servers)), (server) => + server.close(), + ), + ); + }, + }; + }, + }, +]; + +describe.each(harnesses)( + "AcpServer HTTP backend conformance: $name", + ({ createHarness }) => { + it("preserves initialize, connected POST, connection-stream replay, and session-stream late attach", async () => { + const harness = createHarness(() => + createTestAgentApp({ + newSession: () => ({ sessionId: "session-1" }), + }), + ); + + try { + const connectionId = await initialize(harness); + + expect( + await postJson(harness, "post", sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 202 }); + + const connectionSse = await openConnectionSse(harness, connectionId); + expect(connectionSse.status).toBe(200); + expect(await readSseMessages(connectionSse, 1)).toMatchObject([ + { + jsonrpc: "2.0", + id: 2, + result: { sessionId: "session-1" }, + }, + ]); + + expect( + await postJson(harness, "post", promptRequest, { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: "session-1", + }), + ).toMatchObject({ status: 202 }); + + const sessionSse = await openSessionSse( + harness, + connectionId, + "session-1", + ); + expect(sessionSse.status).toBe(200); + expect(await readSseMessages(sessionSse, 2)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId: "session-1", + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-1" }, + }, + }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }, + ]); + } finally { + await harness.close(); + } + }); + + it("closes HTTP backend connections through DELETE", async () => { + const harness = createHarness(() => createTestAgentApp()); + + try { + const connectionId = await initialize(harness); + const deleted = await harness.handle( + "delete", + new Request("http://example.test/acp", { + method: "DELETE", + headers: { + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); + + expect(deleted.status).toBe(202); + expect( + await postJson(harness, "post", sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 404 }); + expect(await openConnectionSse(harness, connectionId)).toMatchObject({ + status: 404, + }); + } finally { + await harness.close(); + } + }); + + it("routes permission and elicitation request/response flows without method allowlists", async () => { + const harness = createHarness(() => createInteractiveAgent()); + + try { + const connectionId = await initialize(harness); + const sessionId = await createSession(harness, connectionId); + const connectionSse = await openConnectionSse(harness, connectionId); + const connectionEvents = createSseMessageIterator(connectionSse); + const sessionSse = await openSessionSse( + harness, + connectionId, + sessionId, + ); + const sessionEvents = createSseMessageIterator(sessionSse); + + expect( + await postJson(harness, "post", promptRequest, { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "before-permission" }, + }, + }, + }); + + const permissionRequest = await readNextSseMessage(sessionEvents); + expect(permissionRequest).toMatchObject({ + jsonrpc: "2.0", + id: expect.stringMatching( + new RegExp(`^${escapeRegExp(harness.serverRequestIdPrefix)}`), + ), + method: "session/request_permission", + params: { + sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + }, + }); + + expect( + await postJson( + harness, + "post", + { + jsonrpc: "2.0", + id: readMessageId(permissionRequest), + result: { + outcome: { + outcome: "selected", + optionId: "allow", + }, + }, + }, + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "permission-selected-allow" }, + }, + }, + }); + + const elicitationRequest = await readNextSseMessage(sessionEvents); + expect(elicitationRequest).toMatchObject({ + jsonrpc: "2.0", + id: expect.stringMatching( + new RegExp(`^${escapeRegExp(harness.serverRequestIdPrefix)}`), + ), + method: "elicitation/create", + params: { + sessionId, + mode: "form", + message: "Name", + }, + }); + expect(readMessageId(elicitationRequest)).not.toBe( + readMessageId(permissionRequest), + ); + + expect( + await postJson( + harness, + "post", + { + jsonrpc: "2.0", + id: readMessageId(elicitationRequest), + result: { + action: "accept", + content: { name: "Alice" }, + }, + }, + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(connectionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "elicitation/complete", + params: { elicitationId: "elicitation-1" }, + }); + expect(await readSseIteratorMessages(sessionEvents, 2)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "elicitation-accept-Alice" }, + }, + }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }, + ]); + + await sessionEvents.return?.(); + await connectionEvents.return?.(); + await sessionSse.body?.cancel(); + await connectionSse.body?.cancel(); + } finally { + await harness.close(); + } + }, 10_000); + }, +); + +function createInteractiveAgent(): AgentApp { + return createAgentApp({ name: "http-backend-conformance-agent" }) + .onRequest(methods.agent.initialize, () => ({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + })) + .onRequest(methods.agent.session.new, () => ({ sessionId: "session-1" })) + .onRequest(methods.agent.authenticate, () => ({})) + .onRequest(methods.agent.session.prompt, async (c) => { + await c.client.notify(methods.client.session.update, { + sessionId: c.params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: "before-permission", + }, + }, + }); + + const permission = await c.client.request( + methods.client.session.requestPermission, + { + sessionId: c.params.sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + options: [ + { + kind: "allow_once", + name: "Allow once", + optionId: "allow", + }, + { + kind: "reject_once", + name: "Reject once", + optionId: "reject", + }, + ], + }, + ); + + if (!isPermissionResponse(permission)) { + throw new Error("Expected permission response"); + } + + await c.client.notify(methods.client.session.update, { + sessionId: c.params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: + permission.outcome.outcome === "selected" + ? `permission-selected-${permission.outcome.optionId}` + : "permission-cancelled", + }, + }, + }); + + const elicitation = await c.client.request( + methods.client.elicitation.create, + { + sessionId: c.params.sessionId, + mode: "form", + message: "Name", + requestedSchema: { + type: "object", + properties: { + name: { type: "string" }, + }, + }, + }, + ); + + if (!isAcceptedElicitation(elicitation)) { + throw new Error("Expected accepted elicitation response"); + } + + await c.client.notify(methods.client.elicitation.complete, { + elicitationId: "elicitation-1", + }); + await c.client.notify(methods.client.session.update, { + sessionId: c.params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: `elicitation-accept-${String(elicitation.content["name"])}`, + }, + }, + }); + + return { stopReason: "end_turn" }; + }) + .onNotification(methods.agent.session.cancel, () => {}); +} + +async function initialize(harness: HttpBackendHarness): Promise { + const response = await postJson(harness, "initialize", initializeRequest); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +async function createSession( + harness: HttpBackendHarness, + connectionId: string, +): Promise { + const response = await openConnectionSse(harness, connectionId); + const events = createSseMessageIterator(response); + + try { + expect( + await postJson(harness, "post", sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 202 }); + + return readSessionId(await readNextSseMessage(events)); + } finally { + await events.return?.(); + await response.body?.cancel(); + } +} + +function openConnectionSse( + harness: HttpBackendHarness, + connectionId: string, +): Promise { + return harness.handle( + "connectionSse", + new Request("http://example.test/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); +} + +function openSessionSse( + harness: HttpBackendHarness, + connectionId: string, + sessionId: string, +): Promise { + return harness.handle( + "sessionSse", + new Request("http://example.test/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + }), + ); +} + +function postJson( + harness: HttpBackendHarness, + role: HarnessRole, + body: unknown, + headers: Record = {}, +): Promise { + return harness.handle( + role, + new Request("http://example.test/acp", { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }), + ); +} + +function createSseMessageIterator( + response: Response, +): AsyncIterator { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + return parseSseStream(response.body)[Symbol.asyncIterator](); +} + +async function readNextSseMessage( + iterator: AsyncIterator, +): Promise { + const result = await iterator.next(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + return result.value; +} + +async function readSseIteratorMessages( + iterator: AsyncIterator, + count: number, +): Promise { + const messages: AnyMessage[] = []; + + for (const __unused of Array.from({ length: count })) { + void __unused; + messages.push(await readNextSseMessage(iterator)); + } + + return messages; +} + +async function readSseMessages( + response: Response, + count: number, +): Promise { + const iterator = createSseMessageIterator(response); + + try { + return await readSseIteratorMessages(iterator, count); + } finally { + await iterator.return?.(); + await response.body?.cancel(); + } +} + +function readSessionId(message: AnyMessage): string { + if (!("result" in message) || !isRecord(message.result)) { + throw new Error("Expected session/new response result"); + } + + const sessionId = message.result["sessionId"]; + + if (typeof sessionId !== "string") { + throw new Error("Expected session ID"); + } + + return sessionId; +} + +function readMessageId(message: AnyMessage): string | number | null { + if (!("id" in message)) { + throw new Error("Expected message ID"); + } + + return message.id; +} + +function isPermissionResponse(value: unknown): value is { + readonly outcome: + | { readonly outcome: "cancelled" } + | { readonly outcome: "selected"; readonly optionId: string }; +} { + if (!isRecord(value) || !isRecord(value["outcome"])) { + return false; + } + + const outcome = value["outcome"]; + return ( + outcome["outcome"] === "cancelled" || outcome["outcome"] === "selected" + ); +} + +function isAcceptedElicitation(value: unknown): value is { + readonly action: "accept"; + readonly content: Record; +} { + return ( + isRecord(value) && + value["action"] === "accept" && + isRecord(value["content"]) + ); +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function escapeRegExp(value: string): string { + return value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); +} diff --git a/src/http-backend.ts b/src/http-backend.ts new file mode 100644 index 0000000..eb26bc5 --- /dev/null +++ b/src/http-backend.ts @@ -0,0 +1,149 @@ +import type { + AgentConnector, + OutboundSubscription, + ResponseRoute, +} from "./connection.js"; +import type { + AnyMessage, + AnyNotification, + AnyRequest, + AnyResponse, +} from "./jsonrpc.js"; + +export type HttpBackendServerRequestIdGenerator = () => string | number; + +export interface HttpBackendInitializeInput { + readonly agent: AgentConnector; + readonly message: AnyMessage; + readonly signal: AbortSignal; +} + +export interface HttpBackendInitializeResult { + readonly connectionId: string; + readonly response: AnyResponse; +} + +export interface HttpBackendLoadConnectionInput { + readonly connectionId: string; +} + +export interface HttpBackendTouchConnectionInput { + readonly connectionId: string; +} + +export interface HttpBackendCloseConnectionInput { + readonly connectionId: string; +} + +export interface HttpBackendAcceptClientMethodMessageInput { + readonly connectionId: string; + readonly message: AnyRequest | AnyNotification; + readonly route: ResponseRoute; + readonly responseRoute: ResponseRoute; +} + +export interface HttpBackendAcceptClientResponseInput { + readonly connectionId: string; + readonly message: AnyResponse; + readonly headerSessionId: string | null; +} + +export interface HttpBackendOpenConnectionStreamInput { + readonly connectionId: string; + readonly cursor?: string; +} + +export interface HttpBackendOpenSessionStreamInput { + readonly connectionId: string; + readonly sessionId: string; + readonly cursor?: string; +} + +export type HttpBackendAcceptResult = + | { + readonly ok: true; + } + | { + readonly ok: false; + readonly status: number; + readonly message: string; + }; + +export interface HttpBackendLoadedConnection { + readonly connectionId: string; +} + +export interface AcpHttpBackend { + /** + * Allocates JSON-RPC request IDs for server-originated client requests. + * + * Distributed HTTP backends should provide IDs that do not collide across + * server instances. The in-memory backend intentionally keeps the existing + * monotonically increasing numeric behavior. + */ + readonly generateServerRequestId?: HttpBackendServerRequestIdGenerator; + + /** + * Creates a new HTTP connection, forwards initialize to the agent, and only + * exposes the connection for subsequent connected HTTP requests once + * initialize succeeds. + */ + initialize( + input: HttpBackendInitializeInput, + ): Promise; + + /** + * Loads connection metadata/state. Backends may also use this as the activity + * touch point for connected HTTP requests. + */ + loadConnection( + input: HttpBackendLoadConnectionInput, + ): Promise; + + /** + * Refreshes connection metadata after accepted activity. + */ + touchConnection(input: HttpBackendTouchConnectionInput): Promise; + + /** + * Accepts any client-originated ACP request/notification after the HTTP + * server has performed protocol-generic route determination. + */ + acceptClientMethodMessage( + input: HttpBackendAcceptClientMethodMessageInput, + ): Promise; + + /** + * Accepts any client response to a server-originated request and validates + * response routing against the backend-owned server request route map. + */ + acceptClientResponse( + input: HttpBackendAcceptClientResponseInput, + ): Promise; + + /** + * Opens the connection-level hot stream. + */ + openConnectionStream( + input: HttpBackendOpenConnectionStreamInput, + ): Promise; + + /** + * Opens a session-level hot stream. + */ + openSessionStream( + input: HttpBackendOpenSessionStreamInput, + ): Promise; + + /** + * Closes a connection and releases transport state. + * + * Returns false when the connection is unknown. + */ + closeConnection(input: HttpBackendCloseConnectionInput): Promise; + + /** + * Closes backend-owned resources. + */ + close(): Promise; +} diff --git a/src/jsonrpc.ts b/src/jsonrpc.ts index 2d49bb9..6c0c77b 100644 --- a/src/jsonrpc.ts +++ b/src/jsonrpc.ts @@ -221,6 +221,15 @@ type ConnectionPendingResponse = { */ export type MaybePromise = T | Promise; +/** + * Allocates IDs for JSON-RPC requests sent by this connection. + * + * The default generator is a per-connection numeric sequence starting at `0`. + * HTTP server integrations can inject a custom generator so server-originated + * request IDs remain unique across distributed server instances. + */ +export type JsonRpcRequestIdGenerator = () => string | number; + /** * Incoming request passed to JSON-RPC handlers. */ @@ -610,6 +619,13 @@ export type ConnectionOptions = { * Extra handlers to prepend to the connection's handler chain. */ handlers?: JsonRpcHandler[]; + + /** + * Allocates IDs for outbound JSON-RPC requests. + * + * Defaults to the existing per-connection numeric sequence: 0, 1, 2, ... + */ + requestIdGenerator?: JsonRpcRequestIdGenerator; }; /** @@ -623,6 +639,8 @@ export class Connection { new Map(); private incomingRequests: Map = new Map(); private nextRequestId = 0; + private requestIdGenerator: JsonRpcRequestIdGenerator = () => + this.nextRequestId++; private staticHandlers: JsonRpcHandler[] = []; private dynamicHandlers: Set = new Set(); private stream!: Stream; @@ -655,20 +673,25 @@ export class Connection { const notificationHandler = notificationHandlerOrHandlers as NotificationHandler; const stream = streamOrOptions as Stream; - this.initialize(stream, [ - ...(options?.handlers ?? []), - this.legacyHandler(requestHandler, notificationHandler), - ]); + this.initialize( + stream, + [ + ...(options?.handlers ?? []), + this.legacyHandler(requestHandler, notificationHandler), + ], + options, + ); return; } const stream = requestHandlerOrStream; const handlers = notificationHandlerOrHandlers as JsonRpcHandler[]; const connectionOptions = streamOrOptions as ConnectionOptions | undefined; - this.initialize(stream, [ - ...(connectionOptions?.handlers ?? []), - ...handlers, - ]); + this.initialize( + stream, + [...(connectionOptions?.handlers ?? []), ...handlers], + connectionOptions, + ); } /** @@ -760,7 +783,7 @@ export class Connection { return rejectedPromise(this.closedReason()); } - const id = this.nextRequestId++; + const id = this.requestIdGenerator(); let cancel = () => {}; const responsePromise = new Promise((resolve, reject) => { const pendingResponse: ConnectionPendingResponse = { @@ -849,9 +872,15 @@ export class Connection { void this.receiveReader?.cancel(closeError).catch(() => {}); } - private initialize(stream: Stream, handlers: JsonRpcHandler[]): void { + private initialize( + stream: Stream, + handlers: JsonRpcHandler[], + options?: ConnectionOptions, + ): void { this.stream = stream; this.staticHandlers = handlers; + this.requestIdGenerator = + options?.requestIdGenerator ?? (() => this.nextRequestId++); this.closedPromise = new Promise((resolve) => { this.abortController.signal.addEventListener("abort", () => resolve()); }); diff --git a/src/server-elicitation.test.ts b/src/server-elicitation.test.ts new file mode 100644 index 0000000..5cc08ca --- /dev/null +++ b/src/server-elicitation.test.ts @@ -0,0 +1,302 @@ +import { describe, expect, it } from "vitest"; + +import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { AcpServer } from "./server.js"; +import { parseSseStream } from "./sse.js"; +import { PROTOCOL_VERSION, agent as createAgentApp, methods } from "./acp.js"; + +import type { AgentApp } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, +}; + +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +}; + +const promptRequest = { + jsonrpc: "2.0", + id: 3, + method: "session/prompt", + params: { + sessionId: "elicitation-session", + prompt: [{ type: "text", text: "Start" }], + }, +}; + +describe("AcpServer elicitation requests over HTTP", () => { + it("routes elicitation requests through the HTTP backend with injected request IDs", async () => { + let nextRequestId = 0; + const server = new AcpServer({ + createAgent: () => createElicitationAgent(), + httpBackend: new InMemoryAcpHttpBackend( + new ConnectionRegistry(), + () => `elicitation-${nextRequestId++}`, + ), + }); + + try { + const connectionId = await initialize(server); + const sessionId = await createSession(server, connectionId); + const connectionSse = await openConnectionSse(server, connectionId); + const connectionEvents = createSseMessageIterator(connectionSse); + const sessionSse = await openSessionSse(server, connectionId, sessionId); + const sessionEvents = createSseMessageIterator(sessionSse); + + expect( + await postJson(server, promptRequest, { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + const elicitationRequest = await readNextSseMessage(sessionEvents); + expect(elicitationRequest).toMatchObject({ + jsonrpc: "2.0", + id: "elicitation-0", + method: "elicitation/create", + params: { + sessionId, + mode: "form", + message: "Please enter your name", + }, + }); + + expect( + await postJson( + server, + { + jsonrpc: "2.0", + id: readMessageId(elicitationRequest), + result: { + action: "accept", + content: { name: "Alice" }, + }, + }, + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(connectionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "elicitation/complete", + params: { elicitationId: "elicitation-1" }, + }); + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }); + + await connectionEvents.return?.(); + await sessionEvents.return?.(); + await connectionSse.body?.cancel(); + await sessionSse.body?.cancel(); + } finally { + await server.close(); + } + }); +}); + +function createElicitationAgent(): AgentApp { + return createAgentApp({ name: "elicitation-http-agent" }) + .onRequest(methods.agent.initialize, () => ({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + })) + .onRequest(methods.agent.session.new, () => ({ + sessionId: "elicitation-session", + })) + .onRequest(methods.agent.authenticate, () => ({})) + .onRequest(methods.agent.session.prompt, async (c) => { + const elicitation = await c.client.request( + methods.client.elicitation.create, + { + sessionId: c.params.sessionId, + mode: "form", + message: "Please enter your name", + requestedSchema: { + type: "object", + properties: { + name: { type: "string" }, + }, + }, + }, + ); + + if (!isAcceptedElicitation(elicitation)) { + throw new Error("Expected accepted elicitation response"); + } + + await c.client.notify(methods.client.elicitation.complete, { + elicitationId: "elicitation-1", + }); + + return { stopReason: "end_turn" }; + }) + .onNotification(methods.agent.session.cancel, () => {}); +} + +async function initialize(server: AcpServer): Promise { + const response = await postJson(server, initializeRequest); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +async function createSession( + server: AcpServer, + connectionId: string, +): Promise { + const response = await openConnectionSse(server, connectionId); + const events = createSseMessageIterator(response); + + try { + expect( + await postJson(server, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 202 }); + + return readSessionId(await readNextSseMessage(events)); + } finally { + await events.return?.(); + await response.body?.cancel(); + } +} + +function openConnectionSse( + server: AcpServer, + connectionId: string, +): Promise { + return server.handleRequest( + new Request("http://example.test/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); +} + +function openSessionSse( + server: AcpServer, + connectionId: string, + sessionId: string, +): Promise { + return server.handleRequest( + new Request("http://example.test/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + }), + ); +} + +function postJson( + server: AcpServer, + body: unknown, + headers: Record = {}, +): Promise { + return server.handleRequest( + new Request("http://example.test/acp", { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }), + ); +} + +function createSseMessageIterator( + response: Response, +): AsyncIterator { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + return parseSseStream(response.body)[Symbol.asyncIterator](); +} + +async function readNextSseMessage( + iterator: AsyncIterator, +): Promise { + const result = await iterator.next(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + return result.value; +} + +function readSessionId(message: AnyMessage): string { + if (!("result" in message) || !isRecord(message.result)) { + throw new Error("Expected session/new response result"); + } + + const sessionId = message.result["sessionId"]; + + if (typeof sessionId !== "string") { + throw new Error("Expected session ID"); + } + + return sessionId; +} + +function readMessageId(message: AnyMessage): string | number | null { + if (!("id" in message)) { + throw new Error("Expected message ID"); + } + + return message.id; +} + +function isAcceptedElicitation(value: unknown): value is { + readonly action: "accept"; + readonly content: Record; +} { + return ( + isRecord(value) && + value["action"] === "accept" && + isRecord(value["content"]) + ); +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} diff --git a/src/server-permission.test.ts b/src/server-permission.test.ts index b9c16fc..f5d12e0 100644 --- a/src/server-permission.test.ts +++ b/src/server-permission.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it } from "vitest"; +import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; import { EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, @@ -101,6 +102,74 @@ describe("AcpServer permission requests over HTTP", () => { } }, 10_000); + it("uses injected backend request IDs for permission requests", async () => { + let nextRequestId = 0; + const server = await startTestServer( + () => createTestAgentApp({ enablePermission: true }), + { + httpBackend: new InMemoryAcpHttpBackend( + new ConnectionRegistry(), + () => `permission-${nextRequestId++}`, + ), + }, + ); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + const sessionEvents = createSseMessageIterator(sessionSse); + + expect( + await postJson(server.url, createPromptRequest(3, sessionId), { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + await readNextSseMessage(sessionEvents); + const permissionRequest = await readNextSseMessage(sessionEvents); + + expect(permissionRequest).toMatchObject({ + jsonrpc: "2.0", + id: "permission-0", + method: "session/request_permission", + params: { sessionId }, + }); + + expect( + await postJson( + server.url, + { + jsonrpc: "2.0", + id: readMessageId(permissionRequest), + result: { + outcome: { + outcome: "selected", + optionId: "allow", + }, + }, + }, + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ), + ).toMatchObject({ status: 202 }); + + await readNextSseMessage(sessionEvents); + await readNextSseMessage(sessionEvents); + await sessionEvents.return?.(); + await sessionSse.body?.cancel(); + } finally { + await server.close(); + } + }, 10_000); + it("routes permission requests over session SSE and accepts client responses", async () => { const server = await startTestServer(() => createTestAgentApp({ enablePermission: true }), diff --git a/src/server-session-sse.test.ts b/src/server-session-sse.test.ts index e1af341..ff0f5c4 100644 --- a/src/server-session-sse.test.ts +++ b/src/server-session-sse.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it } from "vitest"; +import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; import { EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, @@ -392,6 +393,46 @@ describe("AcpServer session SSE", () => { } }); + it("replays buffered session messages through a configured HTTP backend", async () => { + const server = await startTestServer(() => createTestAgentApp(), { + httpBackend: new InMemoryAcpHttpBackend(new ConnectionRegistry()), + }); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const accepted = await postJson( + server.url, + createPromptRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + + expect(accepted.status).toBe(202); + expect(await readSseMessages(sessionSse, 2)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { sessionId }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }, + ]); + } finally { + await server.close(); + } + }); + it("replays buffered session messages when session SSE attaches after prompt", async () => { const server = await startTestServer(); diff --git a/src/server-websocket-upgrade.test.ts b/src/server-websocket-upgrade.test.ts index 9e43b28..5331b81 100644 --- a/src/server-websocket-upgrade.test.ts +++ b/src/server-websocket-upgrade.test.ts @@ -13,6 +13,7 @@ import { createTestAgentApp, TestAgent } from "./test-support/test-agent.js"; import { handleWebSocketConnection } from "./ws-server.js"; import type { InitializeResponse } from "./acp.js"; +import type { AcpHttpBackend } from "./http-backend.js"; import type { AnyMessage } from "./jsonrpc.js"; import type { WebSocketServerSocket } from "./ws-server.js"; @@ -521,6 +522,43 @@ describe("AcpServer prepared WebSocket upgrades", () => { } }); + it("keeps prepared WebSocket upgrades on the in-memory path when an HTTP backend is configured", async () => { + const server = new AcpServer({ + createAgent: () => + createTestAgentApp({ + newSession: () => ({ sessionId: "ws-session" }), + }), + httpBackend: createThrowingHttpBackend(), + }); + const socket = new FakeServerSocket(); + + try { + server.prepareWebSocketUpgrade().accept(socket); + socket.receive(JSON.stringify(initializeRequest)); + + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + + socket.receive(JSON.stringify(sessionNewRequest)); + + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + id: sessionNewRequest.id, + result: { + sessionId: "ws-session", + }, + }); + } finally { + socket.close(); + await server.close(); + } + }); + it("keeps existing double-settle behavior for prepared WebSocket upgrades", async () => { const server = new AcpServer({ createAgent: () => createTestAgentApp(), @@ -547,6 +585,25 @@ describe("AcpServer prepared WebSocket upgrades", () => { }); }); +function createThrowingHttpBackend(): AcpHttpBackend { + const error = () => new Error("HTTP backend must not be used by WebSocket"); + + return { + generateServerRequestId: () => { + throw error(); + }, + initialize: () => Promise.reject(error()), + loadConnection: () => Promise.reject(error()), + touchConnection: () => Promise.reject(error()), + acceptClientMethodMessage: () => Promise.reject(error()), + acceptClientResponse: () => Promise.reject(error()), + openConnectionStream: () => Promise.reject(error()), + openSessionStream: () => Promise.reject(error()), + closeConnection: () => Promise.reject(error()), + close: () => Promise.resolve(), + }; +} + function recordingFactory( createdBy: string[], label: string, diff --git a/src/server.ts b/src/server.ts index ff11fee..b02215a 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,11 +1,10 @@ -import { ConnectionRegistry } from "./connection.js"; +import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; import { EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, isInitializeRequest, - messageIdKey, methodRequiresSessionHeader, sessionIdFromParams, } from "./protocol.js"; @@ -21,17 +20,12 @@ import type { import type { AgentConnector, - ConnectionState, OutboundSubscription, ResponseRoute, } from "./connection.js"; -import type { - AnyMessage, - AnyNotification, - AnyRequest, - AnyResponse, -} from "./jsonrpc.js"; +import type { AnyMessage, AnyNotification, AnyRequest } from "./jsonrpc.js"; import type { Agent, AgentApp } from "./acp.js"; +import type { AcpHttpBackend } from "./http-backend.js"; export type AgentFactory = () => AgentApp; /** @deprecated Prefer {@link AgentFactory}. */ @@ -86,7 +80,15 @@ type OptionalAgentOption = }; /** Options for creating an ACP server transport. */ -export type AcpServerOptions = AgentOption; +export type AcpServerOptions = AgentOption & { + /** + * Experimental backend for Streamable HTTP transport state. + * + * WebSocket upgrades always use the server's in-memory registry and are not + * affected by this backend. + */ + readonly httpBackend?: AcpHttpBackend; +}; export type HandleRequestOptions = OptionalAgentOption; @@ -108,10 +110,13 @@ export interface PreparedWebSocketUpgrade { export class AcpServer { private readonly agent: AgentConnector; private readonly registry = new ConnectionRegistry(); + private readonly httpBackend: AcpHttpBackend; private readonly webSocketSessions = new Set(); constructor(options: AcpServerOptions) { this.agent = resolveAgent(options); + this.httpBackend = + options.httpBackend ?? new InMemoryAcpHttpBackend(this.registry); } /** Handles one Streamable HTTP ACP request. */ @@ -124,11 +129,11 @@ export class AcpServer { } if (req.method === "GET") { - return this.handleGet(req); + return await this.handleGet(req); } if (req.method === "DELETE") { - return this.handleDelete(req); + return await this.handleDelete(req); } return textResponse("Method Not Allowed", 405); @@ -174,11 +179,12 @@ export class AcpServer { /** Closes all active ACP connections owned by this server. */ async close(): Promise { const closeConnections = this.registry.closeAll(); + const closeHttpBackend = this.httpBackend.close(); const closeWebSockets = Promise.all( Array.from(this.webSocketSessions, (session) => session.close()), ); - await Promise.all([closeConnections, closeWebSockets]); + await Promise.all([closeConnections, closeHttpBackend, closeWebSockets]); } private async handlePost( @@ -219,25 +225,28 @@ export class AcpServer { return textResponse("Missing Acp-Connection-Id", 400); } - const connection = this.registry.get(connectionId); + const connection = await this.httpBackend.loadConnection({ connectionId }); if (!connection) { return textResponse("Unknown Acp-Connection-Id", 404); } const forwarded = await this.forwardConnectedMessage( - connection, + connectionId, body.value, req.headers, ); + if (!forwarded.ok) { return textResponse(forwarded.message, forwarded.status); } + await this.httpBackend.touchConnection({ connectionId }); + return emptyResponse(202); } - private handleGet(req: Request): Response { + private async handleGet(req: Request): Promise { if (req.headers.get("Upgrade")?.toLowerCase() === "websocket") { return textResponse("WebSocket upgrade is not implemented", 426); } @@ -254,28 +263,34 @@ export class AcpServer { return textResponse("Missing Acp-Connection-Id", 400); } - const connection = this.registry.get(connectionId); + const connection = await this.httpBackend.loadConnection({ connectionId }); if (!connection) { return textResponse("Unknown Acp-Connection-Id", 404); } const sessionId = req.headers.get(HEADER_SESSION_ID); - if (sessionId) { - return sseResponse(connection.ensureSession(sessionId).subscribe()); + const subscription = sessionId + ? await this.httpBackend.openSessionStream({ connectionId, sessionId }) + : await this.httpBackend.openConnectionStream({ connectionId }); + + if (!subscription) { + return textResponse("Unknown Acp-Connection-Id", 404); } - return sseResponse(connection.connectionStream.subscribe()); + await this.httpBackend.touchConnection({ connectionId }); + + return sseResponse(subscription); } - private handleDelete(req: Request): Response { + private async handleDelete(req: Request): Promise { const connectionId = req.headers.get(HEADER_CONNECTION_ID); if (!connectionId) { return textResponse("Missing Acp-Connection-Id", 400); } - if (!this.registry.remove(connectionId)) { + if (!(await this.httpBackend.closeConnection({ connectionId }))) { return textResponse("Unknown Acp-Connection-Id", 404); } @@ -295,38 +310,28 @@ export class AcpServer { return textResponse("Request aborted", 499); } - let connection: - | ReturnType - | undefined; - try { - connection = this.registry.createConnection( - agentOverride(options, this.agent), - ); - const initialResponsePromise = writeAndReceiveInitial( - connection, + const initializePromise = this.httpBackend.initialize({ + agent: agentOverride(options, this.agent), message, - ); - initialResponsePromise.catch(() => undefined); + signal, + }); + initializePromise.catch(() => undefined); - const initialResponse = await raceAbort(initialResponsePromise, signal); + const { connectionId, response } = await raceAbort( + initializePromise, + signal, + ); if (signal.aborted) { throw new RequestAbortedError(); } - connection.startRouter(); - connection.startConnectHandlers(); - - return jsonResponse(initialResponse, 200, { - [HEADER_CONNECTION_ID]: connection.connectionId, + return jsonResponse(response, 200, { + [HEADER_CONNECTION_ID]: connectionId, }); } catch (error) { - if (connection) { - this.registry.remove(connection.connectionId); - } - - if (error instanceof RequestAbortedError) { + if (error instanceof RequestAbortedError || signal.aborted) { return textResponse("Request aborted", 499); } @@ -346,15 +351,25 @@ export class AcpServer { } private async forwardConnectedMessage( - connection: ConnectionState, + connectionId: string, message: AnyMessage, headers: Headers, ): Promise { if (isResponseMessage(message)) { - return await forwardClientResponse(connection, message, headers); + return await forwardClientResponse( + this.httpBackend, + connectionId, + message, + headers, + ); } - return await forwardClientMethodMessage(connection, message, headers); + return await forwardClientMethodMessage( + this.httpBackend, + connectionId, + message, + headers, + ); } } @@ -394,7 +409,10 @@ function resolveAgent(options: AgentOptions): AgentConnector { } if (options.agent) { - return options.agent; + return { + connect: (stream, connectionOptions) => + options.agent!.connect(stream, connectionOptions ?? {}), + }; } if (options.createAgent) { @@ -405,8 +423,12 @@ function resolveAgent(options: AgentOptions): AgentConnector { } return { - connect: (stream) => { - new AgentSideConnection(options.createLegacyAgent!, stream); + connect: (stream, connectionOptions) => { + new AgentSideConnection( + options.createLegacyAgent!, + stream, + connectionOptions, + ); }, }; } @@ -463,26 +485,6 @@ async function readJson(req: Request): Promise { } } -async function writeInbound( - connection: ConnectionState, - message: AnyMessage, -): Promise { - await connection.writeInbound(message); -} - -async function writeAndReceiveInitial( - connection: ConnectionState, - message: AnyMessage, -): Promise { - await writeInbound(connection, message); - - if (!("id" in message) || message.id === null) { - throw new Error("Initialize request must include an ID"); - } - - return await connection.recvInitial(message.id); -} - async function raceAbort( promise: Promise, signal: AbortSignal, @@ -511,7 +513,8 @@ async function raceAbort( } async function forwardClientMethodMessage( - connection: ConnectionState, + httpBackend: AcpHttpBackend, + connectionId: string, message: ClientMethodMessage, headers: Headers, ): Promise { @@ -521,54 +524,33 @@ async function forwardClientMethodMessage( return route; } - if (route.value !== "connection") { - connection.ensureSession(route.value.session); - } - - const key = "id" in message ? messageIdKey(message.id) : undefined; - - if (key) { - connection.pendingRoutes.set( - key, - pendingResponseRoute(message, route.value), - ); - } - - await writeInbound(connection, message); - return { ok: true }; + return await httpBackend.acceptClientMethodMessage({ + connectionId, + message, + route: route.value, + responseRoute: pendingResponseRoute(message, route.value), + }); } async function forwardClientResponse( - connection: ConnectionState, - message: AnyResponse, + httpBackend: AcpHttpBackend, + connectionId: string, + message: AnyMessage, headers: Headers, ): Promise { - const key = messageIdKey(message.id); - const route = key ? connection.clientResponseRoutes.get(key) : undefined; - const headerSessionId = headers.get(HEADER_SESSION_ID); - - if (route && route !== "connection" && !headerSessionId) { - return { - ok: false, - status: 400, - message: "Missing Acp-Session-Id", - }; - } - - if (route && route !== "connection" && headerSessionId !== route.session) { + if (!isResponseMessage(message)) { return { ok: false, status: 400, - message: "Mismatched Acp-Session-Id", + message: "Invalid JSON-RPC response", }; } - if (key) { - connection.clientResponseRoutes.delete(key); - } - - await writeInbound(connection, message); - return { ok: true }; + return await httpBackend.acceptClientResponse({ + connectionId, + message, + headerSessionId: headers.get(HEADER_SESSION_ID), + }); } function pendingResponseRoute( diff --git a/src/test-support/test-http-server.ts b/src/test-support/test-http-server.ts index a292b57..778678e 100644 --- a/src/test-support/test-http-server.ts +++ b/src/test-support/test-http-server.ts @@ -10,6 +10,7 @@ import { createTestAgentApp } from "./test-agent.js"; import type { AddressInfo } from "node:net"; import type { AgentApp } from "../acp.js"; +import type { AcpHttpBackend } from "../http-backend.js"; export interface TestHttpServer { readonly url: string; @@ -19,9 +20,12 @@ export interface TestHttpServer { export async function startTestServer( createAgent: () => AgentApp = () => createTestAgentApp(), - options: { port?: number } = {}, + options: { port?: number; httpBackend?: AcpHttpBackend } = {}, ): Promise { - const acpServer = new AcpServer({ createAgent }); + const acpServer = new AcpServer({ + createAgent, + httpBackend: options.httpBackend, + }); const httpServer = http.createServer(createNodeHttpHandler(acpServer)); const webSocketServer = new WebSocketServer({ noServer: true }); From 2bfed85d37783b8409ef013abd98b0290e58556c Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 23 Jun 2026 16:20:40 +1000 Subject: [PATCH 2/2] fix: harden HTTP backend error handling --- src/acp.ts | 2 +- src/connection.ts | 44 ++- src/http-backend.conformance.test.ts | 470 ++++++++++++++++++++++++++- src/http-backend.ts | 17 + src/server.test.ts | 193 +++++++++++ src/server.ts | 45 ++- 6 files changed, 754 insertions(+), 17 deletions(-) diff --git a/src/acp.ts b/src/acp.ts index 3f6d8b9..650f544 100644 --- a/src/acp.ts +++ b/src/acp.ts @@ -1767,7 +1767,7 @@ export class AgentApp { connectWith( stream: Stream, op: (context: AgentContext) => MaybePromise, - options?: AppConnectOptions, + options?: AcpConnectionOptions, ): Promise; /** * Connects this agent app directly to a client app for the lifetime of `op`. diff --git a/src/connection.ts b/src/connection.ts index 4210333..dcf9a85 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -501,8 +501,16 @@ export class InMemoryAcpHttpBackend implements AcpHttpBackend { }); try { - await connection.writeInbound(message); - const response = await connection.recvInitial(message.id); + const discard = (): void => { + this.registry.discard(connection.connectionId); + }; + + await raceAbort(connection.writeInbound(message), signal, discard); + const response = await raceAbort( + connection.recvInitial(message.id), + signal, + discard, + ); if (signal.aborted) { throw new Error("Request aborted"); @@ -736,3 +744,35 @@ function isMatchingResponse( ): msg is AnyResponse { return "id" in msg && !("method" in msg) && msg.id === id; } + +async function raceAbort( + promise: Promise, + signal: AbortSignal, + onAbort: () => void, +): Promise { + promise.catch(() => undefined); + + if (signal.aborted) { + onAbort(); + throw new Error("Request aborted"); + } + + let removeAbortListener: () => void = () => {}; + const abortPromise = new Promise((_resolve, reject) => { + const abort = (): void => { + onAbort(); + reject(new Error("Request aborted")); + }; + + signal.addEventListener("abort", abort, { once: true }); + removeAbortListener = () => { + signal.removeEventListener("abort", abort); + }; + }); + + try { + return await Promise.race([promise, abortPromise]); + } finally { + removeAbortListener(); + } +} diff --git a/src/http-backend.conformance.test.ts b/src/http-backend.conformance.test.ts index 65dbaff..3d3b143 100644 --- a/src/http-backend.conformance.test.ts +++ b/src/http-backend.conformance.test.ts @@ -1,11 +1,18 @@ import { describe, expect, it } from "vitest"; -import { ConnectionRegistry, InMemoryAcpHttpBackend } from "./connection.js"; +import { + ConnectionRegistry, + InMemoryAcpHttpBackend, + OutboundStream, +} from "./connection.js"; import { EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + messageIdKey, + sessionIdFromMessageParams, + sessionIdFromResponseResult, } from "./protocol.js"; import { AcpServer } from "./server.js"; import { parseSseStream } from "./sse.js"; @@ -13,7 +20,15 @@ import { PROTOCOL_VERSION, agent as createAgentApp, methods } from "./acp.js"; import { createTestAgentApp } from "./test-support/test-agent.js"; import type { AgentApp } from "./acp.js"; -import type { AnyMessage } from "./jsonrpc.js"; +import type { AgentConnector, ResponseRoute } from "./connection.js"; +import type { AcpHttpBackend } from "./http-backend.js"; +import type { + AnyMessage, + AnyResponse, + JsonRpcRequestIdGenerator, +} from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; +import { isResponseMessage } from "./jsonrpc.js"; const initializeRequest = { jsonrpc: "2.0", @@ -88,12 +103,12 @@ const harnesses: Array<{ { name: "fake distributed backend", createHarness: (createAgent) => { - const registry = new ConnectionRegistry(); + const store = new FakeDistributedTransportStore(); const counters = new Map(); const createServer = (nodeId: string): AcpServer => new AcpServer({ createAgent, - httpBackend: new InMemoryAcpHttpBackend(registry, () => { + httpBackend: new FakeDistributedAcpHttpBackend(store, () => { const next = counters.get(nodeId) ?? 0; counters.set(nodeId, next + 1); return `${nodeId}-${next}`; @@ -371,6 +386,453 @@ describe.each(harnesses)( }, ); +class FakeDistributedAcpHttpBackend implements AcpHttpBackend { + constructor( + private readonly store: FakeDistributedTransportStore, + readonly generateServerRequestId?: JsonRpcRequestIdGenerator, + ) {} + + initialize: AcpHttpBackend["initialize"] = (input) => + this.store.initialize(input, this.generateServerRequestId); + + loadConnection: AcpHttpBackend["loadConnection"] = (input) => + this.store.loadConnection(input); + + touchConnection: AcpHttpBackend["touchConnection"] = (input) => + this.store.touchConnection(input); + + acceptClientMethodMessage: AcpHttpBackend["acceptClientMethodMessage"] = ( + input, + ) => this.store.acceptClientMethodMessage(input); + + acceptClientResponse: AcpHttpBackend["acceptClientResponse"] = (input) => + this.store.acceptClientResponse(input); + + openConnectionStream: AcpHttpBackend["openConnectionStream"] = (input) => + this.store.openConnectionStream(input); + + openSessionStream: AcpHttpBackend["openSessionStream"] = (input) => + this.store.openSessionStream(input); + + closeConnection: AcpHttpBackend["closeConnection"] = (input) => + this.store.closeConnection(input); + + close: AcpHttpBackend["close"] = () => this.store.close(); +} + +class FakeDistributedTransportStore { + private readonly connections = new Map(); + + async initialize( + { agent, message, signal }: Parameters[0], + requestIdGenerator?: JsonRpcRequestIdGenerator, + ): ReturnType { + if (!("id" in message) || message.id === null) { + throw new Error("Initialize request must include an ID"); + } + + const connection = new FakeDistributedConnection(agent, requestIdGenerator); + + try { + await connection.writeInbound(message); + const response = await connection.recvInitial(message.id); + + if (signal.aborted) { + throw new Error("Request aborted"); + } + + connection.startRouter(); + this.connections.set(connection.connectionId, connection); + + return { + connectionId: connection.connectionId, + response, + }; + } catch (error) { + this.connections.delete(connection.connectionId); + void connection.shutdown(); + throw error; + } + } + + async loadConnection({ + connectionId, + }: Parameters[0]): ReturnType< + AcpHttpBackend["loadConnection"] + > { + if (!this.connections.has(connectionId)) { + return undefined; + } + + return { connectionId }; + } + + async touchConnection( + _input: Parameters[0], + ): ReturnType {} + + async acceptClientMethodMessage({ + connectionId, + message, + route, + responseRoute, + }: Parameters[0]): ReturnType< + AcpHttpBackend["acceptClientMethodMessage"] + > { + const connection = this.connections.get(connectionId); + + if (!connection) { + return unknownConnectionResult(); + } + + if (route !== "connection") { + connection.ensureSession(route.session); + } + + const key = "id" in message ? messageIdKey(message.id) : undefined; + if (key) { + connection.trackPendingResponseRoute(key, responseRoute); + } + + await connection.writeInbound(message); + return { ok: true }; + } + + async acceptClientResponse({ + connectionId, + message, + headerSessionId, + }: Parameters[0]): ReturnType< + AcpHttpBackend["acceptClientResponse"] + > { + const connection = this.connections.get(connectionId); + + if (!connection) { + return unknownConnectionResult(); + } + + const key = messageIdKey(message.id); + const route = key ? connection.clientResponseRoute(key) : undefined; + + if (route && route !== "connection" && !headerSessionId) { + return { + ok: false, + status: 400, + message: "Missing Acp-Session-Id", + }; + } + + if (route && route !== "connection" && headerSessionId !== route.session) { + return { + ok: false, + status: 400, + message: "Mismatched Acp-Session-Id", + }; + } + + if (key) { + connection.clearClientResponseRoute(key); + } + + await connection.writeInbound(message); + return { ok: true }; + } + + async openConnectionStream({ + connectionId, + }: Parameters[0]): ReturnType< + AcpHttpBackend["openConnectionStream"] + > { + return this.connections.get(connectionId)?.connectionStream.subscribe(); + } + + async openSessionStream({ + connectionId, + sessionId, + }: Parameters[0]): ReturnType< + AcpHttpBackend["openSessionStream"] + > { + return this.connections + .get(connectionId) + ?.ensureSession(sessionId) + .subscribe(); + } + + async closeConnection({ + connectionId, + }: Parameters[0]): ReturnType< + AcpHttpBackend["closeConnection"] + > { + const connection = this.connections.get(connectionId); + + if (!connection) { + return false; + } + + this.connections.delete(connectionId); + void connection.shutdown(); + return true; + } + + async close(): Promise { + const connections = Array.from(this.connections.values()); + this.connections.clear(); + await Promise.all(connections.map((connection) => connection.shutdown())); + } +} + +class FakeDistributedConnection { + readonly connectionId = globalThis.crypto.randomUUID(); + readonly connectionStream = new OutboundStream(); + + private readonly inboundTx: WritableStream; + private readonly outboundRx: ReadableStream; + private readonly sessionStreams = new Map(); + private readonly pendingRoutes = new Map(); + private readonly clientResponseRoutes = new Map(); + private inboundWriteChain: Promise = Promise.resolve(); + private initialReader: ReadableStreamDefaultReader | undefined; + private outboundReader: ReadableStreamDefaultReader | undefined; + private hasStartedRouter = false; + private shutdownPromise: Promise | undefined; + + constructor( + agent: AgentConnector, + requestIdGenerator?: JsonRpcRequestIdGenerator, + ) { + const inbound = new TransformStream(); + const outbound = new TransformStream(); + this.inboundTx = inbound.writable; + this.outboundRx = outbound.readable; + + const stream: Stream = { + readable: inbound.readable, + writable: outbound.writable, + }; + + agent.connect(stream, { requestIdGenerator }); + } + + async recvInitial(initializeId: string | number): Promise { + const reader = this.outboundRx.getReader(); + this.initialReader = reader; + + try { + const result = await reader.read(); + + if ( + result.done || + !result.value || + !isMatchingResponse(result.value, initializeId) + ) { + if (!this.shutdownPromise) { + await this.shutdown(); + } + + throw new Error("Expected initialize response from agent"); + } + + return result.value; + } finally { + if (this.initialReader === reader) { + this.initialReader = undefined; + } + + reader.releaseLock(); + } + } + + async writeInbound(message: AnyMessage): Promise { + const write = this.inboundWriteChain.then(() => + this.writeInboundMessage(message), + ); + this.inboundWriteChain = write.catch(() => undefined); + await write; + } + + startRouter(): void { + if (this.hasStartedRouter) { + return; + } + + this.hasStartedRouter = true; + void this.runRouter(); + } + + ensureSession(sessionId: string): OutboundStream { + const existing = this.sessionStreams.get(sessionId); + if (existing) { + return existing; + } + + const stream = new OutboundStream(); + this.sessionStreams.set(sessionId, stream); + + return stream; + } + + trackPendingResponseRoute(key: string, route: ResponseRoute): void { + this.pendingRoutes.set(key, route); + } + + clientResponseRoute(key: string): ResponseRoute | undefined { + return this.clientResponseRoutes.get(key); + } + + clearClientResponseRoute(key: string): void { + this.clientResponseRoutes.delete(key); + } + + async shutdown(): Promise { + if (!this.shutdownPromise) { + this.shutdownPromise = this.runShutdown(); + } + + return this.shutdownPromise; + } + + private async runShutdown(): Promise { + this.connectionStream.close(); + + for (const stream of this.sessionStreams.values()) { + stream.close(); + } + + this.sessionStreams.clear(); + this.pendingRoutes.clear(); + this.clientResponseRoutes.clear(); + + await Promise.allSettled([ + this.inboundTx.close(), + this.cancelOutboundReader(), + ]); + } + + private cancelOutboundReader(): Promise { + const reader = this.initialReader ?? this.outboundReader; + if (reader) { + return reader.cancel(); + } + + return this.outboundRx.cancel(); + } + + private async writeInboundMessage(message: AnyMessage): Promise { + const writer = this.inboundTx.getWriter(); + + try { + await writer.write(message); + } finally { + writer.releaseLock(); + } + } + + private async runRouter(): Promise { + const reader = this.outboundRx.getReader(); + this.outboundReader = reader; + + try { + while (true) { + const result = await reader.read(); + + if (result.done) { + return; + } + + this.routeOutbound(result.value); + } + } catch (error) { + console.error("Fake distributed ACP router stopped unexpectedly:", error); + } finally { + if (this.outboundReader === reader) { + this.outboundReader = undefined; + } + + reader.releaseLock(); + this.connectionStream.close(); + + for (const stream of this.sessionStreams.values()) { + stream.close(); + } + } + } + + private routeOutbound(message: AnyMessage): void { + if (isResponseMessage(message)) { + this.routeOutboundResponse(message); + return; + } + + this.routeOutboundRequestOrNotification(message); + } + + private routeOutboundResponse(message: AnyResponse): void { + const key = messageIdKey(message.id); + const route = key ? this.pendingRoutes.get(key) : undefined; + const sessionId = sessionIdFromResponseResult(message); + + if (sessionId) { + this.ensureSession(sessionId); + } + + if (key) { + this.pendingRoutes.delete(key); + } + + this.pushToRoute(route ?? "connection", message); + } + + private routeOutboundRequestOrNotification(message: AnyMessage): void { + const sessionId = sessionIdFromMessageParams(message); + if (sessionId) { + this.trackClientResponseRoute(message, { session: sessionId }); + this.ensureSession(sessionId).push(message); + return; + } + + this.trackClientResponseRoute(message, "connection"); + this.connectionStream.push(message); + } + + private trackClientResponseRoute( + message: AnyMessage, + route: ResponseRoute, + ): void { + if (!("id" in message) || !("method" in message)) { + return; + } + + const key = messageIdKey(message.id); + if (key) { + this.clientResponseRoutes.set(key, route); + } + } + + private pushToRoute(route: ResponseRoute, message: AnyMessage): void { + if (route === "connection") { + this.connectionStream.push(message); + return; + } + + this.ensureSession(route.session).push(message); + } +} + +function unknownConnectionResult() { + return { + ok: false as const, + status: 404, + message: "Unknown Acp-Connection-Id", + }; +} + +function isMatchingResponse( + msg: AnyMessage, + id: string | number, +): msg is AnyResponse { + return "id" in msg && !("method" in msg) && msg.id === id; +} + function createInteractiveAgent(): AgentApp { return createAgentApp({ name: "http-backend-conformance-agent" }) .onRequest(methods.agent.initialize, () => ({ diff --git a/src/http-backend.ts b/src/http-backend.ts index eb26bc5..793c216 100644 --- a/src/http-backend.ts +++ b/src/http-backend.ts @@ -12,6 +12,23 @@ import type { export type HttpBackendServerRequestIdGenerator = () => string | number; +export class AcpHttpBackendError extends Error { + constructor( + readonly status: number, + message: string, + options?: ErrorOptions, + ) { + super(message, options); + this.name = "AcpHttpBackendError"; + } +} + +export function isAcpHttpBackendError( + error: unknown, +): error is AcpHttpBackendError { + return error instanceof AcpHttpBackendError; +} + export interface HttpBackendInitializeInput { readonly agent: AgentConnector; readonly message: AnyMessage; diff --git a/src/server.test.ts b/src/server.test.ts index cb41ec8..8d1f1c7 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -5,6 +5,11 @@ import { agent as createAgentApp, methods, } from "./acp.js"; +import { + ConnectionRegistry, + InMemoryAcpHttpBackend, + type OutboundSubscription, +} from "./connection.js"; import { EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, @@ -15,7 +20,9 @@ import { AcpServer } from "./server.js"; import { parseSseStream } from "./sse.js"; import { createTestAgentApp, TestAgent } from "./test-support/test-agent.js"; import { startTestServer } from "./test-support/test-http-server.js"; +import { AcpHttpBackendError } from "./http-backend.js"; +import type { AcpHttpBackend } from "./http-backend.js"; import type { AnyMessage } from "./jsonrpc.js"; const initializeRequest = { @@ -515,6 +522,106 @@ describe("AcpServer", () => { } }); + it("removes pending HTTP initialize connections as soon as the request aborts", async () => { + const registry = new ConnectionRegistry(); + const agentCreated = createDeferred(); + const initializeStarted = createDeferred(); + const initializeResponse = createDeferred<{ + protocolVersion: 1; + agentCapabilities: { loadSession: false }; + }>(); + const abortController = new AbortController(); + const server = new AcpServer({ + createAgent: () => { + agentCreated.resolve(); + return createTestAgentApp({ + initialize: () => { + initializeStarted.resolve(); + return initializeResponse.promise; + }, + }); + }, + httpBackend: new InMemoryAcpHttpBackend(registry), + }); + + try { + const responsePromise = server.handleRequest( + jsonRequest(initializeRequest, {}, abortController.signal), + ); + + await agentCreated.promise; + await withTimeout(initializeStarted.promise); + abortController.abort(); + + const response = await withTimeout(responsePromise); + + expect(response.status).toBe(499); + expect(pendingConnectionCount(registry)).toBe(0); + + initializeResponse.resolve({ + protocolVersion: 1, + agentCapabilities: { loadSession: false }, + }); + await flushMicrotasks(); + } finally { + await server.close(); + } + }); + + it.each([ + { + name: "initialize", + method: "initialize" as const, + request: () => jsonRequest(initializeRequest), + expectedStatus: 503, + }, + { + name: "connected POST load", + method: "loadConnection" as const, + request: () => + jsonRequest(sessionNewRequest, { + [HEADER_CONNECTION_ID]: "connection-1", + }), + expectedStatus: 503, + }, + { + name: "session SSE open", + method: "openSessionStream" as const, + request: () => + sseRequest({ + [HEADER_CONNECTION_ID]: "connection-1", + [HEADER_SESSION_ID]: "session-1", + }), + expectedStatus: 403, + }, + { + name: "connection close", + method: "closeConnection" as const, + request: () => + deleteRequest({ + [HEADER_CONNECTION_ID]: "connection-1", + }), + expectedStatus: 503, + }, + ])( + "maps status-bearing HTTP backend errors from $name to HTTP responses", + async ({ method, request, expectedStatus }) => { + const server = new AcpServer({ + createAgent: () => createTestAgentApp(), + httpBackend: createStatusThrowingBackend(method, expectedStatus), + }); + + try { + const response = await server.handleRequest(request()); + + expect(response.status).toBe(expectedStatus); + expect(await response.text()).toBe("Backend unavailable"); + } finally { + await server.close(); + } + }, + ); + it("ignores HTTP factory overrides for existing-connection POST requests", async () => { const createdBy: string[] = []; const server = new AcpServer({ @@ -1125,6 +1232,23 @@ function jsonRequest( }); } +function sseRequest(headers: Record = {}): Request { + return new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + ...headers, + }, + }); +} + +function deleteRequest(headers: Record = {}): Request { + return new Request("http://127.0.0.1/acp", { + method: "DELETE", + headers, + }); +} + async function initialize(url: string): Promise { const response = await postJson(url, initializeRequest); const connectionId = response.headers.get(HEADER_CONNECTION_ID); @@ -1326,6 +1450,75 @@ function createDeferred(): { return { promise, resolve, reject }; } +type StatusThrowingBackendMethod = + | "initialize" + | "loadConnection" + | "openSessionStream" + | "closeConnection"; + +function createStatusThrowingBackend( + failingMethod: StatusThrowingBackendMethod, + status: number, +): AcpHttpBackend { + const throwIfFailing = (method: StatusThrowingBackendMethod): void => { + if (method === failingMethod) { + throw new AcpHttpBackendError(status, "Backend unavailable"); + } + }; + + return { + async initialize({ message }) { + throwIfFailing("initialize"); + return { + connectionId: "connection-1", + response: { + jsonrpc: "2.0", + id: "id" in message ? message.id : null, + result: {}, + }, + }; + }, + async loadConnection({ connectionId }) { + throwIfFailing("loadConnection"); + return { connectionId }; + }, + async touchConnection() {}, + async acceptClientMethodMessage() { + return { ok: true }; + }, + async acceptClientResponse() { + return { ok: true }; + }, + async openConnectionStream() { + return emptyOutboundSubscription(); + }, + async openSessionStream() { + throwIfFailing("openSessionStream"); + return emptyOutboundSubscription(); + }, + async closeConnection() { + throwIfFailing("closeConnection"); + return true; + }, + async close() {}, + }; +} + +function emptyOutboundSubscription(): OutboundSubscription { + return { + replay: [], + stream: new ReadableStream(), + }; +} + +function pendingConnectionCount(registry: ConnectionRegistry): number { + return ( + registry as unknown as { + readonly pendingConnections: Map; + } + ).pendingConnections.size; +} + async function withTimeout(promise: Promise): Promise { let timeout: ReturnType | undefined; diff --git a/src/server.ts b/src/server.ts index b02215a..a6720fe 100644 --- a/src/server.ts +++ b/src/server.ts @@ -13,6 +13,7 @@ import { AGENT_METHODS } from "./schema/index.js"; import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; import { handleWebSocketConnection } from "./ws-server.js"; import { AgentSideConnection } from "./acp.js"; +import { isAcpHttpBackendError } from "./http-backend.js"; import type { WebSocketServerSessionHandle, WebSocketServerSocket, @@ -124,19 +125,29 @@ export class AcpServer { req: Request, options: HandleRequestOptions = {}, ): Promise { - if (req.method === "POST") { - return await this.handlePost(req, options); - } + try { + if (req.method === "POST") { + return await this.handlePost(req, options); + } - if (req.method === "GET") { - return await this.handleGet(req); - } + if (req.method === "GET") { + return await this.handleGet(req); + } - if (req.method === "DELETE") { - return await this.handleDelete(req); - } + if (req.method === "DELETE") { + return await this.handleDelete(req); + } - return textResponse("Method Not Allowed", 405); + return textResponse("Method Not Allowed", 405); + } catch (error) { + const backendResponse = backendErrorResponse(error); + + if (backendResponse) { + return backendResponse; + } + + throw error; + } } /** Creates a WebSocket connection before accepting the HTTP upgrade. */ @@ -335,6 +346,12 @@ export class AcpServer { return textResponse("Request aborted", 499); } + const backendResponse = backendErrorResponse(error); + + if (backendResponse) { + return backendResponse; + } + return jsonResponse( { jsonrpc: "2.0", @@ -785,3 +802,11 @@ function textResponse(body: string, status: number): Response { function emptyResponse(status: number): Response { return new Response(null, { status }); } + +function backendErrorResponse(error: unknown): Response | undefined { + if (!isAcpHttpBackendError(error)) { + return undefined; + } + + return textResponse(error.message, error.status); +}