diff --git a/packages/agent/src/adapters/error-classification.ts b/packages/agent/src/adapters/error-classification.ts index cf51ab762..4873724de 100644 --- a/packages/agent/src/adapters/error-classification.ts +++ b/packages/agent/src/adapters/error-classification.ts @@ -1,6 +1,7 @@ export type AgentErrorClassification = | "upstream_stream_terminated" | "upstream_connection_error" + | "upstream_timeout" | "upstream_provider_failure" | "agent_error"; @@ -23,6 +24,9 @@ export function classifyAgentError( if (/API Error:\s*Connection error\b/i.test(text)) { return "upstream_connection_error"; } + if (/API Error:.*\b(?:timed out|timeout)\b/i.test(text)) { + return "upstream_timeout"; + } if (UPSTREAM_PROVIDER_ERROR_STATUS_PATTERN.test(text)) { return "upstream_provider_failure"; } diff --git a/packages/agent/src/server/agent-server.test.ts b/packages/agent/src/server/agent-server.test.ts index 9cbcdb850..b3e5f826d 100644 --- a/packages/agent/src/server/agent-server.test.ts +++ b/packages/agent/src/server/agent-server.test.ts @@ -669,6 +669,98 @@ describe("AgentServer HTTP Mode", () => { expect(testServer.posthogAPI.updateTaskRun).not.toHaveBeenCalled(); }); + function createFailureTestServer() { + const appendRawLine = vi.fn(); + const testServer = new AgentServer({ + port, + jwtPublicKey: TEST_PUBLIC_KEY, + repositoryPath: repo.path, + apiUrl: "http://localhost:8000", + apiKey: "test-api-key", + projectId: 1, + mode: "interactive", + taskId: "test-task-id", + runId: "test-run-id", + }) as unknown as { + eventStreamSender: { + enqueue: ReturnType; + stop: ReturnType; + }; + posthogAPI: { updateTaskRun: ReturnType }; + session: unknown; + handleTurnFailure( + payload: JwtPayload, + phase: "initial" | "resume" | "followup", + error: unknown, + ): Promise; + }; + testServer.eventStreamSender = { + enqueue: vi.fn(), + stop: vi.fn(async () => {}), + }; + testServer.posthogAPI = { updateTaskRun: vi.fn(async () => ({})) }; + testServer.session = { + acpSessionId: "acp-1", + payload: { run_id: "run-1" }, + logWriter: { appendRawLine, flush: vi.fn(async () => {}) }, + }; + return testServer; + } + + const interactivePayload: JwtPayload = { + run_id: "run-1", + task_id: "task-1", + team_id: 1, + user_id: 1, + distinct_id: "distinct-id", + mode: "interactive", + }; + + it.each([ + ["genuine agent error (terminal)", "boom", "agent_error", true], + [ + "transient upstream timeout (recoverable)", + "API Error: The operation timed out.", + "upstream_timeout", + false, + ], + ] as const)( + "tags and handles a follow-up %s", + async (_name, errorMessage, expectedErrorType, expectsFailed) => { + const testServer = createFailureTestServer(); + + await testServer.handleTurnFailure( + interactivePayload, + "followup", + new Error(errorMessage), + ); + + expect(testServer.eventStreamSender.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ + notification: expect.objectContaining({ + method: "session/update", + params: expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: "error", + errorType: expectedErrorType, + }), + }), + }), + }), + ); + + if (expectsFailed) { + expect(testServer.posthogAPI.updateTaskRun).toHaveBeenCalledWith( + "task-1", + "run-1", + expect.objectContaining({ status: "failed" }), + ); + } else { + expect(testServer.posthogAPI.updateTaskRun).not.toHaveBeenCalled(); + } + }, + ); + it("persists structured turn completion notifications", () => { const appendRawLine = vi.fn(); const testServer = new AgentServer({ diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 0c30dc5b1..309d3bda6 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -3,6 +3,7 @@ import { basename, join } from "node:path"; import { pathToFileURL } from "node:url"; import type { ContentBlock, + PromptResponse, RequestPermissionRequest, RequestPermissionResponse, } from "@agentclientprotocol/sdk"; @@ -82,6 +83,7 @@ import type { AgentServerConfig } from "./types"; const agentErrorClassificationSchema = z.enum([ "upstream_stream_terminated", "upstream_connection_error", + "upstream_timeout", "upstream_provider_failure", "agent_error", ]) satisfies z.ZodType; @@ -93,6 +95,7 @@ const upstreamProviderFailureClassifications = new Set([ "upstream_stream_terminated", "upstream_connection_error", + "upstream_timeout", "upstream_provider_failure", ]); @@ -788,17 +791,31 @@ export class AgentServer { this.session.logWriter.resetTurnMessages(this.session.payload.run_id); - const result = await this.session.clientConnection.prompt({ - sessionId: this.session.acpSessionId, - prompt, - ...(this.detectedPrUrl && { - _meta: { - // Keep the live-session PR override aligned with the startup - // prompt policy so non-Slack runs remain review-first. - prContext: this.buildDetectedPrContext(this.detectedPrUrl), - }, - }), - }); + let result: PromptResponse; + try { + result = await this.session.clientConnection.prompt({ + sessionId: this.session.acpSessionId, + prompt, + ...(this.detectedPrUrl && { + _meta: { + // Keep the live-session PR override aligned with the startup + // prompt policy so non-Slack runs remain review-first. + prContext: this.buildDetectedPrContext(this.detectedPrUrl), + }, + }), + }); + } catch (error) { + await this.session.logWriter.flushAll(); + const { recoverable } = await this.handleTurnFailure( + this.session.payload, + "followup", + error, + ); + if (!recoverable) { + throw error; + } + return { stopReason: "error_recoverable" }; + } this.logger.debug("User message completed", { stopReason: result.stopReason, @@ -1284,22 +1301,67 @@ export class AgentServer { return { classification: classifyAgentError(message), message }; } - private classifyAndSignalFailure( + private async handleTurnFailure( payload: JwtPayload, - phase: "initial" | "resume", + phase: "initial" | "resume" | "followup", error: unknown, - ): Promise { + ): Promise<{ recoverable: boolean }> { const { classification, message } = this.extractErrorClassification(error); - const errorMessage = upstreamProviderFailureClassifications.has( - classification, - ) + const isUpstreamFailure = + upstreamProviderFailureClassifications.has(classification); + const displayMessage = isUpstreamFailure ? UPSTREAM_PROVIDER_FAILURE_MESSAGE : message || "Agent error"; + const recoverable = + isUpstreamFailure && + phase === "followup" && + this.getEffectiveMode(payload) === "interactive"; + this.logger.error(`send_${phase}_task_message_failed`, { classification, message, + recoverable, }); - return this.signalTaskComplete(payload, "error", errorMessage); + + this.broadcastTurnFailure(classification, displayMessage); + + if (recoverable) { + this.broadcastTurnComplete("error_recoverable"); + return { recoverable: true }; + } + + await this.signalTaskComplete(payload, "error", displayMessage); + return { recoverable: false }; + } + + private broadcastTurnFailure( + classification: AgentErrorClassification, + message: string, + ): void { + if (!this.session) return; + const notification = { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId: this.session.acpSessionId, + update: { + sessionUpdate: "error", + errorType: classification, + message, + }, + }, + }; + + this.broadcastEvent({ + type: "notification", + timestamp: new Date().toISOString(), + notification, + }); + + this.session.logWriter.appendRawLine( + this.session.payload.run_id, + JSON.stringify(notification), + ); } private async sendInitialTaskMessage( @@ -1401,7 +1463,7 @@ export class AgentServer { if (this.session) { await this.session.logWriter.flushAll(); } - await this.classifyAndSignalFailure(payload, "initial", error); + await this.handleTurnFailure(payload, "initial", error); } } @@ -1540,7 +1602,7 @@ export class AgentServer { if (this.session) { await this.session.logWriter.flushAll(); } - await this.classifyAndSignalFailure(payload, "resume", error); + await this.handleTurnFailure(payload, "resume", error); } } diff --git a/packages/agent/src/server/question-relay.test.ts b/packages/agent/src/server/question-relay.test.ts index 073328a40..e00643d0a 100644 --- a/packages/agent/src/server/question-relay.test.ts +++ b/packages/agent/src/server/question-relay.test.ts @@ -95,6 +95,8 @@ describe("Question relay", () => { it.each([ ["API Error: terminated", "upstream_stream_terminated"], ["API Error: Connection error", "upstream_connection_error"], + ["API Error: The operation timed out.", "upstream_timeout"], + ["API Error: Request timed out.", "upstream_timeout"], ["API Error: 429 rate_limit_error", "upstream_provider_failure"], ["API Error: 529 overloaded_error", "upstream_provider_failure"], ["API Error: 503 internal_error", "upstream_provider_failure"], @@ -529,6 +531,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), }, @@ -573,6 +576,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), }, @@ -607,6 +611,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), }, @@ -646,6 +651,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), }, @@ -690,6 +696,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), }, @@ -734,6 +741,7 @@ describe("Question relay", () => { flushAll: vi.fn().mockResolvedValue(undefined), getFullAgentResponse: vi.fn().mockReturnValue(null), resetTurnMessages: vi.fn(), + appendRawLine: vi.fn(), flush: vi.fn().mockResolvedValue(undefined), isRegistered: vi.fn().mockReturnValue(true), },