diff --git a/apps/sim/app/api/a2a/agents/[agentId]/route.ts b/apps/sim/app/api/a2a/agents/[agentId]/route.ts index bf05379f861..cc399e9e0f6 100644 --- a/apps/sim/app/api/a2a/agents/[agentId]/route.ts +++ b/apps/sim/app/api/a2a/agents/[agentId]/route.ts @@ -3,8 +3,8 @@ import { a2aAgent, workflow } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, isNull } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' -import { generateAgentCard, generateSkillsFromWorkflow } from '@/lib/a2a/agent-card' -import type { AgentCapabilities, AgentSkill } from '@/lib/a2a/types' +import { buildAgentCard, generateSkillsFromWorkflow } from '@/lib/a2a/agent-card' +import type { AgentAuthentication, AgentCapabilities, AgentSkill } from '@/lib/a2a/types' import { a2aAgentParamsSchema, publishA2AAgentContract, @@ -13,10 +13,12 @@ import { import { parseRequest } from '@/lib/api/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' import { getRedisClient } from '@/lib/core/config/redis' +import { getBaseUrl } from '@/lib/core/utils/urls' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { captureServerEvent } from '@/lib/posthog/server' import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils' import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils' +import { getBrandConfig } from '@/ee/whitelabeling' const logger = createLogger('A2AAgentCardAPI') @@ -60,21 +62,23 @@ export const GET = withRouteHandler( } } - const agentCard = generateAgentCard( - { + const agentCard = buildAgentCard({ + agent: { id: agent.agent.id, name: agent.agent.name, description: agent.agent.description, version: agent.agent.version, capabilities: agent.agent.capabilities as AgentCapabilities, skills: agent.agent.skills as AgentSkill[], + authentication: agent.agent.authentication as AgentAuthentication, }, - { - id: agent.workflow.id, + baseUrl: getBaseUrl(), + providerOrganization: getBrandConfig().name, + workflow: { name: agent.workflow.name, description: agent.workflow.description, - } - ) + }, + }) return NextResponse.json(agentCard, { headers: { diff --git a/apps/sim/app/api/a2a/serve/[agentId]/.well-known/agent-card.json/route.ts b/apps/sim/app/api/a2a/serve/[agentId]/.well-known/agent-card.json/route.ts new file mode 100644 index 00000000000..b69f1165c96 --- /dev/null +++ b/apps/sim/app/api/a2a/serve/[agentId]/.well-known/agent-card.json/route.ts @@ -0,0 +1,39 @@ +import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' +import { a2aServeAgentParamsSchema } from '@/lib/api/contracts/a2a-agents' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { getServedAgentCard } from '@/app/api/a2a/serve/[agentId]/utils' + +export const dynamic = 'force-dynamic' +export const runtime = 'nodejs' + +interface RouteParams { + agentId: string +} + +/** + * GET - A2A v0.3 well-known discovery endpoint. + * + * Serves the Agent Card at the RFC 8615 context path + * (`/api/a2a/serve/{agentId}/.well-known/agent-card.json`) so standard A2A + * clients that append the well-known path to an agent's base URL can discover + * Sim-hosted agents. + */ +export const GET = withRouteHandler( + async (_request: NextRequest, { params }: { params: Promise }) => { + const { agentId } = a2aServeAgentParamsSchema.parse(await params) + + const result = await getServedAgentCard(agentId) + if (!result.ok) { + return NextResponse.json({ error: result.error }, { status: result.status }) + } + + return NextResponse.json(result.card, { + headers: { + 'Content-Type': 'application/json', + 'Cache-Control': 'public, max-age=60', + 'X-Cache': result.cacheHit ? 'HIT' : 'MISS', + }, + }) + } +) diff --git a/apps/sim/app/api/a2a/serve/[agentId]/request-handler.test.ts b/apps/sim/app/api/a2a/serve/[agentId]/request-handler.test.ts new file mode 100644 index 00000000000..07e0da42929 --- /dev/null +++ b/apps/sim/app/api/a2a/serve/[agentId]/request-handler.test.ts @@ -0,0 +1,123 @@ +/** + * @vitest-environment node + */ +import { A2AError } from '@a2a-js/sdk/server' +import { dbChainMock, dbChainMockFns, resetDbChainMock } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@/lib/a2a/push-notifications', () => ({ + notifyTaskStateChange: vi.fn().mockResolvedValue(undefined), +})) +vi.mock('@/lib/execution/cancellation', () => ({ + markExecutionCancelled: vi.fn().mockResolvedValue(undefined), +})) +vi.mock('@/lib/core/config/redis', () => ({ + acquireLock: vi.fn().mockResolvedValue(true), + releaseLock: vi.fn().mockResolvedValue(undefined), + getRedisClient: vi.fn(() => null), +})) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + validateUrlWithDNS: vi.fn().mockResolvedValue({ isValid: true, resolvedIP: '1.2.3.4' }), + secureFetchWithPinnedIP: vi.fn(), +})) +vi.mock('@/lib/auth/internal', () => ({ + generateInternalToken: vi.fn().mockResolvedValue('internal-token'), +})) +vi.mock('@/ee/whitelabeling', () => ({ + getBrandConfig: () => ({ name: 'Sim' }), +})) + +import { buildAgentCard } from '@/lib/a2a/agent-card' +import { SimA2ARequestHandler } from '@/app/api/a2a/serve/[agentId]/request-handler' + +const agent = { id: 'agent-1', name: 'A', workflowId: 'wf-1', workspaceId: 'ws-1' } +const agentCard = buildAgentCard({ + agent: { id: agent.id, name: agent.name, version: '1.0.0' }, + baseUrl: 'https://example.com', + providerOrganization: 'Sim', +}) + +function makeHandler(callerFingerprint = 'user:u1') { + return new SimA2ARequestHandler({ agent, agentCard, callerFingerprint }) +} + +function taskRow(overrides: Record = {}) { + return { + id: 't1', + agentId: 'agent-1', + sessionId: 'ctx-1', + status: 'completed', + messages: [], + artifacts: [], + executionId: null, + metadata: { callerFingerprint: 'user:u1' }, + ...overrides, + } +} + +describe('SimA2ARequestHandler', () => { + beforeEach(() => { + vi.clearAllMocks() + resetDbChainMock() + }) + + it('getAgentCard returns the configured card', async () => { + await expect(makeHandler().getAgentCard()).resolves.toBe(agentCard) + }) + + it('getAuthenticatedExtendedAgentCard rejects when not configured', async () => { + await expect(makeHandler().getAuthenticatedExtendedAgentCard()).rejects.toMatchObject({ + code: -32007, + }) + }) + + it('getTask returns an SDK Task for an owned task', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([taskRow({ status: 'completed' })]) + + const task = await makeHandler('user:u1').getTask({ id: 't1' }) + + expect(task.kind).toBe('task') + expect(task.id).toBe('t1') + expect(task.contextId).toBe('ctx-1') + expect(task.status.state).toBe('completed') + }) + + it('hides a task owned by a different caller (taskNotFound)', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + taskRow({ metadata: { callerFingerprint: 'user:someone-else' } }), + ]) + + await expect(makeHandler('user:u1').getTask({ id: 't1' })).rejects.toMatchObject({ + code: -32001, + }) + }) + + it('cancelTask rejects a task already in a terminal state (taskNotCancelable)', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([taskRow({ status: 'completed' })]) + + await expect(makeHandler('user:u1').cancelTask({ id: 't1' })).rejects.toMatchObject({ + code: -32002, + }) + }) + + it('cancelTask cancels a running task and returns canceled state', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + taskRow({ status: 'working', executionId: 'exec-1' }), + ]) + + const task = await makeHandler('user:u1').cancelTask({ id: 't1' }) + + expect(task.status.state).toBe('canceled') + expect(task.id).toBe('t1') + }) + + it('all thrown errors are SDK A2AError instances', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([]) + const error = await makeHandler('user:u1') + .getTask({ id: 'missing' }) + .catch((e) => e) + expect(error).toBeInstanceOf(A2AError) + expect(error.code).toBe(-32001) + }) +}) diff --git a/apps/sim/app/api/a2a/serve/[agentId]/request-handler.ts b/apps/sim/app/api/a2a/serve/[agentId]/request-handler.ts new file mode 100644 index 00000000000..a31573e3f4c --- /dev/null +++ b/apps/sim/app/api/a2a/serve/[agentId]/request-handler.ts @@ -0,0 +1,708 @@ +import type { + AgentCard, + Artifact, + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, + ListTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskState, + TaskStatusUpdateEvent, +} from '@a2a-js/sdk' +import { A2AError, type A2ARequestHandler } from '@a2a-js/sdk/server' +import { db } from '@sim/db' +import { a2aPushNotificationConfig, a2aTask } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' +import { sleep } from '@sim/utils/helpers' +import { generateId } from '@sim/utils/id' +import { eq } from 'drizzle-orm' +import { A2A_DEFAULT_TIMEOUT, A2A_MAX_HISTORY_LENGTH } from '@/lib/a2a/constants' +import { notifyTaskStateChange } from '@/lib/a2a/push-notifications' +import { + createAgentMessage, + extractWorkflowInput, + isTerminalState, + parseWorkflowSSEChunk, +} from '@/lib/a2a/utils' +import { acquireLock, releaseLock } from '@/lib/core/config/redis' +import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { markExecutionCancelled } from '@/lib/execution/cancellation' +import { + buildExecuteRequest, + buildStatusUpdate, + buildTaskResponse, + extractAgentContent, +} from '@/app/api/a2a/serve/[agentId]/utils' + +const logger = createLogger('A2ARequestHandler') + +const RESUBSCRIBE_POLL_INTERVAL_MS = 3000 +const RESUBSCRIBE_MAX_POLLS = 100 + +interface HandlerAgent { + id: string + name: string + workflowId: string + workspaceId: string +} + +export interface SimA2AHandlerConfig { + agent: HandlerAgent + agentCard: AgentCard + apiKey?: string | null + executionUserId?: string + callerFingerprint: string + requestSignal?: AbortSignal +} + +/** + * Sim implementation of the A2A {@link A2ARequestHandler} interface. + * + * The SDK's {@link import('@a2a-js/sdk/server').JsonRpcTransportHandler} drives this + * handler — it owns JSON-RPC parsing, method routing, capability checks, and + * error enveloping. This class supplies the Sim-specific behavior that the SDK + * cannot know about: executing the backing workflow, task persistence, caller + * isolation, distributed locking, and push-notification delivery. + * + * Auth, workspace access, and billing entitlement are enforced by the route + * before the request reaches this handler. + */ +export class SimA2ARequestHandler implements A2ARequestHandler { + constructor(private readonly config: SimA2AHandlerConfig) {} + + getAgentCard(): Promise { + return Promise.resolve(this.config.agentCard) + } + + async getAuthenticatedExtendedAgentCard(): Promise { + throw A2AError.authenticatedExtendedCardNotConfigured() + } + + async sendMessage(params: MessageSendParams): Promise { + const { agent } = this.config + const message = params.message + const taskId = message.taskId || generateId() + const contextId = message.contextId || generateId() + + const lockKey = `a2a:task:${taskId}:lock` + const lockValue = generateId() + const acquired = await acquireLock(lockKey, lockValue, 60) + if (!acquired) { + throw A2AError.internalError('Task is currently being processed') + } + + let movedToWorking = false + try { + const existingTask = await this.loadExistingTaskForSend(message.taskId) + const history: Message[] = existingTask ? (existingTask.messages as Message[]) : [] + history.push(message) + this.truncateHistory(history) + + if (existingTask) { + await db + .update(a2aTask) + .set({ status: 'working', messages: history, updatedAt: new Date() }) + .where(eq(a2aTask.id, taskId)) + } else { + await db.insert(a2aTask).values({ + id: taskId, + agentId: agent.id, + sessionId: contextId, + status: 'working', + messages: history, + metadata: { callerFingerprint: this.config.callerFingerprint }, + createdAt: new Date(), + updatedAt: new Date(), + }) + } + movedToWorking = true + + const workflowInput = extractWorkflowInput(message) + if (!workflowInput) { + throw A2AError.invalidParams('Message must contain at least one part with content') + } + + const { url, headers, useInternalAuth } = await buildExecuteRequest({ + workflowId: agent.workflowId, + apiKey: this.config.apiKey, + userId: this.config.executionUserId, + }) + + logger.info(`Executing workflow ${agent.workflowId} for A2A task ${taskId}`) + + const response = await fetch(url, { + method: 'POST', + headers, + body: JSON.stringify({ + ...workflowInput, + triggerType: 'a2a', + ...(useInternalAuth && { workflowId: agent.workflowId }), + }), + signal: AbortSignal.timeout(A2A_DEFAULT_TIMEOUT), + }) + + const executeResult = await response.json() + const executionId = executeResult.executionId || executeResult.metadata?.executionId + const executionSucceeded = response.ok && executeResult.success !== false + const finalState: TaskState = executionSucceeded ? 'completed' : 'failed' + + const agentMessage = createAgentMessage(extractAgentContent(executeResult)) + agentMessage.taskId = taskId + agentMessage.contextId = contextId + history.push(agentMessage) + + const artifacts: Artifact[] = executeResult.output?.artifacts || [] + + await db + .update(a2aTask) + .set({ + status: finalState, + messages: history, + artifacts, + executionId, + completedAt: new Date(), + updatedAt: new Date(), + }) + .where(eq(a2aTask.id, taskId)) + + this.notifyIfTerminal(taskId, finalState) + + return buildTaskResponse({ taskId, contextId, state: finalState, history, artifacts }) + } catch (error) { + await this.failTask(taskId, movedToWorking) + throw this.toA2AError(error) + } finally { + await releaseLock(lockKey, lockValue) + } + } + + async *sendMessageStream( + params: MessageSendParams + ): AsyncGenerator< + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + void, + undefined + > { + const { agent } = this.config + const message = params.message + const contextId = message.contextId || generateId() + const taskId = message.taskId || generateId() + + const lockKey = `a2a:task:${taskId}:lock` + const lockValue = generateId() + const acquired = await acquireLock(lockKey, lockValue, 300) + if (!acquired) { + throw A2AError.internalError('Task is currently being processed') + } + + let movedToWorking = false + try { + const existingTask = await this.loadExistingTaskForSend(message.taskId) + const history: Message[] = existingTask ? (existingTask.messages as Message[]) : [] + history.push(message) + this.truncateHistory(history) + + if (existingTask) { + await db + .update(a2aTask) + .set({ status: 'working', messages: history, updatedAt: new Date() }) + .where(eq(a2aTask.id, taskId)) + } else { + await db.insert(a2aTask).values({ + id: taskId, + agentId: agent.id, + sessionId: contextId, + status: 'working', + messages: history, + metadata: { callerFingerprint: this.config.callerFingerprint }, + createdAt: new Date(), + updatedAt: new Date(), + }) + } + movedToWorking = true + + yield buildStatusUpdate({ taskId, contextId, state: 'working', final: false }) + + const workflowInput = extractWorkflowInput(message) + if (!workflowInput) { + throw A2AError.invalidParams('Message must contain at least one part with content') + } + + const { url, headers, useInternalAuth } = await buildExecuteRequest({ + workflowId: agent.workflowId, + apiKey: this.config.apiKey, + userId: this.config.executionUserId, + stream: true, + }) + + const response = await fetch(url, { + method: 'POST', + headers, + body: JSON.stringify({ + ...workflowInput, + triggerType: 'a2a', + stream: true, + ...(useInternalAuth && { workflowId: agent.workflowId }), + }), + signal: AbortSignal.timeout(A2A_DEFAULT_TIMEOUT), + }) + + if (!response.ok) { + let errorMessage = 'Workflow execution failed' + try { + const errorResult = await response.json() + errorMessage = errorResult.error || errorMessage + } catch { + // Response may not be JSON + } + throw new Error(errorMessage) + } + + const contentType = response.headers.get('content-type') || '' + const streamingExecutionId = response.headers.get('X-Execution-Id') || undefined + const isStreamingResponse = + contentType.includes('text/event-stream') || contentType.includes('text/plain') + + if (response.body && isStreamingResponse) { + const reader = response.body.getReader() + const decoder = new TextDecoder() + const contentChunks: string[] = [] + let finalContent: string | undefined + let finalArtifacts: Artifact[] = [] + let sseBuffer = '' + + while (true) { + if (this.config.requestSignal?.aborted) { + await reader.cancel().catch(() => {}) + return + } + + const { done, value } = await reader.read() + if (done) break + + sseBuffer += decoder.decode(value, { stream: true }) + const frames = sseBuffer.split('\n\n') + sseBuffer = frames.pop() ?? '' + + for (const frame of frames) { + const parsed = parseWorkflowSSEChunk(frame) + + if (parsed.content) { + contentChunks.push(parsed.content) + yield this.streamMessage(taskId, contextId, parsed.content) + } + + if (parsed.finalContent) finalContent = parsed.finalContent + if (parsed.finalArtifacts) finalArtifacts = parsed.finalArtifacts + + if (parsed.terminalState === 'canceled') { + const agentMessage = createAgentMessage(finalContent || 'Task canceled') + agentMessage.taskId = taskId + agentMessage.contextId = contextId + history.push(agentMessage) + + await db + .update(a2aTask) + .set({ + status: 'canceled', + messages: history, + executionId: streamingExecutionId, + artifacts: finalArtifacts, + completedAt: new Date(), + updatedAt: new Date(), + }) + .where(eq(a2aTask.id, taskId)) + + this.notifyIfTerminal(taskId, 'canceled') + + yield buildTaskResponse({ + taskId, + contextId, + state: 'canceled', + history, + artifacts: finalArtifacts, + }) + return + } + + if (parsed.finalSuccess === false) { + throw new Error('Workflow execution failed') + } + } + } + + if (sseBuffer.trim().length > 0) { + const parsed = parseWorkflowSSEChunk(sseBuffer) + if (parsed.content) { + contentChunks.push(parsed.content) + yield this.streamMessage(taskId, contextId, parsed.content) + } + if (parsed.finalContent) finalContent = parsed.finalContent + if (parsed.finalArtifacts) finalArtifacts = parsed.finalArtifacts + if (parsed.finalSuccess === false) { + throw new Error('Workflow execution failed') + } + } + + const accumulatedContent = contentChunks.join('') + const messageContent = + (finalContent !== undefined && finalContent.length > 0 + ? finalContent + : accumulatedContent) || 'Task completed' + const agentMessage = createAgentMessage(messageContent) + agentMessage.taskId = taskId + agentMessage.contextId = contextId + history.push(agentMessage) + + await db + .update(a2aTask) + .set({ + status: 'completed', + messages: history, + executionId: streamingExecutionId, + artifacts: finalArtifacts, + completedAt: new Date(), + updatedAt: new Date(), + }) + .where(eq(a2aTask.id, taskId)) + + this.notifyIfTerminal(taskId, 'completed') + + yield buildTaskResponse({ + taskId, + contextId, + state: 'completed', + history, + artifacts: finalArtifacts, + }) + return + } + + const result = await response.json() + const executionSucceeded = result.success !== false + const content = extractAgentContent(result) + const finalState: TaskState = executionSucceeded ? 'completed' : 'failed' + + yield this.streamMessage(taskId, contextId, content) + + const agentMessage = createAgentMessage(content) + agentMessage.taskId = taskId + agentMessage.contextId = contextId + history.push(agentMessage) + + const artifacts: Artifact[] = (result.output?.artifacts as Artifact[]) || [] + + await db + .update(a2aTask) + .set({ + status: finalState, + messages: history, + artifacts, + executionId: result.executionId || result.metadata?.executionId, + completedAt: new Date(), + updatedAt: new Date(), + }) + .where(eq(a2aTask.id, taskId)) + + this.notifyIfTerminal(taskId, finalState) + + yield buildTaskResponse({ taskId, contextId, state: finalState, history, artifacts }) + } catch (error) { + await this.failTask(taskId, movedToWorking) + throw this.toA2AError(error) + } finally { + await releaseLock(lockKey, lockValue) + } + } + + async getTask(params: TaskQueryParams): Promise { + const task = await this.getTaskForCaller(params.id) + const historyLength = + params.historyLength !== undefined && params.historyLength >= 0 + ? params.historyLength + : undefined + + const history = task.messages as Message[] + return buildTaskResponse({ + taskId: task.id, + contextId: task.sessionId || task.id, + state: task.status as TaskState, + history: historyLength !== undefined ? history.slice(-historyLength) : history, + artifacts: (task.artifacts as Artifact[]) || [], + }) + } + + async cancelTask(params: TaskIdParams): Promise { + const task = await this.getTaskForCaller(params.id) + + if (isTerminalState(task.status as TaskState)) { + throw A2AError.taskNotCancelable(params.id) + } + + if (task.executionId) { + try { + await markExecutionCancelled(task.executionId) + logger.info('Cancelled workflow execution', { + taskId: task.id, + executionId: task.executionId, + }) + } catch (error) { + logger.warn('Failed to cancel workflow execution', { + taskId: task.id, + executionId: task.executionId, + error, + }) + } + } + + await db + .update(a2aTask) + .set({ status: 'canceled', completedAt: new Date(), updatedAt: new Date() }) + .where(eq(a2aTask.id, params.id)) + + this.notifyIfTerminal(params.id, 'canceled') + + return buildTaskResponse({ + taskId: task.id, + contextId: task.sessionId || task.id, + state: 'canceled', + history: task.messages as Message[], + artifacts: (task.artifacts as Artifact[]) || [], + }) + } + + async setTaskPushNotificationConfig( + params: TaskPushNotificationConfig + ): Promise { + const config = params.pushNotificationConfig + const urlValidation = await validateUrlWithDNS(config.url, 'Push notification URL') + if (!urlValidation.isValid) { + throw A2AError.invalidParams(urlValidation.error || 'Invalid push notification URL') + } + + await this.getTaskForCaller(params.taskId) + + const [existingConfig] = await db + .select() + .from(a2aPushNotificationConfig) + .where(eq(a2aPushNotificationConfig.taskId, params.taskId)) + .limit(1) + + if (existingConfig) { + await db + .update(a2aPushNotificationConfig) + .set({ + url: config.url, + token: config.token || null, + isActive: true, + updatedAt: new Date(), + }) + .where(eq(a2aPushNotificationConfig.id, existingConfig.id)) + } else { + await db.insert(a2aPushNotificationConfig).values({ + id: generateId(), + taskId: params.taskId, + url: config.url, + token: config.token || null, + isActive: true, + createdAt: new Date(), + updatedAt: new Date(), + }) + } + + return { + taskId: params.taskId, + pushNotificationConfig: { url: config.url, token: config.token }, + } + } + + async getTaskPushNotificationConfig( + params: TaskIdParams | GetTaskPushNotificationConfigParams + ): Promise { + await this.getTaskForCaller(params.id) + + const [config] = await db + .select() + .from(a2aPushNotificationConfig) + .where(eq(a2aPushNotificationConfig.taskId, params.id)) + .limit(1) + + if (!config) { + throw A2AError.invalidParams('No push notification configuration found for task') + } + + return { + taskId: params.id, + pushNotificationConfig: { url: config.url, token: config.token || undefined }, + } + } + + async listTaskPushNotificationConfigs( + params: ListTaskPushNotificationConfigParams + ): Promise { + await this.getTaskForCaller(params.id) + + const configs = await db + .select() + .from(a2aPushNotificationConfig) + .where(eq(a2aPushNotificationConfig.taskId, params.id)) + + return configs.map((config) => ({ + taskId: params.id, + pushNotificationConfig: { url: config.url, token: config.token || undefined }, + })) + } + + async deleteTaskPushNotificationConfig( + params: DeleteTaskPushNotificationConfigParams + ): Promise { + await this.getTaskForCaller(params.id) + + await db + .delete(a2aPushNotificationConfig) + .where(eq(a2aPushNotificationConfig.taskId, params.id)) + } + + async *resubscribe( + params: TaskIdParams + ): AsyncGenerator { + const task = await this.getTaskForCaller(params.id) + const contextId = task.sessionId || task.id + + if (isTerminalState(task.status as TaskState)) { + yield buildTaskResponse({ + taskId: task.id, + contextId, + state: task.status as TaskState, + history: task.messages as Message[], + artifacts: (task.artifacts as Artifact[]) || [], + }) + return + } + + yield buildStatusUpdate({ + taskId: task.id, + contextId, + state: task.status as TaskState, + final: false, + }) + + let lastStatus = task.status + + for (let poll = 0; poll < RESUBSCRIBE_MAX_POLLS; poll++) { + if (this.config.requestSignal?.aborted) return + + await sleep(RESUBSCRIBE_POLL_INTERVAL_MS) + if (this.config.requestSignal?.aborted) return + + const [updated] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1) + if (!updated) { + throw A2AError.taskNotFound(params.id) + } + + const terminal = isTerminalState(updated.status as TaskState) + + if (updated.status !== lastStatus) { + lastStatus = updated.status + yield buildStatusUpdate({ + taskId: updated.id, + contextId: updated.sessionId || updated.id, + state: updated.status as TaskState, + final: terminal, + }) + } + + if (terminal) { + yield buildTaskResponse({ + taskId: updated.id, + contextId: updated.sessionId || updated.id, + state: updated.status as TaskState, + history: updated.messages as Message[], + artifacts: (updated.artifacts as Artifact[]) || [], + }) + return + } + } + } + + private async loadExistingTaskForSend( + taskId: string | undefined + ): Promise { + if (!taskId) return null + + const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, taskId)).limit(1) + if (!task || task.agentId !== this.config.agent.id) { + throw A2AError.taskNotFound(taskId) + } + if (!this.hasCallerAccess(task)) { + throw A2AError.taskNotFound(taskId) + } + if (isTerminalState(task.status as TaskState)) { + throw A2AError.invalidRequest(`Task ${taskId} is already in a terminal state`) + } + return task + } + + private async getTaskForCaller(taskId: string): Promise { + const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, taskId)).limit(1) + if (!task || task.agentId !== this.config.agent.id || !this.hasCallerAccess(task)) { + throw A2AError.taskNotFound(taskId) + } + return task + } + + private hasCallerAccess(task: typeof a2aTask.$inferSelect): boolean { + const metadata = (task.metadata as Record | null) ?? {} + const stored = + typeof metadata.callerFingerprint === 'string' ? metadata.callerFingerprint : null + return !stored || stored === this.config.callerFingerprint + } + + private streamMessage(taskId: string, contextId: string, text: string): Message { + const message = createAgentMessage(text) + message.taskId = taskId + message.contextId = contextId + return message + } + + private truncateHistory(history: Message[]): void { + if (history.length > A2A_MAX_HISTORY_LENGTH) { + history.splice(0, history.length - A2A_MAX_HISTORY_LENGTH) + } + } + + private notifyIfTerminal(taskId: string, state: TaskState): void { + if (!isTerminalState(state)) return + notifyTaskStateChange(taskId, state).catch((err) => { + logger.error('Failed to trigger push notification', { taskId, state, error: err }) + }) + } + + private async failTask(taskId: string, movedToWorking: boolean): Promise { + if (!movedToWorking) return + try { + await db + .update(a2aTask) + .set({ status: 'failed', completedAt: new Date(), updatedAt: new Date() }) + .where(eq(a2aTask.id, taskId)) + this.notifyIfTerminal(taskId, 'failed') + } catch (error) { + logger.error('Failed to mark A2A task as failed', { taskId, error }) + } + } + + private toA2AError(error: unknown): A2AError { + if (error instanceof A2AError) return error + const isTimeout = error instanceof Error && error.name === 'TimeoutError' + const message = isTimeout + ? `Workflow execution timed out after ${A2A_DEFAULT_TIMEOUT}ms` + : getErrorMessage(error, 'Workflow execution failed') + return A2AError.internalError(message) + } +} diff --git a/apps/sim/app/api/a2a/serve/[agentId]/route.ts b/apps/sim/app/api/a2a/serve/[agentId]/route.ts index 62c41b0d48a..05bc28b4cd8 100644 --- a/apps/sim/app/api/a2a/serve/[agentId]/route.ts +++ b/apps/sim/app/api/a2a/serve/[agentId]/route.ts @@ -1,62 +1,37 @@ -import type { Artifact, Message, PushNotificationConfig, TaskState } from '@a2a-js/sdk' +import { A2AError, JsonRpcTransportHandler } from '@a2a-js/sdk/server' import { db } from '@sim/db' -import { a2aAgent, a2aPushNotificationConfig, a2aTask, workflow } from '@sim/db/schema' +import { a2aAgent, workflow } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' -import { generateId } from '@sim/utils/id' import { and, eq, isNull } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' -import { A2A_DEFAULT_TIMEOUT, A2A_MAX_HISTORY_LENGTH } from '@/lib/a2a/constants' -import { notifyTaskStateChange } from '@/lib/a2a/push-notifications' -import { - createAgentMessage, - extractWorkflowInput, - isTerminalState, - parseWorkflowSSEChunk, -} from '@/lib/a2a/utils' -import { - type A2AJsonRpcId, - type A2AMessageSendParams, - type A2APushNotificationSetParams, - type A2ATaskIdParams, - a2aJsonRpcRequestSchema, - a2aMessageSendParamsSchema, - a2aPushNotificationSetParamsSchema, - a2aServeAgentParamsSchema, - a2aTaskIdParamsSchema, -} from '@/lib/api/contracts/a2a-agents' +import { a2aServeAgentParamsSchema } from '@/lib/api/contracts/a2a-agents' import { type AuthResult, AuthType, checkHybridAuth } from '@/lib/auth/hybrid' import { API_EXECUTION_REQUIRES_PAID_PLAN_MESSAGE, isApiExecutionEntitled, } from '@/lib/billing/core/api-access' -import { acquireLock, getRedisClient, releaseLock } from '@/lib/core/config/redis' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import { getClientIp } from '@/lib/core/utils/request' import { SSE_HEADERS } from '@/lib/core/utils/sse' -import { getBaseUrl } from '@/lib/core/utils/urls' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { markExecutionCancelled } from '@/lib/execution/cancellation' import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils' import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils' -import { - A2A_ERROR_CODES, - A2A_METHODS, - buildExecuteRequest, - buildTaskResponse, - createError, - createResponse, - extractAgentContent, - formatTaskResponse, - generateTaskId, -} from '@/app/api/a2a/serve/[agentId]/utils' -import { getBrandConfig } from '@/ee/whitelabeling' +import { SimA2ARequestHandler } from '@/app/api/a2a/serve/[agentId]/request-handler' +import { getServedAgentCard } from '@/app/api/a2a/serve/[agentId]/utils' const logger = createLogger('A2AServeAPI') export const dynamic = 'force-dynamic' export const runtime = 'nodejs' +/** + * JSON-RPC server-error code (-32000) for Sim-specific conditions the A2A spec + * does not define (agent unavailable, unauthorized, billing). The real signal + * is the HTTP status; this avoids colliding with the SDK's reserved A2A codes + * (-32001..-32007). + */ +const A2A_SERVER_ERROR_CODE = -32000 + interface RouteParams { agentId: string } @@ -65,1512 +40,244 @@ function getCallerFingerprint(request: NextRequest, userId?: string | null): str if (userId) { return `user:${userId}` } - const clientIp = getClientIp(request) const userAgent = request.headers.get('user-agent')?.trim() || 'unknown' return `public:${clientIp}:${userAgent}` } -function hasCallerAccessToTask( - task: typeof a2aTask.$inferSelect, - callerFingerprint: string -): boolean { - const metadata = (task.metadata as Record | null) ?? {} - const storedFingerprint = - typeof metadata.callerFingerprint === 'string' ? metadata.callerFingerprint : null - return !storedFingerprint || storedFingerprint === callerFingerprint +function extractJsonRpcId(body: unknown): string | number | null { + if (body && typeof body === 'object' && 'id' in body) { + const id = (body as { id: unknown }).id + if (typeof id === 'string' || typeof id === 'number') return id + } + return null } -/** - * GET - Returns the Agent Card (discovery document) - */ -export const GET = withRouteHandler( - async (_request: NextRequest, { params }: { params: Promise }) => { - const { agentId } = a2aServeAgentParamsSchema.parse(await params) +function jsonRpcErrorResponse( + id: string | number | null, + error: A2AError, + status: number +): NextResponse { + return NextResponse.json({ jsonrpc: '2.0', id, error: error.toJSONRPCError() }, { status }) +} - const redis = getRedisClient() - const cacheKey = `a2a:agent:${agentId}:card` +function isAsyncIterable(value: unknown): value is AsyncIterable { + return ( + typeof value === 'object' && + value !== null && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === 'function' + ) +} - if (redis) { +function streamJsonRpc( + stream: AsyncIterable, + requestId: string | number | null +): NextResponse { + const encoder = new TextEncoder() + const readable = new ReadableStream({ + async start(controller) { try { - const cached = await redis.get(cacheKey) - if (cached) { - return NextResponse.json(JSON.parse(cached), { - headers: { - 'Content-Type': 'application/json', - 'Cache-Control': 'private, max-age=60', - 'X-Cache': 'HIT', - }, - }) + for await (const event of stream) { + controller.enqueue(encoder.encode(`data: ${JSON.stringify(event)}\n\n`)) } - } catch (err) { - logger.warn('Redis cache read failed', { agentId, error: err }) - } - } - - try { - const [agent] = await db - .select({ - id: a2aAgent.id, - name: a2aAgent.name, - description: a2aAgent.description, - version: a2aAgent.version, - capabilities: a2aAgent.capabilities, - skills: a2aAgent.skills, - authentication: a2aAgent.authentication, - isPublished: a2aAgent.isPublished, - }) - .from(a2aAgent) - .where(and(eq(a2aAgent.id, agentId), isNull(a2aAgent.archivedAt))) - .limit(1) - - if (!agent) { - return NextResponse.json({ error: 'Agent not found' }, { status: 404 }) - } - - if (!agent.isPublished) { - return NextResponse.json({ error: 'Agent not published' }, { status: 404 }) + } catch (error) { + const a2aError = + error instanceof A2AError + ? error + : A2AError.internalError(getErrorMessage(error, 'Streaming error')) + const errorEnvelope = { + jsonrpc: '2.0' as const, + id: requestId, + error: a2aError.toJSONRPCError(), + } + controller.enqueue( + encoder.encode(`event: error\ndata: ${JSON.stringify(errorEnvelope)}\n\n`) + ) + } finally { + controller.close() } + }, + }) - const baseUrl = getBaseUrl() - const brandConfig = getBrandConfig() - - const authConfig = agent.authentication as { schemes?: string[] } | undefined - const schemes = authConfig?.schemes || [] - const isPublic = schemes.includes('none') - - const agentCard = { - protocolVersion: '0.3.0', - name: agent.name, - description: agent.description || '', - url: `${baseUrl}/api/a2a/serve/${agent.id}`, - version: agent.version, - preferredTransport: 'JSONRPC', - documentationUrl: `${baseUrl}/docs/a2a`, - provider: { - organization: brandConfig.name, - url: baseUrl, - }, - capabilities: agent.capabilities, - skills: agent.skills || [], - ...(isPublic - ? {} - : { - securitySchemes: { - apiKey: { - type: 'apiKey' as const, - name: 'X-API-Key', - in: 'header' as const, - description: 'API key authentication', - }, - }, - security: [{ apiKey: [] }], - }), - defaultInputModes: ['text/plain', 'application/json'], - defaultOutputModes: ['text/plain', 'application/json'], - } + return new NextResponse(readable, { headers: SSE_HEADERS }) +} - if (redis) { - try { - await redis.set(cacheKey, JSON.stringify(agentCard), 'EX', 60) - } catch (err) { - logger.warn('Redis cache write failed', { agentId, error: err }) - } - } +/** + * GET - Returns the Agent Card (discovery document) + */ +export const GET = withRouteHandler( + async (_request: NextRequest, { params }: { params: Promise }) => { + const { agentId } = a2aServeAgentParamsSchema.parse(await params) - return NextResponse.json(agentCard, { - headers: { - 'Content-Type': 'application/json', - 'Cache-Control': 'private, max-age=60', - 'X-Cache': 'MISS', - }, - }) - } catch (error) { - logger.error('Error getting Agent Card:', error) - return NextResponse.json({ error: 'Internal server error' }, { status: 500 }) + const result = await getServedAgentCard(agentId) + if (!result.ok) { + return NextResponse.json({ error: result.error }, { status: result.status }) } + + return NextResponse.json(result.card, { + headers: { + 'Content-Type': 'application/json', + 'Cache-Control': 'private, max-age=60', + 'X-Cache': result.cacheHit ? 'HIT' : 'MISS', + }, + }) } ) /** - * POST - Handle JSON-RPC requests + * POST - Handle JSON-RPC requests via the A2A SDK transport handler */ export const POST = withRouteHandler( async (request: NextRequest, { params }: { params: Promise }) => { const { agentId } = a2aServeAgentParamsSchema.parse(await params) + let body: unknown try { - const [agent] = await db - .select({ - id: a2aAgent.id, - name: a2aAgent.name, - workflowId: a2aAgent.workflowId, - workspaceId: a2aAgent.workspaceId, - isPublished: a2aAgent.isPublished, - capabilities: a2aAgent.capabilities, - authentication: a2aAgent.authentication, - }) - .from(a2aAgent) - .where(and(eq(a2aAgent.id, agentId), isNull(a2aAgent.archivedAt))) - .limit(1) - - if (!agent) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AGENT_UNAVAILABLE, 'Agent not found'), - { status: 404 } - ) - } - - if (!agent.isPublished) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AGENT_UNAVAILABLE, 'Agent not published'), - { status: 404 } - ) - } - - const authSchemes = (agent.authentication as { schemes?: string[] })?.schemes || [] - const requiresAuth = !authSchemes.includes('none') - let authenticatedUserId: string | null = null - let authenticatedAuthType: AuthResult['authType'] - let authenticatedApiKeyType: AuthResult['apiKeyType'] - - if (requiresAuth) { - const auth = await checkHybridAuth(request, { requireWorkflowId: false }) - if (!auth.success || !auth.userId) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AUTHENTICATION_REQUIRED, 'Unauthorized'), - { status: 401 } - ) - } - authenticatedUserId = auth.userId - authenticatedAuthType = auth.authType - authenticatedApiKeyType = auth.apiKeyType - - if (auth.apiKeyType === 'workspace' && auth.workspaceId !== agent.workspaceId) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AUTHENTICATION_REQUIRED, 'Access denied'), - { status: 403 } - ) - } - - const workspaceAccess = await checkWorkspaceAccess(agent.workspaceId, authenticatedUserId) - if (!workspaceAccess.exists || !workspaceAccess.hasAccess) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AUTHENTICATION_REQUIRED, 'Access denied'), - { status: 403 } - ) - } - } - - const [wf] = await db - .select({ isDeployed: workflow.isDeployed }) - .from(workflow) - .where(and(eq(workflow.id, agent.workflowId), isNull(workflow.archivedAt))) - .limit(1) - - if (!wf?.isDeployed) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.AGENT_UNAVAILABLE, 'Workflow is not deployed'), - { status: 400 } - ) - } - - let rawBody: unknown - try { - rawBody = await request.json() - } catch { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.PARSE_ERROR, 'Invalid JSON body'), - { status: 400 } - ) - } - - const bodyResult = a2aJsonRpcRequestSchema.safeParse(rawBody) - - if (!bodyResult.success) { - return NextResponse.json( - createError(null, A2A_ERROR_CODES.INVALID_REQUEST, 'Invalid JSON-RPC request'), - { status: 400 } - ) - } - - const body = bodyResult.data - const { id, method, params: rpcParams } = body - const requestApiKey = request.headers.get('X-API-Key') - const apiKey = authenticatedAuthType === AuthType.API_KEY ? requestApiKey : null - const isPersonalApiKeyCaller = - authenticatedAuthType === AuthType.API_KEY && authenticatedApiKeyType === 'personal' - const callerFingerprint = getCallerFingerprint(request, authenticatedUserId) - const billedUserId = await getWorkspaceBilledAccountUserId(agent.workspaceId) - if (!billedUserId) { - logger.error('Unable to resolve workspace billed account for A2A execution', { - agentId: agent.id, - workspaceId: agent.workspaceId, - }) - return NextResponse.json( - createError( - id, - A2A_ERROR_CODES.INTERNAL_ERROR, - 'Unable to resolve billing account for this workspace' - ), - { status: 500 } - ) - } - if (!(await isApiExecutionEntitled(billedUserId))) { - return NextResponse.json( - createError( - id, - A2A_ERROR_CODES.AGENT_UNAVAILABLE, - API_EXECUTION_REQUIRES_PAID_PLAN_MESSAGE - ), - { status: 402 } - ) - } - - const executionUserId = - isPersonalApiKeyCaller && authenticatedUserId ? authenticatedUserId : billedUserId - - logger.info(`A2A request: ${method} for agent ${agentId}`) - - switch (method) { - case A2A_METHODS.MESSAGE_SEND: { - const paramsValidation = a2aMessageSendParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Message is required'), - { status: 400 } - ) - } - - return handleMessageSend( - id, - agent, - paramsValidation.data, - apiKey, - executionUserId, - callerFingerprint - ) - } - - case A2A_METHODS.MESSAGE_STREAM: { - const paramsValidation = a2aMessageSendParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Message is required'), - { status: 400 } - ) - } - - return handleMessageStream( - request, - id, - agent, - paramsValidation.data, - apiKey, - executionUserId, - callerFingerprint - ) - } - - case A2A_METHODS.TASKS_GET: { - const paramsValidation = a2aTaskIdParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'), - { status: 400 } - ) - } - return handleTaskGet(id, agent.id, paramsValidation.data, callerFingerprint) - } - - case A2A_METHODS.TASKS_CANCEL: { - const paramsValidation = a2aTaskIdParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'), - { status: 400 } - ) - } - return handleTaskCancel(id, agent.id, paramsValidation.data, callerFingerprint) - } - - case A2A_METHODS.TASKS_RESUBSCRIBE: { - const paramsValidation = a2aTaskIdParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'), - { status: 400 } - ) - } - - return handleTaskResubscribe( - request, - id, - agent.id, - paramsValidation.data, - callerFingerprint - ) - } - - case A2A_METHODS.PUSH_NOTIFICATION_SET: { - const paramsValidation = a2aPushNotificationSetParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Invalid push notification params'), - { status: 400 } - ) - } - - return handlePushNotificationSet(id, agent.id, paramsValidation.data, callerFingerprint) - } - - case A2A_METHODS.PUSH_NOTIFICATION_GET: { - const paramsValidation = a2aTaskIdParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'), - { status: 400 } - ) - } - return handlePushNotificationGet(id, agent.id, paramsValidation.data, callerFingerprint) - } - - case A2A_METHODS.PUSH_NOTIFICATION_DELETE: { - const paramsValidation = a2aTaskIdParamsSchema.safeParse(rpcParams) - if (!paramsValidation.success) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'), - { status: 400 } - ) - } - - return handlePushNotificationDelete( - id, - agent.id, - paramsValidation.data, - callerFingerprint - ) - } - - default: - return NextResponse.json( - createError(id, A2A_ERROR_CODES.METHOD_NOT_FOUND, `Method not found: ${method}`), - { status: 404 } - ) - } - } catch (error) { - logger.error('Error handling A2A request:', error) - return NextResponse.json( - createError(null, A2A_ERROR_CODES.INTERNAL_ERROR, 'Internal error'), - { - status: 500, - } + // boundary-raw-json: A2A JSON-RPC envelope is parsed and validated by the @a2a-js/sdk JsonRpcTransportHandler + body = await request.json() + } catch { + return jsonRpcErrorResponse(null, A2AError.parseError('Invalid JSON body'), 400) + } + const requestId = extractJsonRpcId(body) + + const [agent] = await db + .select({ + id: a2aAgent.id, + name: a2aAgent.name, + workflowId: a2aAgent.workflowId, + workspaceId: a2aAgent.workspaceId, + isPublished: a2aAgent.isPublished, + authentication: a2aAgent.authentication, + }) + .from(a2aAgent) + .where(and(eq(a2aAgent.id, agentId), isNull(a2aAgent.archivedAt))) + .limit(1) + + if (!agent) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Agent not found'), + 404 + ) + } + if (!agent.isPublished) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Agent not published'), + 404 ) } - } -) - -async function getTaskForAgent(taskId: string, agentId: string, callerFingerprint?: string) { - const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, taskId)).limit(1) - if (!task || task.agentId !== agentId) { - return null - } - if (callerFingerprint && !hasCallerAccessToTask(task, callerFingerprint)) { - return null - } - return task -} - -/** - * Handle message/send - Send a message (v0.3) - */ -async function handleMessageSend( - id: A2AJsonRpcId, - agent: { - id: string - name: string - workflowId: string - workspaceId: string - }, - params: A2AMessageSendParams, - apiKey?: string | null, - executionUserId?: string, - callerFingerprint?: string -): Promise { - const message = params.message - const taskId = message.taskId || generateTaskId() - const contextId = message.contextId || generateId() - - // Distributed lock to prevent concurrent task processing - const lockKey = `a2a:task:${taskId}:lock` - const lockValue = generateId() - const acquired = await acquireLock(lockKey, lockValue, 60) - - if (!acquired) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INTERNAL_ERROR, 'Task is currently being processed'), - { status: 409 } - ) - } - - try { - let existingTask: typeof a2aTask.$inferSelect | null = null - if (message.taskId) { - const [found] = await db.select().from(a2aTask).where(eq(a2aTask.id, message.taskId)).limit(1) - existingTask = found || null - - if (!existingTask) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), - { status: 404 } - ) - } - if (existingTask.agentId !== agent.id) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), - { status: 404 } + const authSchemes = (agent.authentication as { schemes?: string[] })?.schemes || [] + const requiresAuth = !authSchemes.includes('none') + let authenticatedUserId: string | null = null + let authenticatedAuthType: AuthResult['authType'] + let authenticatedApiKeyType: AuthResult['apiKeyType'] + + if (requiresAuth) { + const auth = await checkHybridAuth(request, { requireWorkflowId: false }) + if (!auth.success || !auth.userId) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Unauthorized'), + 401 ) } - - if (callerFingerprint && !hasCallerAccessToTask(existingTask, callerFingerprint)) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), - { status: 404 } + authenticatedUserId = auth.userId + authenticatedAuthType = auth.authType + authenticatedApiKeyType = auth.apiKeyType + + if (auth.apiKeyType === 'workspace' && auth.workspaceId !== agent.workspaceId) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Access denied'), + 403 ) } - if (isTerminalState(existingTask.status as TaskState)) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_ALREADY_COMPLETE, 'Task already in terminal state'), - { status: 400 } + const workspaceAccess = await checkWorkspaceAccess(agent.workspaceId, authenticatedUserId) + if (!workspaceAccess.exists || !workspaceAccess.hasAccess) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Access denied'), + 403 ) } } - const history: Message[] = existingTask?.messages ? (existingTask.messages as Message[]) : [] - - history.push(message) - - if (history.length > A2A_MAX_HISTORY_LENGTH) { - history.splice(0, history.length - A2A_MAX_HISTORY_LENGTH) - } - - if (existingTask) { - await db - .update(a2aTask) - .set({ - status: 'working', - messages: history, - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - } else { - await db.insert(a2aTask).values({ - id: taskId, - agentId: agent.id, - sessionId: contextId || null, - status: 'working', - messages: history, - metadata: callerFingerprint ? { callerFingerprint } : {}, - createdAt: new Date(), - updatedAt: new Date(), - }) - } - - const { - url: executeUrl, - headers, - useInternalAuth, - } = await buildExecuteRequest({ - workflowId: agent.workflowId, - apiKey, - userId: executionUserId, - }) - - logger.info(`Executing workflow ${agent.workflowId} for A2A task ${taskId}`) - - try { - const workflowInput = extractWorkflowInput(message) - if (!workflowInput) { - await db - .update(a2aTask) - .set({ - status: 'failed', - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'failed').catch((err) => { - logger.error('Failed to trigger push notification for invalid input', { - taskId, - error: err, - }) - }) - - return NextResponse.json( - createError( - id, - A2A_ERROR_CODES.INVALID_PARAMS, - 'Message must contain at least one part with content' - ), - { status: 400 } - ) - } - - const response = await fetch(executeUrl, { - method: 'POST', - headers, - body: JSON.stringify({ - ...workflowInput, - triggerType: 'a2a', - ...(useInternalAuth && { workflowId: agent.workflowId }), - }), - signal: AbortSignal.timeout(A2A_DEFAULT_TIMEOUT), - }) - - const executeResult = await response.json() - const executionId = executeResult.executionId || executeResult.metadata?.executionId - const executionSucceeded = response.ok && executeResult.success !== false - const finalState: TaskState = executionSucceeded ? 'completed' : 'failed' - - const agentContent = extractAgentContent(executeResult) - const agentMessage = createAgentMessage(agentContent) - agentMessage.taskId = taskId - if (contextId) agentMessage.contextId = contextId - history.push(agentMessage) - - const artifacts = executeResult.output?.artifacts || [] - - await db - .update(a2aTask) - .set({ - status: finalState, - messages: history, - artifacts, - executionId, - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - if (isTerminalState(finalState)) { - notifyTaskStateChange(taskId, finalState).catch((err) => { - logger.error('Failed to trigger push notification', { taskId, error: err }) - }) - } - - const task = buildTaskResponse({ - taskId, - contextId, - state: finalState, - history, - artifacts, - }) - - return NextResponse.json(createResponse(id, task)) - } catch (error) { - const isTimeout = error instanceof Error && error.name === 'TimeoutError' - logger.error(`Error executing workflow for task ${taskId}:`, { error, isTimeout }) - - const errorMessage = isTimeout - ? `Workflow execution timed out after ${A2A_DEFAULT_TIMEOUT}ms` - : error instanceof Error - ? error.message - : 'Workflow execution failed' - - await db - .update(a2aTask) - .set({ - status: 'failed', - updatedAt: new Date(), - completedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'failed').catch((err) => { - logger.error('Failed to trigger push notification for failure', { taskId, error: err }) - }) - - return NextResponse.json(createError(id, A2A_ERROR_CODES.INTERNAL_ERROR, errorMessage), { - status: 500, - }) - } - } finally { - await releaseLock(lockKey, lockValue) - } -} - -/** - * Handle message/stream - Stream a message response (v0.3) - */ -async function handleMessageStream( - _request: NextRequest, - id: A2AJsonRpcId, - agent: { - id: string - name: string - workflowId: string - workspaceId: string - }, - params: A2AMessageSendParams, - apiKey?: string | null, - executionUserId?: string, - callerFingerprint?: string -): Promise { - const message = params.message - const contextId = message.contextId || generateId() - const taskId = message.taskId || generateTaskId() - - // Distributed lock to prevent concurrent task processing - const lockKey = `a2a:task:${taskId}:lock` - const lockValue = generateId() - const acquired = await acquireLock(lockKey, lockValue, 300) - - if (!acquired) { - const encoder = new TextEncoder() - const errorStream = new ReadableStream({ - start(controller) { - controller.enqueue( - encoder.encode( - `event: error\ndata: ${JSON.stringify({ code: A2A_ERROR_CODES.INTERNAL_ERROR, message: 'Task is currently being processed' })}\n\n` - ) - ) - controller.close() - }, - }) - return new NextResponse(errorStream, { headers: SSE_HEADERS }) - } - - let history: Message[] = [] - let existingTask: typeof a2aTask.$inferSelect | null = null - - if (message.taskId) { - const [found] = await db.select().from(a2aTask).where(eq(a2aTask.id, message.taskId)).limit(1) - existingTask = found || null - - if (!existingTask) { - await releaseLock(lockKey, lockValue) - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) + const [wf] = await db + .select({ isDeployed: workflow.isDeployed }) + .from(workflow) + .where(and(eq(workflow.id, agent.workflowId), isNull(workflow.archivedAt))) + .limit(1) + + if (!wf?.isDeployed) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, 'Workflow is not deployed'), + 400 + ) } - if (existingTask.agentId !== agent.id) { - await releaseLock(lockKey, lockValue) - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } + const requestApiKey = request.headers.get('X-API-Key') + const apiKey = authenticatedAuthType === AuthType.API_KEY ? requestApiKey : null + const isPersonalApiKeyCaller = + authenticatedAuthType === AuthType.API_KEY && authenticatedApiKeyType === 'personal' + const callerFingerprint = getCallerFingerprint(request, authenticatedUserId) - if (callerFingerprint && !hasCallerAccessToTask(existingTask, callerFingerprint)) { - await releaseLock(lockKey, lockValue) - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, + const billedUserId = await getWorkspaceBilledAccountUserId(agent.workspaceId) + if (!billedUserId) { + logger.error('Unable to resolve workspace billed account for A2A execution', { + agentId: agent.id, + workspaceId: agent.workspaceId, }) + return jsonRpcErrorResponse( + requestId, + A2AError.internalError('Unable to resolve billing account for this workspace'), + 500 + ) } - - if (isTerminalState(existingTask.status as TaskState)) { - await releaseLock(lockKey, lockValue) - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_ALREADY_COMPLETE, 'Task already in terminal state'), - { status: 400 } + if (!(await isApiExecutionEntitled(billedUserId))) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, API_EXECUTION_REQUIRES_PAID_PLAN_MESSAGE), + 402 ) } - history = existingTask.messages as Message[] - } - - history.push(message) - - if (history.length > A2A_MAX_HISTORY_LENGTH) { - history.splice(0, history.length - A2A_MAX_HISTORY_LENGTH) - } - - if (existingTask) { - await db - .update(a2aTask) - .set({ - status: 'working', - messages: history, - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - } else { - await db.insert(a2aTask).values({ - id: taskId, - agentId: agent.id, - sessionId: contextId || null, - status: 'working', - messages: history, - metadata: callerFingerprint ? { callerFingerprint } : {}, - createdAt: new Date(), - updatedAt: new Date(), - }) - } - - const encoder = new TextEncoder() - - const stream = new ReadableStream({ - async start(controller) { - const sendEvent = (event: string, data: unknown) => { - try { - const jsonRpcResponse = { - jsonrpc: '2.0' as const, - id, - result: data, - } - controller.enqueue( - encoder.encode(`event: ${event}\ndata: ${JSON.stringify(jsonRpcResponse)}\n\n`) - ) - } catch (error) { - logger.error('Error sending SSE event:', error) - } - } - - sendEvent('status', { - kind: 'status', - taskId, - contextId, - status: { state: 'working', timestamp: new Date().toISOString() }, - }) - - try { - const { - url: executeUrl, - headers, - useInternalAuth, - } = await buildExecuteRequest({ - workflowId: agent.workflowId, - apiKey, - userId: executionUserId, - stream: true, - }) - - const workflowInput = extractWorkflowInput(message) - if (!workflowInput) { - await db - .update(a2aTask) - .set({ - status: 'failed', - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'failed').catch((err) => { - logger.error('Failed to trigger push notification for invalid streamed input', { - taskId, - error: err, - }) - }) - - sendEvent('error', { - code: A2A_ERROR_CODES.INVALID_PARAMS, - message: 'Message must contain at least one part with content', - }) - await releaseLock(lockKey, lockValue) - controller.close() - return - } - - const response = await fetch(executeUrl, { - method: 'POST', - headers, - body: JSON.stringify({ - ...workflowInput, - triggerType: 'a2a', - stream: true, - ...(useInternalAuth && { workflowId: agent.workflowId }), - }), - signal: AbortSignal.timeout(A2A_DEFAULT_TIMEOUT), - }) - - if (!response.ok) { - let errorMessage = 'Workflow execution failed' - try { - const errorResult = await response.json() - errorMessage = errorResult.error || errorMessage - } catch { - // Response may not be JSON - } - throw new Error(errorMessage) - } - - const contentType = response.headers.get('content-type') || '' - const streamingExecutionId = response.headers.get('X-Execution-Id') || undefined - const isStreamingResponse = - contentType.includes('text/event-stream') || contentType.includes('text/plain') - - if (response.body && isStreamingResponse) { - const reader = response.body.getReader() - const decoder = new TextDecoder() - const contentChunks: string[] = [] - let finalContent: string | undefined - let finalArtifacts: Artifact[] = [] - let sseBuffer = '' - - while (true) { - const { done, value } = await reader.read() - if (done) break - - sseBuffer += decoder.decode(value, { stream: true }) - const frames = sseBuffer.split('\n\n') - sseBuffer = frames.pop() ?? '' - - for (const frame of frames) { - const parsed = parseWorkflowSSEChunk(frame) - - if (parsed.content) { - contentChunks.push(parsed.content) - sendEvent('message', { - kind: 'message', - taskId, - contextId, - role: 'agent', - parts: [{ kind: 'text', text: parsed.content }], - final: false, - }) - } - - if (parsed.finalContent) { - finalContent = parsed.finalContent - } - if (parsed.finalArtifacts) { - finalArtifacts = parsed.finalArtifacts - } - if (parsed.terminalState === 'canceled') { - const agentMessage = createAgentMessage(finalContent || 'Task canceled') - agentMessage.taskId = taskId - if (contextId) agentMessage.contextId = contextId - history.push(agentMessage) - - await db - .update(a2aTask) - .set({ - status: 'canceled', - messages: history, - executionId: streamingExecutionId, - artifacts: finalArtifacts, - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'canceled').catch((err) => { - logger.error('Failed to trigger push notification', { taskId, error: err }) - }) - - sendEvent('task', { - kind: 'task', - id: taskId, - contextId, - status: { state: 'canceled', timestamp: new Date().toISOString() }, - history, - artifacts: finalArtifacts, - }) - return - } - - if (parsed.finalSuccess === false) { - throw new Error('Workflow execution failed') - } - } - } - - if (sseBuffer.trim().length > 0) { - const parsed = parseWorkflowSSEChunk(sseBuffer) - if (parsed.content) { - contentChunks.push(parsed.content) - sendEvent('message', { - kind: 'message', - taskId, - contextId, - role: 'agent', - parts: [{ kind: 'text', text: parsed.content }], - final: false, - }) - } - if (parsed.finalContent) { - finalContent = parsed.finalContent - } - if (parsed.finalArtifacts) { - finalArtifacts = parsed.finalArtifacts - } - if (parsed.finalSuccess === false) { - throw new Error('Workflow execution failed') - } - } - - const accumulatedContent = contentChunks.join('') - const messageContent = - (finalContent !== undefined && finalContent.length > 0 - ? finalContent - : accumulatedContent) || 'Task completed' - const agentMessage = createAgentMessage(messageContent) - agentMessage.taskId = taskId - if (contextId) agentMessage.contextId = contextId - history.push(agentMessage) - - await db - .update(a2aTask) - .set({ - status: 'completed', - messages: history, - executionId: streamingExecutionId, - artifacts: finalArtifacts, - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'completed').catch((err) => { - logger.error('Failed to trigger push notification', { taskId, error: err }) - }) - - sendEvent('task', { - kind: 'task', - id: taskId, - contextId, - status: { state: 'completed', timestamp: new Date().toISOString() }, - history, - artifacts: finalArtifacts, - }) - } else { - const result = await response.json() - const executionSucceeded = result.success !== false - - const content = extractAgentContent(result) - - sendEvent('message', { - kind: 'message', - taskId, - contextId, - role: 'agent', - parts: [{ kind: 'text', text: content }], - final: true, - }) - - const agentMessage = createAgentMessage(content) - agentMessage.taskId = taskId - if (contextId) agentMessage.contextId = contextId - history.push(agentMessage) - - const artifacts = (result.output?.artifacts as Artifact[]) || [] - - await db - .update(a2aTask) - .set({ - status: executionSucceeded ? 'completed' : 'failed', - messages: history, - artifacts, - executionId: result.executionId || result.metadata?.executionId, - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) + const executionUserId = + isPersonalApiKeyCaller && authenticatedUserId ? authenticatedUserId : billedUserId - notifyTaskStateChange(taskId, executionSucceeded ? 'completed' : 'failed').catch( - (err) => { - logger.error('Failed to trigger push notification', { taskId, error: err }) - } - ) - - sendEvent('task', { - kind: 'task', - id: taskId, - contextId, - status: { - state: executionSucceeded ? 'completed' : 'failed', - timestamp: new Date().toISOString(), - }, - history, - artifacts, - }) - } - } catch (error) { - const isTimeout = error instanceof Error && error.name === 'TimeoutError' - logger.error(`Streaming error for task ${taskId}:`, { error, isTimeout }) - - const errorMessage = isTimeout - ? `Workflow execution timed out after ${A2A_DEFAULT_TIMEOUT}ms` - : error instanceof Error - ? error.message - : 'Streaming failed' - - await db - .update(a2aTask) - .set({ - status: 'failed', - completedAt: new Date(), - updatedAt: new Date(), - }) - .where(eq(a2aTask.id, taskId)) - - notifyTaskStateChange(taskId, 'failed').catch((err) => { - logger.error('Failed to trigger push notification for failure', { taskId, error: err }) - }) - - sendEvent('error', { - code: A2A_ERROR_CODES.INTERNAL_ERROR, - message: errorMessage, - }) - } finally { - await releaseLock(lockKey, lockValue) - controller.close() - } - }, - cancel() {}, - }) - - return new NextResponse(stream, { - headers: { - ...SSE_HEADERS, - 'X-Task-Id': taskId, - }, - }) -} - -/** - * Handle tasks/get - Query task status - */ -async function handleTaskGet( - id: A2AJsonRpcId, - agentId: string, - params: A2ATaskIdParams, - callerFingerprint?: string -): Promise { - const historyLength = - params.historyLength !== undefined && params.historyLength >= 0 - ? params.historyLength - : undefined - - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - const taskResponse = buildTaskResponse({ - taskId: task.id, - contextId: task.sessionId || task.id, - state: task.status as TaskState, - history: task.messages as Message[], - artifacts: (task.artifacts as Artifact[]) || [], - }) - - const result = formatTaskResponse(taskResponse, historyLength) - - return NextResponse.json(createResponse(id, result)) -} - -/** - * Handle tasks/cancel - Cancel a running task - */ -async function handleTaskCancel( - id: A2AJsonRpcId, - agentId: string, - params: A2ATaskIdParams, - callerFingerprint?: string -): Promise { - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - if (isTerminalState(task.status as TaskState)) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_ALREADY_COMPLETE, 'Task already in terminal state'), - { status: 400 } - ) - } - - if (task.executionId) { - try { - await markExecutionCancelled(task.executionId) - logger.info('Cancelled workflow execution', { - taskId: task.id, - executionId: task.executionId, - }) - } catch (error) { - logger.warn('Failed to cancel workflow execution', { - taskId: task.id, - executionId: task.executionId, - error, - }) + const cardResult = await getServedAgentCard(agentId) + if (!cardResult.ok) { + return jsonRpcErrorResponse( + requestId, + new A2AError(A2A_SERVER_ERROR_CODE, cardResult.error), + cardResult.status + ) } - } - - await db - .update(a2aTask) - .set({ - status: 'canceled', - updatedAt: new Date(), - completedAt: new Date(), - }) - .where(eq(a2aTask.id, params.id)) - - notifyTaskStateChange(params.id, 'canceled').catch((err) => { - logger.error('Failed to trigger push notification for cancellation', { - taskId: params.id, - error: err, - }) - }) - - const canceledTask = buildTaskResponse({ - taskId: task.id, - contextId: task.sessionId || task.id, - state: 'canceled', - history: task.messages as Message[], - artifacts: (task.artifacts as Artifact[]) || [], - }) - - return NextResponse.json(createResponse(id, canceledTask)) -} - -/** - * Handle tasks/resubscribe - Reconnect to SSE stream for an ongoing task - */ -async function handleTaskResubscribe( - request: NextRequest, - id: A2AJsonRpcId, - agentId: string, - params: A2ATaskIdParams, - callerFingerprint?: string -): Promise { - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - const encoder = new TextEncoder() - if (isTerminalState(task.status as TaskState)) { - const completedTask = buildTaskResponse({ - taskId: task.id, - contextId: task.sessionId || task.id, - state: task.status as TaskState, - history: task.messages as Message[], - artifacts: (task.artifacts as Artifact[]) || [], - }) - const jsonRpcResponse = { jsonrpc: '2.0' as const, id, result: completedTask } - const sseData = `event: task\ndata: ${JSON.stringify(jsonRpcResponse)}\n\n` - const stream = new ReadableStream({ - start(controller) { - controller.enqueue(encoder.encode(sseData)) - controller.close() + const handler = new SimA2ARequestHandler({ + agent: { + id: agent.id, + name: agent.name, + workflowId: agent.workflowId, + workspaceId: agent.workspaceId, }, + agentCard: cardResult.card, + apiKey, + executionUserId, + callerFingerprint, + requestSignal: request.signal, }) - return new NextResponse(stream, { headers: SSE_HEADERS }) - } - let isCancelled = false - let pollTimeoutId: ReturnType | null = null - const abortSignal = request.signal - abortSignal.addEventListener( - 'abort', - () => { - isCancelled = true - if (pollTimeoutId) { - clearTimeout(pollTimeoutId) - pollTimeoutId = null - } - }, - { once: true } - ) + const transport = new JsonRpcTransportHandler(handler) + const result = await transport.handle(body) - const cleanup = () => { - isCancelled = true - if (pollTimeoutId) { - clearTimeout(pollTimeoutId) - pollTimeoutId = null + if (isAsyncIterable(result)) { + return streamJsonRpc(result, requestId) } - } - const stream = new ReadableStream({ - async start(controller) { - const sendEvent = (event: string, data: unknown): boolean => { - if (isCancelled || abortSignal.aborted) return false - try { - const jsonRpcResponse = { jsonrpc: '2.0' as const, id, result: data } - controller.enqueue( - encoder.encode(`event: ${event}\ndata: ${JSON.stringify(jsonRpcResponse)}\n\n`) - ) - return true - } catch (error) { - logger.error('Error sending SSE event:', error) - isCancelled = true - return false - } - } - - if ( - !sendEvent('status', { - kind: 'status', - taskId: task.id, - contextId: task.sessionId, - status: { state: task.status, timestamp: new Date().toISOString() }, - }) - ) { - cleanup() - return - } - - const pollInterval = 3000 // 3 seconds - const maxPolls = 100 // 5 minutes max - - let polls = 0 - const poll = async () => { - if (isCancelled || abortSignal.aborted) { - cleanup() - return - } - - polls++ - if (polls > maxPolls) { - cleanup() - try { - controller.close() - } catch { - // Already closed - } - return - } - - try { - const [updatedTask] = await db - .select() - .from(a2aTask) - .where(eq(a2aTask.id, params.id)) - .limit(1) - - if (isCancelled) { - cleanup() - return - } - - if (!updatedTask) { - sendEvent('error', { code: A2A_ERROR_CODES.TASK_NOT_FOUND, message: 'Task not found' }) - cleanup() - try { - controller.close() - } catch { - // Already closed - } - return - } - - if (updatedTask.status !== task.status) { - if ( - !sendEvent('status', { - kind: 'status', - taskId: updatedTask.id, - contextId: updatedTask.sessionId, - status: { state: updatedTask.status, timestamp: new Date().toISOString() }, - final: isTerminalState(updatedTask.status as TaskState), - }) - ) { - cleanup() - return - } - } - - if (isTerminalState(updatedTask.status as TaskState)) { - const messages = updatedTask.messages as Message[] - const lastMessage = messages[messages.length - 1] - if (lastMessage && lastMessage.role === 'agent') { - sendEvent('message', { - ...lastMessage, - taskId: updatedTask.id, - contextId: updatedTask.sessionId || updatedTask.id, - final: true, - }) - } - - cleanup() - try { - controller.close() - } catch { - // Already closed - } - return - } - - pollTimeoutId = setTimeout(poll, pollInterval) - } catch (error) { - logger.error('Error during SSE poll:', error) - sendEvent('error', { - code: A2A_ERROR_CODES.INTERNAL_ERROR, - message: getErrorMessage(error, 'Polling failed'), - }) - cleanup() - try { - controller.close() - } catch { - // Already closed - } - } - } - - poll() - }, - cancel() { - cleanup() - }, - }) - - return new NextResponse(stream, { - headers: { - ...SSE_HEADERS, - 'X-Task-Id': params.id, - }, - }) -} - -/** - * Handle tasks/pushNotificationConfig/set - Set webhook for task updates - */ -async function handlePushNotificationSet( - id: A2AJsonRpcId, - agentId: string, - params: A2APushNotificationSetParams, - callerFingerprint?: string -): Promise { - const urlValidation = await validateUrlWithDNS( - params.pushNotificationConfig.url, - 'Push notification URL' - ) - if (!urlValidation.isValid) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.INVALID_PARAMS, urlValidation.error || 'Invalid URL'), - { status: 400 } - ) - } - - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - const [existingConfig] = await db - .select() - .from(a2aPushNotificationConfig) - .where(eq(a2aPushNotificationConfig.taskId, params.id)) - .limit(1) - - const config = params.pushNotificationConfig - - if (existingConfig) { - await db - .update(a2aPushNotificationConfig) - .set({ - url: config.url, - token: config.token || null, - isActive: true, - updatedAt: new Date(), - }) - .where(eq(a2aPushNotificationConfig.id, existingConfig.id)) - } else { - await db.insert(a2aPushNotificationConfig).values({ - id: generateId(), - taskId: params.id, - url: config.url, - token: config.token || null, - isActive: true, - createdAt: new Date(), - updatedAt: new Date(), - }) - } - - const result: PushNotificationConfig = { - url: config.url, - token: config.token, - } - - return NextResponse.json(createResponse(id, result)) -} - -/** - * Handle tasks/pushNotificationConfig/get - Get webhook config for a task - */ -async function handlePushNotificationGet( - id: A2AJsonRpcId, - agentId: string, - params: A2ATaskIdParams, - callerFingerprint?: string -): Promise { - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - const [config] = await db - .select() - .from(a2aPushNotificationConfig) - .where(eq(a2aPushNotificationConfig.taskId, params.id)) - .limit(1) - - if (!config) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Push notification config not found'), - { status: 404 } - ) + return NextResponse.json(result) } - - const result: PushNotificationConfig = { - url: config.url, - token: config.token || undefined, - } - - return NextResponse.json(createResponse(id, result)) -} - -/** - * Handle tasks/pushNotificationConfig/delete - Delete webhook config for a task - */ -async function handlePushNotificationDelete( - id: A2AJsonRpcId, - agentId: string, - params: A2ATaskIdParams, - callerFingerprint?: string -): Promise { - const task = await getTaskForAgent(params.id, agentId, callerFingerprint) - - if (!task) { - return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), { - status: 404, - }) - } - - const [config] = await db - .select() - .from(a2aPushNotificationConfig) - .where(eq(a2aPushNotificationConfig.taskId, params.id)) - .limit(1) - - if (!config) { - return NextResponse.json( - createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Push notification config not found'), - { status: 404 } - ) - } - - await db.delete(a2aPushNotificationConfig).where(eq(a2aPushNotificationConfig.id, config.id)) - - return NextResponse.json(createResponse(id, { success: true })) -} +) diff --git a/apps/sim/app/api/a2a/serve/[agentId]/utils.ts b/apps/sim/app/api/a2a/serve/[agentId]/utils.ts index de4be12323a..ca2012b688a 100644 --- a/apps/sim/app/api/a2a/serve/[agentId]/utils.ts +++ b/apps/sim/app/api/a2a/serve/[agentId]/utils.ts @@ -1,95 +1,134 @@ -import type { Artifact, Message, PushNotificationConfig, Task, TaskState } from '@a2a-js/sdk' -import { generateId } from '@sim/utils/id' +import type { + AgentCard, + Artifact, + Message, + Task, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +} from '@a2a-js/sdk' +import { db } from '@sim/db' +import { a2aAgent } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { and, eq, isNull } from 'drizzle-orm' +import { buildAgentCard } from '@/lib/a2a/agent-card' +import type { AgentAuthentication, AgentCapabilities, AgentSkill } from '@/lib/a2a/types' import { generateInternalToken } from '@/lib/auth/internal' -import { getInternalApiBaseUrl } from '@/lib/core/utils/urls' - -/** A2A v0.3 JSON-RPC method names */ -export const A2A_METHODS = { - MESSAGE_SEND: 'message/send', - MESSAGE_STREAM: 'message/stream', - TASKS_GET: 'tasks/get', - TASKS_CANCEL: 'tasks/cancel', - TASKS_RESUBSCRIBE: 'tasks/resubscribe', - PUSH_NOTIFICATION_SET: 'tasks/pushNotificationConfig/set', - PUSH_NOTIFICATION_GET: 'tasks/pushNotificationConfig/get', - PUSH_NOTIFICATION_DELETE: 'tasks/pushNotificationConfig/delete', -} as const - -/** A2A v0.3 error codes */ -export const A2A_ERROR_CODES = { - PARSE_ERROR: -32700, - INVALID_REQUEST: -32600, - METHOD_NOT_FOUND: -32601, - INVALID_PARAMS: -32602, - INTERNAL_ERROR: -32603, - TASK_NOT_FOUND: -32001, - TASK_ALREADY_COMPLETE: -32002, - AGENT_UNAVAILABLE: -32003, - AUTHENTICATION_REQUIRED: -32004, -} as const - -interface JSONRPCRequest { - jsonrpc: '2.0' - id: string | number - method: string - params?: unknown -} - -export interface JSONRPCResponse { - jsonrpc: '2.0' - id: string | number | null - result?: unknown - error?: { - code: number - message: string - data?: unknown +import { getRedisClient } from '@/lib/core/config/redis' +import { getBaseUrl, getInternalApiBaseUrl } from '@/lib/core/utils/urls' +import { getBrandConfig } from '@/ee/whitelabeling' + +const logger = createLogger('A2AServeUtils') + +const AGENT_CARD_CACHE_TTL_SECONDS = 60 + +export type ServedAgentCardResult = + | { ok: true; card: AgentCard; cacheHit: boolean } + | { ok: false; status: number; error: string } + +/** + * Load and build the public {@link AgentCard} for a published agent. + * + * Shared by the serve GET endpoint and the `.well-known/agent-card.json` + * discovery endpoint. Caches the built card in Redis (best-effort). + */ +export async function getServedAgentCard(agentId: string): Promise { + const redis = getRedisClient() + const cacheKey = `a2a:agent:${agentId}:card` + + if (redis) { + try { + const cached = await redis.get(cacheKey) + if (cached) { + return { ok: true, card: JSON.parse(cached) as AgentCard, cacheHit: true } + } + } catch (err) { + logger.warn('Redis cache read failed', { agentId, error: err }) + } } -} -interface MessageSendParams { - message: Message - configuration?: { - acceptedOutputModes?: string[] - historyLength?: number - pushNotificationConfig?: PushNotificationConfig + const [agent] = await db + .select({ + id: a2aAgent.id, + name: a2aAgent.name, + description: a2aAgent.description, + version: a2aAgent.version, + capabilities: a2aAgent.capabilities, + skills: a2aAgent.skills, + authentication: a2aAgent.authentication, + isPublished: a2aAgent.isPublished, + }) + .from(a2aAgent) + .where(and(eq(a2aAgent.id, agentId), isNull(a2aAgent.archivedAt))) + .limit(1) + + if (!agent) { + return { ok: false, status: 404, error: 'Agent not found' } } -} - -interface TaskIdParams { - id: string - historyLength?: number -} -interface PushNotificationSetParams { - id: string - pushNotificationConfig: PushNotificationConfig -} + if (!agent.isPublished) { + return { ok: false, status: 404, error: 'Agent not published' } + } -export function createResponse(id: string | number | null, result: unknown): JSONRPCResponse { - return { jsonrpc: '2.0', id, result } -} + const card = buildAgentCard({ + agent: { + id: agent.id, + name: agent.name, + description: agent.description, + version: agent.version, + capabilities: agent.capabilities as AgentCapabilities, + skills: agent.skills as AgentSkill[], + authentication: agent.authentication as AgentAuthentication, + }, + baseUrl: getBaseUrl(), + providerOrganization: getBrandConfig().name, + }) + + if (redis) { + try { + await redis.set(cacheKey, JSON.stringify(card), 'EX', AGENT_CARD_CACHE_TTL_SECONDS) + } catch (err) { + logger.warn('Redis cache write failed', { agentId, error: err }) + } + } -export function createError( - id: string | number | null, - code: number, - message: string, - data?: unknown -): JSONRPCResponse { - return { jsonrpc: '2.0', id, error: { code, message, data } } + return { ok: true, card, cacheHit: false } } -export function isJSONRPCRequest(obj: unknown): obj is JSONRPCRequest { - if (!obj || typeof obj !== 'object') return false - const r = obj as Record - return r.jsonrpc === '2.0' && typeof r.method === 'string' && r.id !== undefined +export function createTaskStatus(state: TaskState): TaskStatus { + return { state, timestamp: new Date().toISOString() } } -export function generateTaskId(): string { - return generateId() +export function buildTaskResponse(params: { + taskId: string + contextId: string + state: TaskState + history: Message[] + artifacts?: Artifact[] +}): Task { + return { + kind: 'task', + id: params.taskId, + contextId: params.contextId, + status: createTaskStatus(params.state), + history: params.history, + artifacts: params.artifacts || [], + } } -export function createTaskStatus(state: TaskState): { state: TaskState; timestamp: string } { - return { state, timestamp: new Date().toISOString() } +export function buildStatusUpdate(params: { + taskId: string + contextId: string + state: TaskState + final: boolean +}): TaskStatusUpdateEvent { + return { + kind: 'status-update', + taskId: params.taskId, + contextId: params.contextId, + status: createTaskStatus(params.state), + final: params.final, + } } export function formatTaskResponse(task: Task, historyLength?: number): Task { @@ -141,37 +180,16 @@ export function extractAgentContent(executeResult: { output?: { content?: string; [key: string]: unknown } error?: string }): string { - // Prefer explicit content field if (executeResult.output?.content) { return executeResult.output.content } - // If output is an object with meaningful data, stringify it if (typeof executeResult.output === 'object' && executeResult.output !== null) { const keys = Object.keys(executeResult.output) - // Skip empty objects or objects with only undefined values if (keys.length > 0 && keys.some((k) => executeResult.output![k] !== undefined)) { return JSON.stringify(executeResult.output) } } - // Fallback to error message or default return executeResult.error || 'Task completed' } - -export function buildTaskResponse(params: { - taskId: string - contextId: string - state: TaskState - history: Message[] - artifacts?: Artifact[] -}): Task { - return { - kind: 'task', - id: params.taskId, - contextId: params.contextId, - status: createTaskStatus(params.state), - history: params.history, - artifacts: params.artifacts || [], - } -} diff --git a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts index fed318b8330..e169724c294 100644 --- a/apps/sim/app/api/tools/a2a/get-agent-card/route.ts +++ b/apps/sim/app/api/tools/a2a/get-agent-card/route.ts @@ -80,7 +80,8 @@ export const POST = withRouteHandler(async (request: NextRequest) => { name: agentCard.name, description: agentCard.description, url: agentCard.url, - version: agentCard.protocolVersion, + version: agentCard.version, + protocolVersion: agentCard.protocolVersion, capabilities: agentCard.capabilities, skills: agentCard.skills, defaultInputModes: agentCard.defaultInputModes, diff --git a/apps/sim/blocks/blocks/a2a.ts b/apps/sim/blocks/blocks/a2a.ts index b15e8dacfd4..9a131c88ec4 100644 --- a/apps/sim/blocks/blocks/a2a.ts +++ b/apps/sim/blocks/blocks/a2a.ts @@ -35,8 +35,10 @@ export interface A2AResponse extends ToolResponse { description?: string /** Agent URL (get_agent_card) */ url?: string - /** Agent version (get_agent_card) */ + /** Agent's own version (get_agent_card) */ version?: string + /** A2A protocol version supported by the agent (get_agent_card) */ + protocolVersion?: string /** Agent capabilities (get_agent_card) */ capabilities?: Record /** Agent skills (get_agent_card) */ @@ -320,7 +322,11 @@ export const A2ABlock: BlockConfig = { }, version: { type: 'string', - description: 'Agent version', + description: "Agent's own version", + }, + protocolVersion: { + type: 'string', + description: 'A2A protocol version supported by the agent', }, capabilities: { type: 'json', diff --git a/apps/sim/lib/a2a/agent-card.test.ts b/apps/sim/lib/a2a/agent-card.test.ts new file mode 100644 index 00000000000..6c1a6cbf251 --- /dev/null +++ b/apps/sim/lib/a2a/agent-card.test.ts @@ -0,0 +1,133 @@ +/** + * @vitest-environment node + */ +import { describe, expect, it } from 'vitest' +import { buildAgentCard } from '@/lib/a2a/agent-card' +import { A2A_PROTOCOL_VERSION } from '@/lib/a2a/constants' + +const BASE_URL = 'https://example.com' + +const baseAgent = { + id: 'agent-1', + name: 'Support Agent', + version: '2.1.0', +} + +describe('buildAgentCard', () => { + it('emits a spec-compliant v0.3 AgentCard', () => { + const card = buildAgentCard({ + agent: baseAgent, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + }) + + expect(card.protocolVersion).toBe(A2A_PROTOCOL_VERSION) + expect(card.name).toBe('Support Agent') + expect(card.version).toBe('2.1.0') + expect(card.preferredTransport).toBe('JSONRPC') + expect(card.url).toBe('https://example.com/api/a2a/serve/agent-1') + expect(card.provider).toEqual({ organization: 'Sim', url: BASE_URL }) + expect(card.documentationUrl).toBe('https://example.com/docs/a2a') + }) + + it('uses MIME types (not "text") for default input/output modes', () => { + const card = buildAgentCard({ + agent: baseAgent, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + }) + + expect(card.defaultInputModes).toEqual(['text/plain', 'application/json']) + expect(card.defaultOutputModes).toEqual(['text/plain', 'application/json']) + }) + + it('reports the agent version distinct from the protocol version', () => { + const card = buildAgentCard({ + agent: { ...baseAgent, version: '9.9.9' }, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + }) + + expect(card.version).toBe('9.9.9') + expect(card.protocolVersion).toBe(A2A_PROTOCOL_VERSION) + expect(card.version).not.toBe(card.protocolVersion) + }) + + it('adds an apiKey security scheme when auth is required', () => { + const card = buildAgentCard({ + agent: { ...baseAgent, authentication: { schemes: ['bearer', 'apiKey'] } }, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + }) + + expect(card.securitySchemes).toEqual({ + apiKey: { + type: 'apiKey', + name: 'X-API-Key', + in: 'header', + description: 'API key authentication', + }, + }) + expect(card.security).toEqual([{ apiKey: [] }]) + }) + + it('omits security schemes for public ("none") agents', () => { + const card = buildAgentCard({ + agent: { ...baseAgent, authentication: { schemes: ['none'] } }, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + }) + + expect(card.securitySchemes).toBeUndefined() + expect(card.security).toBeUndefined() + }) + + it('synthesizes a default execute skill from the workflow when none are set', () => { + const card = buildAgentCard({ + agent: baseAgent, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + workflow: { name: 'Triage', description: 'Triage inbound tickets' }, + }) + + expect(card.skills).toEqual([ + { + id: 'execute', + name: 'Execute Triage', + description: 'Triage inbound tickets', + tags: [], + }, + ]) + }) + + it('prefers explicit agent skills over the synthesized default', () => { + const skills = [ + { id: 'summarize', name: 'Summarize', description: 'Summarize text', tags: ['nlp'] }, + ] + const card = buildAgentCard({ + agent: { ...baseAgent, skills }, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + workflow: { name: 'Triage' }, + }) + + expect(card.skills).toBe(skills) + }) + + it('falls back to workflow then a generated description', () => { + const withWorkflow = buildAgentCard({ + agent: baseAgent, + baseUrl: BASE_URL, + providerOrganization: 'Sim', + workflow: { description: 'From workflow' }, + }) + expect(withWorkflow.description).toBe('From workflow') + + const generated = buildAgentCard({ + agent: baseAgent, + baseUrl: BASE_URL, + providerOrganization: 'Acme', + }) + expect(generated.description).toBe('Support Agent - A2A Agent powered by Acme') + }) +}) diff --git a/apps/sim/lib/a2a/agent-card.ts b/apps/sim/lib/a2a/agent-card.ts index b80fd6476f4..16d14d0ad12 100644 --- a/apps/sim/lib/a2a/agent-card.ts +++ b/apps/sim/lib/a2a/agent-card.ts @@ -1,74 +1,97 @@ -import { getBaseUrl } from '@/lib/core/utils/urls' +import type { AgentCard } from '@a2a-js/sdk' import { A2A_DEFAULT_CAPABILITIES, A2A_DEFAULT_INPUT_MODES, A2A_DEFAULT_OUTPUT_MODES, A2A_PROTOCOL_VERSION, } from './constants' -import type { AgentCapabilities, AgentSkill } from './types' -import { buildA2AEndpointUrl, sanitizeAgentName } from './utils' +import type { AgentAuthentication, AgentCapabilities, AgentSkill } from './types' +import { buildA2AEndpointUrl } from './utils' -export interface AppAgentCard { - name: string - description: string - url: string - protocolVersion: string - documentationUrl?: string - provider?: { - organization: string - url: string - } - capabilities: AgentCapabilities - skills: AgentSkill[] - defaultInputModes: string[] - defaultOutputModes: string[] -} - -interface WorkflowData { - id: string - name: string - description?: string | null -} - -interface AgentData { +interface BuildAgentCardAgent { id: string name: string description?: string | null version: string capabilities?: AgentCapabilities skills?: AgentSkill[] + authentication?: AgentAuthentication | null } -export function generateAgentCard(agent: AgentData, workflow: WorkflowData): AppAgentCard { - const baseUrl = getBaseUrl() +interface BuildAgentCardWorkflow { + name?: string | null + description?: string | null +} + +interface BuildAgentCardInput { + agent: BuildAgentCardAgent + baseUrl: string + /** Provider organization name (whitelabel-aware brand name). */ + providerOrganization: string + /** Optional source workflow, used only for skill/description fallbacks. */ + workflow?: BuildAgentCardWorkflow +} + +/** + * Build a spec-compliant {@link AgentCard} (A2A v0.3) for a Sim agent. + * + * Single source of truth shared by the public serve endpoint, the + * `.well-known/agent-card.json` discovery endpoint, and the management endpoint + * so the three never drift. + */ +export function buildAgentCard({ + agent, + baseUrl, + providerOrganization, + workflow, +}: BuildAgentCardInput): AgentCard { const description = - agent.description || workflow.description || `${agent.name} - A2A Agent powered by Sim` + agent.description || + workflow?.description || + `${agent.name} - A2A Agent powered by ${providerOrganization}` + + const schemes = agent.authentication?.schemes ?? [] + const isPublic = schemes.includes('none') + + const skills: AgentSkill[] = + agent.skills && agent.skills.length > 0 + ? agent.skills + : generateSkillsFromWorkflow(workflow?.name || agent.name, workflow?.description) - return { + const card: AgentCard = { + protocolVersion: A2A_PROTOCOL_VERSION, name: agent.name, description, url: buildA2AEndpointUrl(baseUrl, agent.id), - protocolVersion: A2A_PROTOCOL_VERSION, + version: agent.version, + preferredTransport: 'JSONRPC', documentationUrl: `${baseUrl}/docs/a2a`, provider: { - organization: 'Sim', + organization: providerOrganization, url: baseUrl, }, capabilities: { ...A2A_DEFAULT_CAPABILITIES, ...agent.capabilities, }, - skills: agent.skills || [ - { - id: 'execute', - name: `Execute ${workflow.name}`, - description: workflow.description || `Execute the ${workflow.name} workflow`, - tags: [], - }, - ], + skills, defaultInputModes: [...A2A_DEFAULT_INPUT_MODES], defaultOutputModes: [...A2A_DEFAULT_OUTPUT_MODES], } + + if (!isPublic) { + card.securitySchemes = { + apiKey: { + type: 'apiKey', + name: 'X-API-Key', + in: 'header', + description: 'API key authentication', + }, + } + card.security = [{ apiKey: [] }] + } + + return card } export function generateSkillsFromWorkflow( @@ -85,54 +108,3 @@ export function generateSkillsFromWorkflow( return [skill] } - -export function generateDefaultAgentName(workflowName: string): string { - return sanitizeAgentName(workflowName) -} - -export function validateAgentCard(card: unknown): card is AppAgentCard { - if (!card || typeof card !== 'object') return false - - const c = card as Record - - if (typeof c.name !== 'string' || !c.name) return false - if (typeof c.url !== 'string' || !c.url) return false - if (typeof c.description !== 'string') return false - - // Validate URL format - try { - const url = new URL(c.url) - if (!['http:', 'https:'].includes(url.protocol)) return false - } catch { - return false - } - - if (c.capabilities && typeof c.capabilities !== 'object') return false - - if (!Array.isArray(c.skills)) return false - - return true -} - -export function mergeAgentCard( - existing: AppAgentCard, - updates: Partial -): AppAgentCard { - return { - ...existing, - ...updates, - capabilities: { - ...existing.capabilities, - ...updates.capabilities, - }, - skills: updates.skills || existing.skills, - } -} - -export function getAgentCardPaths(agentId: string) { - const baseUrl = getBaseUrl() - return { - card: `${baseUrl}/api/a2a/agents/${agentId}`, - serve: `${baseUrl}/api/a2a/serve/${agentId}`, - } -} diff --git a/apps/sim/lib/a2a/constants.ts b/apps/sim/lib/a2a/constants.ts index 41429c6b544..0a9c0201f9b 100644 --- a/apps/sim/lib/a2a/constants.ts +++ b/apps/sim/lib/a2a/constants.ts @@ -17,9 +17,9 @@ export const A2A_DEFAULT_CAPABILITIES = { stateTransitionHistory: true, } as const -export const A2A_DEFAULT_INPUT_MODES = ['text'] as const +export const A2A_DEFAULT_INPUT_MODES = ['text/plain', 'application/json'] as const -export const A2A_DEFAULT_OUTPUT_MODES = ['text'] as const +export const A2A_DEFAULT_OUTPUT_MODES = ['text/plain', 'application/json'] as const export const A2A_CACHE = { AGENT_CARD_TTL: 3600, // 1 hour diff --git a/apps/sim/lib/a2a/types.ts b/apps/sim/lib/a2a/types.ts index badb72f19a7..fd38c9f6ea7 100644 --- a/apps/sim/lib/a2a/types.ts +++ b/apps/sim/lib/a2a/types.ts @@ -1,103 +1,17 @@ /** * A2A (Agent-to-Agent) Protocol Types (v0.3) * @see https://a2a-protocol.org/specification + * + * Protocol shapes are owned by `@a2a-js/sdk`. Only Sim-specific types live here. */ -export type { - AgentCapabilities, - AgentSkill, -} from '@a2a-js/sdk' +export type { AgentCapabilities, AgentSkill } from '@a2a-js/sdk' /** - * App-specific: Extended MessageSendParams - * Note: Structured inputs should be passed via DataPart in message.parts (A2A spec compliant) - * Files should be passed via FilePart in message.parts - */ -interface ExtendedMessageSendParams { - message: import('@a2a-js/sdk').Message - configuration?: import('@a2a-js/sdk').MessageSendConfiguration -} - -/** - * App-specific: Database model for A2A Agent configuration - */ -interface A2AAgentConfig { - id: string - workspaceId: string - workflowId: string - name: string - description?: string - version: string - capabilities: import('@a2a-js/sdk').AgentCapabilities - skills: import('@a2a-js/sdk').AgentSkill[] - authentication?: AgentAuthentication - signatures?: AgentCardSignature[] - isPublished: boolean - publishedAt?: Date - createdAt: Date - updatedAt: Date -} - -/** - * App-specific: Agent authentication configuration + * Sim-specific: how an agent authenticates callers. This is mapped onto the A2A + * card's `securitySchemes` / `security` fields when the card is built. */ export interface AgentAuthentication { schemes: Array<'bearer' | 'apiKey' | 'oauth2' | 'none'> credentials?: string } - -/** - * App-specific: Agent card signature (v0.3) - */ -interface AgentCardSignature { - algorithm: string - keyId: string - value: string -} - -/** - * App-specific: Database model for A2A Task record - */ -interface A2ATaskRecord { - id: string - agentId: string - contextId?: string - status: import('@a2a-js/sdk').TaskState - history: import('@a2a-js/sdk').Message[] - artifacts?: import('@a2a-js/sdk').Artifact[] - executionId?: string - metadata?: Record - createdAt: Date - updatedAt: Date - completedAt?: Date -} - -/** - * App-specific: A2A API Response wrapper - */ -interface A2AApiResponse { - success: boolean - data?: T - error?: string -} - -/** - * App-specific: JSON Schema definition for skill input/output schemas - */ -interface JSONSchema { - type?: string - properties?: Record - items?: JSONSchema - required?: string[] - description?: string - enum?: unknown[] - default?: unknown - format?: string - minimum?: number - maximum?: number - minLength?: number - maxLength?: number - pattern?: string - additionalProperties?: boolean | JSONSchema - [key: string]: unknown -} diff --git a/apps/sim/lib/a2a/utils.ts b/apps/sim/lib/a2a/utils.ts index d89a8cec040..db93f6d7dd9 100644 --- a/apps/sim/lib/a2a/utils.ts +++ b/apps/sim/lib/a2a/utils.ts @@ -1,12 +1,13 @@ -import type { - Artifact, - DataPart, - FilePart, - Message, - Part, - Task, - TaskState, - TextPart, +import { + AGENT_CARD_PATH, + type Artifact, + type DataPart, + type FilePart, + type Message, + type Part, + type Task, + type TaskState, + type TextPart, } from '@a2a-js/sdk' import { type BeforeArgs, @@ -53,12 +54,13 @@ class ApiKeyInterceptor implements CallInterceptor { } /** - * Create an A2A client from an agent URL with optional API key authentication - * - * Supports both standard A2A agents (agent card at /.well-known/agent.json) - * and Sim Studio agents (agent card at root URL via GET). + * Create an A2A client from an agent URL with optional API key authentication. * - * Tries standard path first, falls back to root URL for compatibility. + * Resolves the agent card by trying, in order: + * 1. the A2A v0.3 well-known path (`.well-known/agent-card.json`, the SDK default), + * 2. the legacy pre-0.3 path (`/.well-known/agent.json`), + * 3. the provided URL directly (Sim serves the card at the same URL it serves + * JSON-RPC, so the empty path makes the resolver GET the URL as-is). */ export async function createA2AClient(agentUrl: string, apiKey?: string): Promise { const validation = await validateUrlWithDNS(agentUrl, 'agentUrl') @@ -138,18 +140,22 @@ export async function createA2AClient(agentUrl: string, apiKey?: string): Promis const factory = new ClientFactory(factoryOptions) - // Try standard A2A path first (/.well-known/agent.json) - try { - return await factory.createFromUrl(agentUrl, '/.well-known/agent.json') - } catch (standardError) { - logger.debug('Standard agent card path failed, trying root URL', { - agentUrl, - error: toError(standardError).message, - }) + const candidatePaths = [AGENT_CARD_PATH, '/.well-known/agent.json', ''] + let lastError: unknown + for (const path of candidatePaths) { + try { + return await factory.createFromUrl(agentUrl, path) + } catch (error) { + lastError = error + logger.debug('Agent card resolution attempt failed', { + agentUrl, + path, + error: toError(error).message, + }) + } } - // Fall back to root URL (Sim Studio compatibility) - return factory.createFromUrl(agentUrl, '') + throw toError(lastError) } export function isTerminalState(state: TaskState): boolean { diff --git a/apps/sim/lib/api/contracts/a2a-agents.ts b/apps/sim/lib/api/contracts/a2a-agents.ts index 69b48292fb5..9f6d8396b10 100644 --- a/apps/sim/lib/api/contracts/a2a-agents.ts +++ b/apps/sim/lib/api/contracts/a2a-agents.ts @@ -1,4 +1,4 @@ -import type { AgentCapabilities, AgentSkill, Message } from '@a2a-js/sdk' +import type { AgentCapabilities, AgentSkill } from '@a2a-js/sdk' import { isRecordLike } from '@sim/utils/object' import { z } from 'zod' import type { AgentAuthentication } from '@/lib/a2a/types' @@ -90,25 +90,29 @@ export const a2aAgentSchema = z.object({ taskCount: z.number().optional(), }) -export const a2aAgentCardSchema = z.object({ - name: z.string(), - description: z.string(), - url: z.string(), - protocolVersion: z.string(), - version: z.string().optional(), - documentationUrl: z.string().optional(), - provider: z - .object({ - organization: z.string(), - url: z.string().optional(), - }) - .optional(), - capabilities: a2aAgentCapabilitiesSchema, - skills: z.array(a2aAgentSkillSchema), - authentication: a2aAgentAuthenticationSchema.optional(), - defaultInputModes: z.array(z.string()), - defaultOutputModes: z.array(z.string()), -}) +export const a2aAgentCardSchema = z + .object({ + protocolVersion: z.string(), + name: z.string(), + description: z.string(), + url: z.string(), + version: z.string(), + preferredTransport: z.string().optional(), + documentationUrl: z.string().optional(), + provider: z + .object({ + organization: z.string(), + url: z.string().optional(), + }) + .optional(), + capabilities: a2aAgentCapabilitiesSchema, + skills: z.array(a2aAgentSkillSchema), + defaultInputModes: z.array(z.string()), + defaultOutputModes: z.array(z.string()), + }) + // Agent cards carry optional SDK fields (securitySchemes, security, iconUrl, + // additionalInterfaces, signatures) that pass through untouched. + .passthrough() export const listA2AAgentsContract = defineRouteContract({ method: 'GET', @@ -197,43 +201,7 @@ export const a2aServeAgentParamsSchema = z.object({ }) export type A2AServeAgentParams = z.output -export const a2aJsonRpcIdSchema = z.union([z.string(), z.number(), z.null()]) -export type A2AJsonRpcId = z.output - -export const a2aJsonRpcRequestSchema = z - .object({ - jsonrpc: z.literal('2.0'), - id: a2aJsonRpcIdSchema, - method: z.string(), - params: z.unknown().optional(), - }) - .passthrough() -export type A2AJsonRpcRequest = z.output - -export const a2aMessageSendParamsSchema = z - .object({ - message: z.custom(isRecordLike), - }) - .passthrough() -export type A2AMessageSendParams = z.output - -export const a2aTaskIdParamsSchema = z - .object({ - id: z.string().min(1), - historyLength: z.number().optional(), - }) - .passthrough() -export type A2ATaskIdParams = z.output - -export const a2aPushNotificationSetParamsSchema = z - .object({ - id: z.string().min(1), - pushNotificationConfig: z - .object({ - url: z.string().min(1), - token: z.string().optional(), - }) - .passthrough(), - }) - .passthrough() -export type A2APushNotificationSetParams = z.output +/** + * JSON-RPC request/param validation for the A2A serve endpoint is owned by the + * `@a2a-js/sdk` `JsonRpcTransportHandler`, so no boundary schemas live here. + */ diff --git a/apps/sim/tools/a2a/get_agent_card.ts b/apps/sim/tools/a2a/get_agent_card.ts index 6c7e3a3d4ca..f94a340653f 100644 --- a/apps/sim/tools/a2a/get_agent_card.ts +++ b/apps/sim/tools/a2a/get_agent_card.ts @@ -50,6 +50,7 @@ export const a2aGetAgentCardTool: ToolConfig