@@ -13,12 +13,14 @@ import {
1313 isTerminalState ,
1414 parseWorkflowSSEChunk ,
1515} from '@/lib/a2a/utils'
16- import { checkHybridAuth } from '@/lib/auth/hybrid'
16+ import { type AuthResult , checkHybridAuth } from '@/lib/auth/hybrid'
1717import { acquireLock , getRedisClient , releaseLock } from '@/lib/core/config/redis'
1818import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server'
1919import { SSE_HEADERS } from '@/lib/core/utils/sse'
2020import { getBaseUrl } from '@/lib/core/utils/urls'
2121import { markExecutionCancelled } from '@/lib/execution/cancellation'
22+ import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'
23+ import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
2224import {
2325 A2A_ERROR_CODES ,
2426 A2A_METHODS ,
@@ -191,6 +193,9 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
191193
192194 const authSchemes = ( agent . authentication as { schemes ?: string [ ] } ) ?. schemes || [ ]
193195 const requiresAuth = ! authSchemes . includes ( 'none' )
196+ let authenticatedUserId : string | null = null
197+ let authenticatedAuthType : AuthResult [ 'authType' ]
198+ let authenticatedApiKeyType : AuthResult [ 'apiKeyType' ]
194199
195200 if ( requiresAuth ) {
196201 const auth = await checkHybridAuth ( request , { requireWorkflowId : false } )
@@ -200,6 +205,17 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
200205 { status : 401 }
201206 )
202207 }
208+ authenticatedUserId = auth . userId
209+ authenticatedAuthType = auth . authType
210+ authenticatedApiKeyType = auth . apiKeyType
211+
212+ const workspaceAccess = await checkWorkspaceAccess ( agent . workspaceId , authenticatedUserId )
213+ if ( ! workspaceAccess . exists || ! workspaceAccess . hasAccess ) {
214+ return NextResponse . json (
215+ createError ( null , A2A_ERROR_CODES . AUTHENTICATION_REQUIRED , 'Access denied' ) ,
216+ { status : 403 }
217+ )
218+ }
203219 }
204220
205221 const [ wf ] = await db
@@ -225,34 +241,61 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
225241 }
226242
227243 const { id, method, params : rpcParams } = body
228- const apiKey = request . headers . get ( 'X-API-Key' )
244+ const requestApiKey = request . headers . get ( 'X-API-Key' )
245+ const apiKey = authenticatedAuthType === 'api_key' ? requestApiKey : null
246+ const isPersonalApiKeyCaller =
247+ authenticatedAuthType === 'api_key' && authenticatedApiKeyType === 'personal'
248+ const billedUserId = await getWorkspaceBilledAccountUserId ( agent . workspaceId )
249+ if ( ! billedUserId ) {
250+ logger . error ( 'Unable to resolve workspace billed account for A2A execution' , {
251+ agentId : agent . id ,
252+ workspaceId : agent . workspaceId ,
253+ } )
254+ return NextResponse . json (
255+ createError (
256+ id ,
257+ A2A_ERROR_CODES . INTERNAL_ERROR ,
258+ 'Unable to resolve billing account for this workspace'
259+ ) ,
260+ { status : 500 }
261+ )
262+ }
263+ const executionUserId =
264+ isPersonalApiKeyCaller && authenticatedUserId ? authenticatedUserId : billedUserId
229265
230266 logger . info ( `A2A request: ${ method } for agent ${ agentId } ` )
231267
232268 switch ( method ) {
233269 case A2A_METHODS . MESSAGE_SEND :
234- return handleMessageSend ( id , agent , rpcParams as MessageSendParams , apiKey )
270+ return handleMessageSend ( id , agent , rpcParams as MessageSendParams , apiKey , executionUserId )
235271
236272 case A2A_METHODS . MESSAGE_STREAM :
237- return handleMessageStream ( request , id , agent , rpcParams as MessageSendParams , apiKey )
273+ return handleMessageStream (
274+ request ,
275+ id ,
276+ agent ,
277+ rpcParams as MessageSendParams ,
278+ apiKey ,
279+ executionUserId
280+ )
238281
239282 case A2A_METHODS . TASKS_GET :
240- return handleTaskGet ( id , rpcParams as TaskIdParams )
283+ return handleTaskGet ( id , agent . id , rpcParams as TaskIdParams )
241284
242285 case A2A_METHODS . TASKS_CANCEL :
243- return handleTaskCancel ( id , rpcParams as TaskIdParams )
286+ return handleTaskCancel ( id , agent . id , rpcParams as TaskIdParams )
244287
245288 case A2A_METHODS . TASKS_RESUBSCRIBE :
246- return handleTaskResubscribe ( request , id , rpcParams as TaskIdParams )
289+ return handleTaskResubscribe ( request , id , agent . id , rpcParams as TaskIdParams )
247290
248291 case A2A_METHODS . PUSH_NOTIFICATION_SET :
249- return handlePushNotificationSet ( id , rpcParams as PushNotificationSetParams )
292+ return handlePushNotificationSet ( id , agent . id , rpcParams as PushNotificationSetParams )
250293
251294 case A2A_METHODS . PUSH_NOTIFICATION_GET :
252- return handlePushNotificationGet ( id , rpcParams as TaskIdParams )
295+ return handlePushNotificationGet ( id , agent . id , rpcParams as TaskIdParams )
253296
254297 case A2A_METHODS . PUSH_NOTIFICATION_DELETE :
255- return handlePushNotificationDelete ( id , rpcParams as TaskIdParams )
298+ return handlePushNotificationDelete ( id , agent . id , rpcParams as TaskIdParams )
256299
257300 default :
258301 return NextResponse . json (
@@ -268,6 +311,14 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
268311 }
269312}
270313
314+ async function getTaskForAgent ( taskId : string , agentId : string ) {
315+ const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , taskId ) ) . limit ( 1 )
316+ if ( ! task || task . agentId !== agentId ) {
317+ return null
318+ }
319+ return task
320+ }
321+
271322/**
272323 * Handle message/send - Send a message (v0.3)
273324 */
@@ -280,7 +331,8 @@ async function handleMessageSend(
280331 workspaceId : string
281332 } ,
282333 params : MessageSendParams ,
283- apiKey ?: string | null
334+ apiKey ?: string | null ,
335+ executionUserId ?: string
284336) : Promise < NextResponse > {
285337 if ( ! params ?. message ) {
286338 return NextResponse . json (
@@ -318,6 +370,13 @@ async function handleMessageSend(
318370 )
319371 }
320372
373+ if ( existingTask . agentId !== agent . id ) {
374+ return NextResponse . json (
375+ createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) ,
376+ { status : 404 }
377+ )
378+ }
379+
321380 if ( isTerminalState ( existingTask . status as TaskState ) ) {
322381 return NextResponse . json (
323382 createError ( id , A2A_ERROR_CODES . TASK_ALREADY_COMPLETE , 'Task already in terminal state' ) ,
@@ -363,6 +422,7 @@ async function handleMessageSend(
363422 } = await buildExecuteRequest ( {
364423 workflowId : agent . workflowId ,
365424 apiKey,
425+ userId : executionUserId ,
366426 } )
367427
368428 logger . info ( `Executing workflow ${ agent . workflowId } for A2A task ${ taskId } ` )
@@ -475,7 +535,8 @@ async function handleMessageStream(
475535 workspaceId : string
476536 } ,
477537 params : MessageSendParams ,
478- apiKey ?: string | null
538+ apiKey ?: string | null ,
539+ executionUserId ?: string
479540) : Promise < NextResponse > {
480541 if ( ! params ?. message ) {
481542 return NextResponse . json (
@@ -522,6 +583,13 @@ async function handleMessageStream(
522583 } )
523584 }
524585
586+ if ( existingTask . agentId !== agent . id ) {
587+ await releaseLock ( lockKey , lockValue )
588+ return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
589+ status : 404 ,
590+ } )
591+ }
592+
525593 if ( isTerminalState ( existingTask . status as TaskState ) ) {
526594 await releaseLock ( lockKey , lockValue )
527595 return NextResponse . json (
@@ -595,6 +663,7 @@ async function handleMessageStream(
595663 } = await buildExecuteRequest ( {
596664 workflowId : agent . workflowId ,
597665 apiKey,
666+ userId : executionUserId ,
598667 stream : true ,
599668 } )
600669
@@ -788,7 +857,11 @@ async function handleMessageStream(
788857/**
789858 * Handle tasks/get - Query task status
790859 */
791- async function handleTaskGet ( id : string | number , params : TaskIdParams ) : Promise < NextResponse > {
860+ async function handleTaskGet (
861+ id : string | number ,
862+ agentId : string ,
863+ params : TaskIdParams
864+ ) : Promise < NextResponse > {
792865 if ( ! params ?. id ) {
793866 return NextResponse . json (
794867 createError ( id , A2A_ERROR_CODES . INVALID_PARAMS , 'Task ID is required' ) ,
@@ -801,7 +874,7 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
801874 ? params . historyLength
802875 : undefined
803876
804- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
877+ const task = await getTaskForAgent ( params . id , agentId )
805878
806879 if ( ! task ) {
807880 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
@@ -825,15 +898,19 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
825898/**
826899 * Handle tasks/cancel - Cancel a running task
827900 */
828- async function handleTaskCancel ( id : string | number , params : TaskIdParams ) : Promise < NextResponse > {
901+ async function handleTaskCancel (
902+ id : string | number ,
903+ agentId : string ,
904+ params : TaskIdParams
905+ ) : Promise < NextResponse > {
829906 if ( ! params ?. id ) {
830907 return NextResponse . json (
831908 createError ( id , A2A_ERROR_CODES . INVALID_PARAMS , 'Task ID is required' ) ,
832909 { status : 400 }
833910 )
834911 }
835912
836- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
913+ const task = await getTaskForAgent ( params . id , agentId )
837914
838915 if ( ! task ) {
839916 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
@@ -897,6 +974,7 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom
897974async function handleTaskResubscribe (
898975 request : NextRequest ,
899976 id : string | number ,
977+ agentId : string ,
900978 params : TaskIdParams
901979) : Promise < NextResponse > {
902980 if ( ! params ?. id ) {
@@ -906,7 +984,7 @@ async function handleTaskResubscribe(
906984 )
907985 }
908986
909- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
987+ const task = await getTaskForAgent ( params . id , agentId )
910988
911989 if ( ! task ) {
912990 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
@@ -1103,6 +1181,7 @@ async function handleTaskResubscribe(
11031181 */
11041182async function handlePushNotificationSet (
11051183 id : string | number ,
1184+ agentId : string ,
11061185 params : PushNotificationSetParams
11071186) : Promise < NextResponse > {
11081187 if ( ! params ?. id ) {
@@ -1130,7 +1209,7 @@ async function handlePushNotificationSet(
11301209 )
11311210 }
11321211
1133- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
1212+ const task = await getTaskForAgent ( params . id , agentId )
11341213
11351214 if ( ! task ) {
11361215 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
@@ -1181,6 +1260,7 @@ async function handlePushNotificationSet(
11811260 */
11821261async function handlePushNotificationGet (
11831262 id : string | number ,
1263+ agentId : string ,
11841264 params : TaskIdParams
11851265) : Promise < NextResponse > {
11861266 if ( ! params ?. id ) {
@@ -1190,7 +1270,7 @@ async function handlePushNotificationGet(
11901270 )
11911271 }
11921272
1193- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
1273+ const task = await getTaskForAgent ( params . id , agentId )
11941274
11951275 if ( ! task ) {
11961276 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
@@ -1224,6 +1304,7 @@ async function handlePushNotificationGet(
12241304 */
12251305async function handlePushNotificationDelete (
12261306 id : string | number ,
1307+ agentId : string ,
12271308 params : TaskIdParams
12281309) : Promise < NextResponse > {
12291310 if ( ! params ?. id ) {
@@ -1233,7 +1314,7 @@ async function handlePushNotificationDelete(
12331314 )
12341315 }
12351316
1236- const [ task ] = await db . select ( ) . from ( a2aTask ) . where ( eq ( a2aTask . id , params . id ) ) . limit ( 1 )
1317+ const task = await getTaskForAgent ( params . id , agentId )
12371318
12381319 if ( ! task ) {
12391320 return NextResponse . json ( createError ( id , A2A_ERROR_CODES . TASK_NOT_FOUND , 'Task not found' ) , {
0 commit comments