Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apps/sim/app/api/a2a/agents/[agentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ export async function GET(request: NextRequest, { params }: { params: Promise<Ro

if (!agent.agent.isPublished) {
const auth = await checkSessionOrInternalAuth(request, { requireWorkflowId: false })
if (!auth.success) {
if (!auth.success || !auth.userId) {
return NextResponse.json({ error: 'Agent not published' }, { status: 404 })
}

const workspaceAccess = await checkWorkspaceAccess(agent.agent.workspaceId, auth.userId)
if (!workspaceAccess.exists || !workspaceAccess.hasAccess) {
return NextResponse.json({ error: 'Agent not published' }, { status: 404 })
}
}
Expand Down
17 changes: 14 additions & 3 deletions apps/sim/app/api/a2a/agents/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { sanitizeAgentName } from '@/lib/a2a/utils'
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils'
import { hasValidStartBlockInState } from '@/lib/workflows/triggers/trigger-utils'
import { getWorkspaceById } from '@/lib/workspaces/permissions/utils'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'

const logger = createLogger('A2AAgentsAPI')

Expand All @@ -39,10 +39,13 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'workspaceId is required' }, { status: 400 })
}

const ws = await getWorkspaceById(workspaceId)
if (!ws) {
const workspaceAccess = await checkWorkspaceAccess(workspaceId, auth.userId)
if (!workspaceAccess.exists) {
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
}
if (!workspaceAccess.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}

const agents = await db
.select({
Expand Down Expand Up @@ -103,6 +106,14 @@ export async function POST(request: NextRequest) {
)
}

const workspaceAccess = await checkWorkspaceAccess(workspaceId, auth.userId)
if (!workspaceAccess.exists) {
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
}
if (!workspaceAccess.canWrite) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}

const [wf] = await db
.select({
id: workflow.id,
Expand Down
121 changes: 101 additions & 20 deletions apps/sim/app/api/a2a/serve/[agentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import {
isTerminalState,
parseWorkflowSSEChunk,
} from '@/lib/a2a/utils'
import { checkHybridAuth } from '@/lib/auth/hybrid'
import { type AuthResult, checkHybridAuth } from '@/lib/auth/hybrid'
import { acquireLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server'
import { SSE_HEADERS } from '@/lib/core/utils/sse'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { markExecutionCancelled } from '@/lib/execution/cancellation'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
import {
A2A_ERROR_CODES,
A2A_METHODS,
Expand Down Expand Up @@ -191,6 +193,9 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R

const authSchemes = (agent.authentication as { schemes?: string[] })?.schemes || []
const requiresAuth = !authSchemes.includes('none')
let authenticatedUserId: string | null = null
let authenticatedAuthType: AuthResult['authType']
let authenticatedApiKeyType: AuthResult['apiKeyType']

if (requiresAuth) {
const auth = await checkHybridAuth(request, { requireWorkflowId: false })
Expand All @@ -200,6 +205,17 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
{ status: 401 }
)
}
authenticatedUserId = auth.userId
authenticatedAuthType = auth.authType
authenticatedApiKeyType = auth.apiKeyType

const workspaceAccess = await checkWorkspaceAccess(agent.workspaceId, authenticatedUserId)
if (!workspaceAccess.exists || !workspaceAccess.hasAccess) {
return NextResponse.json(
createError(null, A2A_ERROR_CODES.AUTHENTICATION_REQUIRED, 'Access denied'),
{ status: 403 }
)
}
}

const [wf] = await db
Expand All @@ -225,34 +241,61 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
}

const { id, method, params: rpcParams } = body
const apiKey = request.headers.get('X-API-Key')
const requestApiKey = request.headers.get('X-API-Key')
const apiKey = authenticatedAuthType === 'api_key' ? requestApiKey : null
const isPersonalApiKeyCaller =
authenticatedAuthType === 'api_key' && authenticatedApiKeyType === 'personal'
const billedUserId = await getWorkspaceBilledAccountUserId(agent.workspaceId)
if (!billedUserId) {
logger.error('Unable to resolve workspace billed account for A2A execution', {
agentId: agent.id,
workspaceId: agent.workspaceId,
})
return NextResponse.json(
createError(
id,
A2A_ERROR_CODES.INTERNAL_ERROR,
'Unable to resolve billing account for this workspace'
),
{ status: 500 }
)
}
const executionUserId =
isPersonalApiKeyCaller && authenticatedUserId ? authenticatedUserId : billedUserId

logger.info(`A2A request: ${method} for agent ${agentId}`)

switch (method) {
case A2A_METHODS.MESSAGE_SEND:
return handleMessageSend(id, agent, rpcParams as MessageSendParams, apiKey)
return handleMessageSend(id, agent, rpcParams as MessageSendParams, apiKey, executionUserId)

case A2A_METHODS.MESSAGE_STREAM:
return handleMessageStream(request, id, agent, rpcParams as MessageSendParams, apiKey)
return handleMessageStream(
request,
id,
agent,
rpcParams as MessageSendParams,
apiKey,
executionUserId
)

case A2A_METHODS.TASKS_GET:
return handleTaskGet(id, rpcParams as TaskIdParams)
return handleTaskGet(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.TASKS_CANCEL:
return handleTaskCancel(id, rpcParams as TaskIdParams)
return handleTaskCancel(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.TASKS_RESUBSCRIBE:
return handleTaskResubscribe(request, id, rpcParams as TaskIdParams)
return handleTaskResubscribe(request, id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.PUSH_NOTIFICATION_SET:
return handlePushNotificationSet(id, rpcParams as PushNotificationSetParams)
return handlePushNotificationSet(id, agent.id, rpcParams as PushNotificationSetParams)

case A2A_METHODS.PUSH_NOTIFICATION_GET:
return handlePushNotificationGet(id, rpcParams as TaskIdParams)
return handlePushNotificationGet(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.PUSH_NOTIFICATION_DELETE:
return handlePushNotificationDelete(id, rpcParams as TaskIdParams)
return handlePushNotificationDelete(id, agent.id, rpcParams as TaskIdParams)

default:
return NextResponse.json(
Expand All @@ -268,6 +311,14 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
}
}

async function getTaskForAgent(taskId: string, agentId: string) {
const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, taskId)).limit(1)
if (!task || task.agentId !== agentId) {
return null
}
return task
}

/**
* Handle message/send - Send a message (v0.3)
*/
Expand All @@ -280,7 +331,8 @@ async function handleMessageSend(
workspaceId: string
},
params: MessageSendParams,
apiKey?: string | null
apiKey?: string | null,
executionUserId?: string
): Promise<NextResponse> {
if (!params?.message) {
return NextResponse.json(
Expand Down Expand Up @@ -318,6 +370,13 @@ async function handleMessageSend(
)
}

if (existingTask.agentId !== agent.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'),
{ status: 404 }
)
}

if (isTerminalState(existingTask.status as TaskState)) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.TASK_ALREADY_COMPLETE, 'Task already in terminal state'),
Expand Down Expand Up @@ -363,6 +422,7 @@ async function handleMessageSend(
} = await buildExecuteRequest({
workflowId: agent.workflowId,
apiKey,
userId: executionUserId,
})

logger.info(`Executing workflow ${agent.workflowId} for A2A task ${taskId}`)
Expand Down Expand Up @@ -475,7 +535,8 @@ async function handleMessageStream(
workspaceId: string
},
params: MessageSendParams,
apiKey?: string | null
apiKey?: string | null,
executionUserId?: string
): Promise<NextResponse> {
if (!params?.message) {
return NextResponse.json(
Expand Down Expand Up @@ -522,6 +583,13 @@ async function handleMessageStream(
})
}

if (existingTask.agentId !== agent.id) {
await releaseLock(lockKey, lockValue)
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
status: 404,
})
}

if (isTerminalState(existingTask.status as TaskState)) {
await releaseLock(lockKey, lockValue)
return NextResponse.json(
Expand Down Expand Up @@ -595,6 +663,7 @@ async function handleMessageStream(
} = await buildExecuteRequest({
workflowId: agent.workflowId,
apiKey,
userId: executionUserId,
stream: true,
})

Expand Down Expand Up @@ -788,7 +857,11 @@ async function handleMessageStream(
/**
* Handle tasks/get - Query task status
*/
async function handleTaskGet(id: string | number, params: TaskIdParams): Promise<NextResponse> {
async function handleTaskGet(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'),
Expand All @@ -801,7 +874,7 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
? params.historyLength
: undefined

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand All @@ -825,15 +898,19 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
/**
* Handle tasks/cancel - Cancel a running task
*/
async function handleTaskCancel(id: string | number, params: TaskIdParams): Promise<NextResponse> {
async function handleTaskCancel(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'),
{ status: 400 }
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -897,6 +974,7 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom
async function handleTaskResubscribe(
request: NextRequest,
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -906,7 +984,7 @@ async function handleTaskResubscribe(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1103,6 +1181,7 @@ async function handleTaskResubscribe(
*/
async function handlePushNotificationSet(
id: string | number,
agentId: string,
params: PushNotificationSetParams
): Promise<NextResponse> {
if (!params?.id) {
Expand Down Expand Up @@ -1130,7 +1209,7 @@ async function handlePushNotificationSet(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1181,6 +1260,7 @@ async function handlePushNotificationSet(
*/
async function handlePushNotificationGet(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -1190,7 +1270,7 @@ async function handlePushNotificationGet(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1224,6 +1304,7 @@ async function handlePushNotificationGet(
*/
async function handlePushNotificationDelete(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -1233,7 +1314,7 @@ async function handlePushNotificationDelete(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down
3 changes: 2 additions & 1 deletion apps/sim/app/api/a2a/serve/[agentId]/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export function formatTaskResponse(task: Task, historyLength?: number): Task {
export interface ExecuteRequestConfig {
workflowId: string
apiKey?: string | null
userId?: string
stream?: boolean
}

Expand All @@ -124,7 +125,7 @@ export async function buildExecuteRequest(
if (config.apiKey) {
headers['X-API-Key'] = config.apiKey
} else {
const internalToken = await generateInternalToken()
const internalToken = await generateInternalToken(config.userId)
headers.Authorization = `Bearer ${internalToken}`
useInternalAuth = true
}
Expand Down
Loading