Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions apps/sim/app/api/a2a/agents/[agentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { a2aAgent, workflow } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, isNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { generateAgentCard, generateSkillsFromWorkflow } from '@/lib/a2a/agent-card'
import type { AgentCapabilities, AgentSkill } from '@/lib/a2a/types'
import { buildAgentCard, generateSkillsFromWorkflow } from '@/lib/a2a/agent-card'
import type { AgentAuthentication, AgentCapabilities, AgentSkill } from '@/lib/a2a/types'
import {
a2aAgentParamsSchema,
publishA2AAgentContract,
Expand All @@ -13,10 +13,12 @@ import {
import { parseRequest } from '@/lib/api/server'
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
import { getRedisClient } from '@/lib/core/config/redis'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
import { captureServerEvent } from '@/lib/posthog/server'
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'
import { getBrandConfig } from '@/ee/whitelabeling'

const logger = createLogger('A2AAgentCardAPI')

Expand Down Expand Up @@ -60,21 +62,23 @@ export const GET = withRouteHandler(
}
}

const agentCard = generateAgentCard(
{
const agentCard = buildAgentCard({
agent: {
id: agent.agent.id,
name: agent.agent.name,
description: agent.agent.description,
version: agent.agent.version,
capabilities: agent.agent.capabilities as AgentCapabilities,
skills: agent.agent.skills as AgentSkill[],
authentication: agent.agent.authentication as AgentAuthentication,
},
{
id: agent.workflow.id,
baseUrl: getBaseUrl(),
providerOrganization: getBrandConfig().name,
workflow: {
name: agent.workflow.name,
description: agent.workflow.description,
}
)
},
})

return NextResponse.json(agentCard, {
headers: {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import type { NextRequest } from 'next/server'
import { NextResponse } from 'next/server'
import { a2aServeAgentParamsSchema } from '@/lib/api/contracts/a2a-agents'
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
import { getServedAgentCard } from '@/app/api/a2a/serve/[agentId]/utils'

export const dynamic = 'force-dynamic'
export const runtime = 'nodejs'

interface RouteParams {
agentId: string
}

/**
* GET - A2A v0.3 well-known discovery endpoint.
*
* Serves the Agent Card at the RFC 8615 context path
* (`/api/a2a/serve/{agentId}/.well-known/agent-card.json`) so standard A2A
* clients that append the well-known path to an agent's base URL can discover
* Sim-hosted agents.
*/
export const GET = withRouteHandler(
async (_request: NextRequest, { params }: { params: Promise<RouteParams> }) => {
const { agentId } = a2aServeAgentParamsSchema.parse(await params)

const result = await getServedAgentCard(agentId)
if (!result.ok) {
return NextResponse.json({ error: result.error }, { status: result.status })
}

return NextResponse.json(result.card, {
headers: {
'Content-Type': 'application/json',
'Cache-Control': 'public, max-age=60',
'X-Cache': result.cacheHit ? 'HIT' : 'MISS',
},
})
}
)
123 changes: 123 additions & 0 deletions apps/sim/app/api/a2a/serve/[agentId]/request-handler.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/**
* @vitest-environment node
*/
import { A2AError } from '@a2a-js/sdk/server'
import { dbChainMock, dbChainMockFns, resetDbChainMock } from '@sim/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/db', () => dbChainMock)
vi.mock('@/lib/a2a/push-notifications', () => ({
notifyTaskStateChange: vi.fn().mockResolvedValue(undefined),
}))
vi.mock('@/lib/execution/cancellation', () => ({
markExecutionCancelled: vi.fn().mockResolvedValue(undefined),
}))
vi.mock('@/lib/core/config/redis', () => ({
acquireLock: vi.fn().mockResolvedValue(true),
releaseLock: vi.fn().mockResolvedValue(undefined),
getRedisClient: vi.fn(() => null),
}))
vi.mock('@/lib/core/security/input-validation.server', () => ({
validateUrlWithDNS: vi.fn().mockResolvedValue({ isValid: true, resolvedIP: '1.2.3.4' }),
secureFetchWithPinnedIP: vi.fn(),
}))
vi.mock('@/lib/auth/internal', () => ({
generateInternalToken: vi.fn().mockResolvedValue('internal-token'),
}))
vi.mock('@/ee/whitelabeling', () => ({
getBrandConfig: () => ({ name: 'Sim' }),
}))

import { buildAgentCard } from '@/lib/a2a/agent-card'
import { SimA2ARequestHandler } from '@/app/api/a2a/serve/[agentId]/request-handler'

const agent = { id: 'agent-1', name: 'A', workflowId: 'wf-1', workspaceId: 'ws-1' }
const agentCard = buildAgentCard({
agent: { id: agent.id, name: agent.name, version: '1.0.0' },
baseUrl: 'https://example.com',
providerOrganization: 'Sim',
})

function makeHandler(callerFingerprint = 'user:u1') {
return new SimA2ARequestHandler({ agent, agentCard, callerFingerprint })
}

function taskRow(overrides: Record<string, unknown> = {}) {
return {
id: 't1',
agentId: 'agent-1',
sessionId: 'ctx-1',
status: 'completed',
messages: [],
artifacts: [],
executionId: null,
metadata: { callerFingerprint: 'user:u1' },
...overrides,
}
}

describe('SimA2ARequestHandler', () => {
beforeEach(() => {
vi.clearAllMocks()
resetDbChainMock()
})

it('getAgentCard returns the configured card', async () => {
await expect(makeHandler().getAgentCard()).resolves.toBe(agentCard)
})

it('getAuthenticatedExtendedAgentCard rejects when not configured', async () => {
await expect(makeHandler().getAuthenticatedExtendedAgentCard()).rejects.toMatchObject({
code: -32007,
})
})

it('getTask returns an SDK Task for an owned task', async () => {
dbChainMockFns.limit.mockResolvedValueOnce([taskRow({ status: 'completed' })])

const task = await makeHandler('user:u1').getTask({ id: 't1' })

expect(task.kind).toBe('task')
expect(task.id).toBe('t1')
expect(task.contextId).toBe('ctx-1')
expect(task.status.state).toBe('completed')
})

it('hides a task owned by a different caller (taskNotFound)', async () => {
dbChainMockFns.limit.mockResolvedValueOnce([
taskRow({ metadata: { callerFingerprint: 'user:someone-else' } }),
])

await expect(makeHandler('user:u1').getTask({ id: 't1' })).rejects.toMatchObject({
code: -32001,
})
})

it('cancelTask rejects a task already in a terminal state (taskNotCancelable)', async () => {
dbChainMockFns.limit.mockResolvedValueOnce([taskRow({ status: 'completed' })])

await expect(makeHandler('user:u1').cancelTask({ id: 't1' })).rejects.toMatchObject({
code: -32002,
})
})

it('cancelTask cancels a running task and returns canceled state', async () => {
dbChainMockFns.limit.mockResolvedValueOnce([
taskRow({ status: 'working', executionId: 'exec-1' }),
])

const task = await makeHandler('user:u1').cancelTask({ id: 't1' })

expect(task.status.state).toBe('canceled')
expect(task.id).toBe('t1')
})

it('all thrown errors are SDK A2AError instances', async () => {
dbChainMockFns.limit.mockResolvedValueOnce([])
const error = await makeHandler('user:u1')
.getTask({ id: 'missing' })
.catch((e) => e)
expect(error).toBeInstanceOf(A2AError)
expect(error.code).toBe(-32001)
})
})
Loading
Loading