diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 2023fc7ac..5567b5123 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -495,7 +495,17 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte pushConfigStore.setInfo(createdTask.id(), params.configuration().pushNotificationConfig()); } - if (blocking && interruptedOrNonBlocking) { + // Check if task requires immediate return (AUTH_REQUIRED) + // AUTH_REQUIRED expects the client to receive it immediately and handle it out-of-band, + // while the agent continues executing in the background + boolean requiresImmediateReturn = kind instanceof Task task && + task.status().state() == io.a2a.spec.TaskState.TASK_STATE_AUTH_REQUIRED; + if (requiresImmediateReturn) { + LOGGER.debug("DefaultRequestHandler: Task {} in AUTH_REQUIRED state, skipping fire-and-forget handling", + taskId.get()); + } + + if (blocking && interruptedOrNonBlocking && !requiresImmediateReturn) { // For blocking calls: ensure all consumed events are persisted to TaskStore before returning // Order of operations is critical to avoid circular dependency and race conditions: // 1. Wait for agent to finish enqueueing events (or timeout) diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index e69de29bb..00db599f2 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -0,0 +1,512 @@ +package io.a2a.server.requesthandlers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import io.a2a.server.ServerCallContext; +import io.a2a.server.agentexecution.AgentExecutor; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.server.tasks.AgentEmitter; +import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationConfigStore; +import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.A2AError; +import io.a2a.spec.Event; +import io.a2a.spec.EventKind; +import io.a2a.spec.Message; +import io.a2a.spec.MessageSendConfiguration; +import io.a2a.spec.MessageSendParams; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Integration tests for DefaultRequestHandler focusing on AUTH_REQUIRED workflow. + * Tests verify the special interrupt behavior where AUTH_REQUIRED tasks: + * 1. Return immediately to the client + * 2. Continue agent execution in background + * 3. Keep queues open for late events + * 4. Perform async cleanup + */ +public class DefaultRequestHandlerTest { + + private static final MessageSendConfiguration DEFAULT_CONFIG = MessageSendConfiguration.builder() + .blocking(false) + .acceptedOutputModes(List.of()) + .build(); + + private static final ServerCallContext NULL_CONTEXT = null; + + private static final Message MESSAGE = Message.builder() + .messageId("111") + .role(Message.Role.ROLE_AGENT) + .parts(new TextPart("test message")) + .build(); + + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + + // Test infrastructure components + protected AgentExecutor executor; + protected TaskStore taskStore; + protected RequestHandler requestHandler; + protected InMemoryQueueManager queueManager; + protected MainEventBus mainEventBus; + protected MainEventBusProcessor mainEventBusProcessor; + protected AgentExecutorMethod agentExecutorExecute; + protected AgentExecutorMethod agentExecutorCancel; + + protected final Executor internalExecutor = Executors.newCachedThreadPool(); + + @BeforeEach + public void init() { + // Create test AgentExecutor with mocked execute/cancel methods + executor = new AgentExecutor() { + @Override + public void execute(RequestContext context, AgentEmitter agentEmitter) throws A2AError { + if (agentExecutorExecute != null) { + agentExecutorExecute.invoke(context, agentEmitter); + } + } + + @Override + public void cancel(RequestContext context, AgentEmitter agentEmitter) throws A2AError { + if (agentExecutorCancel != null) { + agentExecutorCancel.invoke(context, agentEmitter); + } + } + }; + + // Set up infrastructure + InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore(); + taskStore = inMemoryTaskStore; + + PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); + + // Create MainEventBus and MainEventBusProcessor + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + // Create DefaultRequestHandler + requestHandler = DefaultRequestHandler.create( + executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); + } + + @AfterEach + public void cleanup() { + agentExecutorExecute = null; + agentExecutorCancel = null; + + // Stop MainEventBusProcessor background thread + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Functional interface for test agent executor methods. + */ + protected interface AgentExecutorMethod { + void invoke(RequestContext context, AgentEmitter agentEmitter) throws A2AError; + } + + /** + * Test 1: Non-streaming AUTH_REQUIRED returns immediately while agent continues. + * Verifies: + * - Task returned immediately with AUTH_REQUIRED state + * - Agent still running in background (not blocked) + * - TaskStore persisted AUTH_REQUIRED state + * - Agent completes after release + * - Final state persisted to TaskStore + */ + @Test + void testAuthRequired_NonStreaming_ReturnsImmediately() throws Exception { + // Arrange: Set up agent that emits AUTH_REQUIRED then waits + CountDownLatch authRequiredEmitted = new CountDownLatch(1); + CountDownLatch continueAgent = new CountDownLatch(1); + + agentExecutorExecute = (context, emitter) -> { + // Emit AUTH_REQUIRED - client should receive immediately + emitter.requiresAuth(Message.builder() + .role(Message.Role.ROLE_AGENT) + .parts(new TextPart("Please authenticate with OAuth provider")) + .build()); + authRequiredEmitted.countDown(); + + // Agent continues processing (simulating waiting for out-of-band auth) + try { + continueAgent.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Complete after "auth received" + emitter.complete(); + }; + + // Create MessageSendParams + MessageSendParams params = MessageSendParams.builder() + .message(MESSAGE) + .configuration(DEFAULT_CONFIG) + .build(); + + // Act: Send message (non-streaming) + EventKind eventKind = requestHandler.onMessageSend(params, NULL_CONTEXT); + + // Assert: Task returned immediately with AUTH_REQUIRED state + assertNotNull(eventKind, "Result should not be null"); + assertInstanceOf(Task.class, eventKind, "Result should be a Task"); + Task result = (Task) eventKind; + + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, result.status().state(), + "Task should be in AUTH_REQUIRED state"); + assertTrue(authRequiredEmitted.await(2, TimeUnit.SECONDS), + "AUTH_REQUIRED should be emitted quickly"); + + // Verify agent still running (continueAgent latch not counted down yet) + assertFalse(continueAgent.await(100, TimeUnit.MILLISECONDS), + "Agent should still be waiting (not completed yet)"); + + // Verify TaskStore has AUTH_REQUIRED state + Task storedTask = taskStore.get(result.id()); + assertNotNull(storedTask, "Task should be persisted in TaskStore"); + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, storedTask.status().state(), + "TaskStore should have AUTH_REQUIRED state"); + + // Release agent to complete + continueAgent.countDown(); + + // Wait for completion and verify final state + Thread.sleep(1000); // Allow time for completion to process through MainEventBus + Task finalTask = taskStore.get(result.id()); + assertEquals(TaskState.TASK_STATE_COMPLETED, finalTask.status().state(), + "TaskStore should have COMPLETED state after agent finishes"); + } + + /** + * Test 2: Queue remains open after AUTH_REQUIRED for late events. + * Verifies: + * - Queue stays open after AUTH_REQUIRED response + * - Can tap into queue after AUTH_REQUIRED + * - Late artifacts arrive on tapped queue + * - Completion event arrives on tapped queue + */ + @Test + void testAuthRequired_QueueRemainsOpen() throws Exception { + // Arrange: Agent emits AUTH_REQUIRED then continues with late events + CountDownLatch authEmitted = new CountDownLatch(1); + CountDownLatch continueAgent = new CountDownLatch(1); + + agentExecutorExecute = (context, emitter) -> { + // Emit AUTH_REQUIRED + emitter.requiresAuth(Message.builder() + .role(Message.Role.ROLE_AGENT) + .parts(new TextPart("Authenticate required")) + .build()); + authEmitted.countDown(); + + // Wait for test to tap queue + try { + continueAgent.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Emit late artifact after AUTH_REQUIRED + emitter.addArtifact(List.of(new TextPart("Late artifact after auth"))); + emitter.complete(); + }; + + // Create MessageSendParams + MessageSendParams params = MessageSendParams.builder() + .message(MESSAGE) + .configuration(DEFAULT_CONFIG) + .build(); + + // Act: Send message, get AUTH_REQUIRED response + EventKind eventKind = requestHandler.onMessageSend(params, NULL_CONTEXT); + assertInstanceOf(Task.class, eventKind); + Task task = (Task) eventKind; + + assertTrue(authEmitted.await(2, TimeUnit.SECONDS), + "AUTH_REQUIRED should be emitted"); + + // Tap into the queue (simulates client resubscription after AUTH_REQUIRED) + EventQueue tappedQueue = queueManager.tap(task.id()); + assertNotNull(tappedQueue, "Queue should remain open after AUTH_REQUIRED"); + + // Release agent to continue and emit late events + continueAgent.countDown(); + + // Assert: Late events arrive on tapped queue + + // First event should be the late artifact + EventQueueItem item = tappedQueue.dequeueEventItem(5000); + assertNotNull(item, "Should receive late artifact event"); + Event event = item.getEvent(); + assertInstanceOf(TaskArtifactUpdateEvent.class, event, + "First event should be TaskArtifactUpdateEvent"); + + // Second event should be completion + item = tappedQueue.dequeueEventItem(5000); + assertNotNull(item, "Should receive completion event"); + event = item.getEvent(); + assertInstanceOf(TaskStatusUpdateEvent.class, event, + "Second event should be TaskStatusUpdateEvent"); + assertEquals(TaskState.TASK_STATE_COMPLETED, + ((TaskStatusUpdateEvent) event).status().state(), + "Task should be completed"); + } + + /** + * Test 3: TaskStore persistence through AUTH_REQUIRED lifecycle. + * Verifies: + * - AUTH_REQUIRED state persisted correctly + * - State transitions persisted (AUTH_REQUIRED → WORKING → COMPLETED) + * - TaskStore always reflects current state + */ + @Test + void testAuthRequired_TaskStorePersistence() throws Exception { + // Arrange: Agent emits AUTH_REQUIRED, then WORKING, then COMPLETED + CountDownLatch authEmitted = new CountDownLatch(1); + CountDownLatch continueAgent = new CountDownLatch(1); + + agentExecutorExecute = (context, emitter) -> { + // Emit AUTH_REQUIRED + emitter.requiresAuth(); + authEmitted.countDown(); + + // Wait for test to verify AUTH_REQUIRED persisted + try { + continueAgent.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Continue working (simulating auth received out-of-band) + emitter.startWork(); + + // Complete the task + emitter.complete(); + }; + + // Create MessageSendParams + MessageSendParams params = MessageSendParams.builder() + .message(MESSAGE) + .configuration(DEFAULT_CONFIG) + .build(); + + // Act: Send message + EventKind eventKind = requestHandler.onMessageSend(params, NULL_CONTEXT); + assertInstanceOf(Task.class, eventKind); + Task task = (Task) eventKind; + + assertTrue(authEmitted.await(2, TimeUnit.SECONDS), + "AUTH_REQUIRED should be emitted"); + + // Assert: Verify AUTH_REQUIRED state persisted + Task storedTask1 = taskStore.get(task.id()); + assertNotNull(storedTask1, "Task should be in TaskStore"); + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, storedTask1.status().state(), + "TaskStore should have AUTH_REQUIRED state"); + + // Release agent to continue + continueAgent.countDown(); + + // Wait for state transitions to process + Thread.sleep(1000); + + // Verify WORKING state persisted + Task storedTask2 = taskStore.get(task.id()); + // Note: WORKING might be skipped if processing is fast, so we accept either WORKING or COMPLETED + TaskState state2 = storedTask2.status().state(); + assertTrue(state2 == TaskState.TASK_STATE_WORKING || state2 == TaskState.TASK_STATE_COMPLETED, + "TaskStore should have WORKING or COMPLETED state"); + + // Wait a bit more and verify final COMPLETED state + Thread.sleep(500); + Task storedTask3 = taskStore.get(task.id()); + assertEquals(TaskState.TASK_STATE_COMPLETED, storedTask3.status().state(), + "TaskStore should have COMPLETED state after agent finishes"); + } + + /** + * Test 4: Streaming with AUTH_REQUIRED continues in background. + * Verifies: + * - Client receives AUTH_REQUIRED in stream + * - Agent continues emitting artifacts after AUTH_REQUIRED + * - Artifacts stream to client + * - Completion event arrives in stream + */ + @Test + void testAuthRequired_Streaming_ContinuesInBackground() throws Exception { + // Arrange: Agent emits AUTH_REQUIRED, then streams artifacts + CountDownLatch authEmitted = new CountDownLatch(1); + CountDownLatch continueAgent = new CountDownLatch(1); + + agentExecutorExecute = (context, emitter) -> { + // Emit AUTH_REQUIRED + emitter.requiresAuth(); + authEmitted.countDown(); + + // Wait briefly (simulating auth happening out-of-band) + try { + continueAgent.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Continue streaming artifacts + emitter.addArtifact(List.of(new TextPart("Artifact 1"))); + emitter.addArtifact(List.of(new TextPart("Artifact 2"))); + emitter.complete(); + }; + + // Create MessageSendParams + MessageSendParams params = MessageSendParams.builder() + .message(MESSAGE) + .configuration(DEFAULT_CONFIG) + .build(); + + // Act: Send message with streaming enabled + EventKind eventKind = requestHandler.onMessageSend(params, NULL_CONTEXT); + assertInstanceOf(Task.class, eventKind); + Task result = (Task) eventKind; + + assertTrue(authEmitted.await(2, TimeUnit.SECONDS), + "AUTH_REQUIRED should be emitted"); + + // Verify AUTH_REQUIRED received + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, result.status().state(), + "Should receive AUTH_REQUIRED state"); + + // Tap queue to receive subsequent events + EventQueue tappedQueue = queueManager.tap(result.id()); + + // Release agent to continue streaming + continueAgent.countDown(); + + // Assert: Verify artifacts stream through + EventQueueItem item1 = tappedQueue.dequeueEventItem(5000); + assertNotNull(item1, "Should receive first artifact"); + assertInstanceOf(TaskArtifactUpdateEvent.class, item1.getEvent()); + + EventQueueItem item2 = tappedQueue.dequeueEventItem(5000); + assertNotNull(item2, "Should receive second artifact"); + assertInstanceOf(TaskArtifactUpdateEvent.class, item2.getEvent()); + + // Verify completion arrives + EventQueueItem completionItem = tappedQueue.dequeueEventItem(5000); + assertNotNull(completionItem, "Should receive completion"); + Event completionEvent = completionItem.getEvent(); + assertInstanceOf(TaskStatusUpdateEvent.class, completionEvent); + assertEquals(TaskState.TASK_STATE_COMPLETED, + ((TaskStatusUpdateEvent) completionEvent).status().state()); + } + + /** + * Test 5: Resubscription after AUTH_REQUIRED works correctly. + * Verifies: + * - Queue stays open after AUTH_REQUIRED and client disconnect + * - Can resubscribe (tap) after AUTH_REQUIRED + * - Late events received on resubscribed queue + * - Completion event arrives on resubscribed queue + */ + @Test + void testAuthRequired_Resubscription() throws Exception { + // Arrange: Agent emits AUTH_REQUIRED, simulates client disconnect, then continues + CountDownLatch authEmitted = new CountDownLatch(1); + CountDownLatch continueAgent = new CountDownLatch(1); + + agentExecutorExecute = (context, emitter) -> { + // Emit AUTH_REQUIRED + emitter.requiresAuth(); + authEmitted.countDown(); + + // Wait for test to simulate disconnect and resubscribe + try { + continueAgent.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + // Emit late events after "client reconnect" + emitter.addArtifact(List.of(new TextPart("Event after reconnect"))); + emitter.complete(); + }; + + // Create MessageSendParams + MessageSendParams params = MessageSendParams.builder() + .message(MESSAGE) + .configuration(DEFAULT_CONFIG) + .build(); + + // Act: Send message, get AUTH_REQUIRED + EventKind eventKind = requestHandler.onMessageSend(params, NULL_CONTEXT); + assertInstanceOf(Task.class, eventKind); + Task task = (Task) eventKind; + + assertTrue(authEmitted.await(2, TimeUnit.SECONDS), + "AUTH_REQUIRED should be emitted"); + + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, task.status().state(), + "Should receive AUTH_REQUIRED state"); + + // Simulate client disconnect by just waiting + Thread.sleep(100); + + // Client reconnects: tap into queue (resubscription) + EventQueue resubscribedQueue = queueManager.tap(task.id()); + assertNotNull(resubscribedQueue, + "Should be able to resubscribe after AUTH_REQUIRED"); + + // Release agent to continue + continueAgent.countDown(); + + // Assert: Late events arrive on resubscribed queue + EventQueueItem item = resubscribedQueue.dequeueEventItem(5000); + assertNotNull(item, "Should receive late artifact on resubscribed queue"); + assertInstanceOf(TaskArtifactUpdateEvent.class, item.getEvent(), + "Should receive artifact update event"); + + // Verify completion arrives + EventQueueItem completionItem = resubscribedQueue.dequeueEventItem(5000); + assertNotNull(completionItem, "Should receive completion event"); + Event completionEvent = completionItem.getEvent(); + assertInstanceOf(TaskStatusUpdateEvent.class, completionEvent, + "Should receive status update event"); + assertEquals(TaskState.TASK_STATE_COMPLETED, + ((TaskStatusUpdateEvent) completionEvent).status().state(), + "Task should be completed"); + } +} diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index fa4eb8ef5..a347d9cd0 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -1,6 +1,7 @@ package io.a2a.server.tasks; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.atMost; @@ -27,6 +28,7 @@ import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import org.junit.jupiter.api.BeforeEach; @@ -278,4 +280,185 @@ void testConsumeAndBreakNonBlocking() throws Exception { // Cleanup: stop the processor EventQueueUtil.stop(processor); } + + // AUTH_REQUIRED Tests + + @Test + void testConsumeAndBreakOnAuthRequired_Blocking() throws Exception { + // Test that AUTH_REQUIRED with blocking=true sets interrupted=true and continues consumption in background + String taskId = "auth-required-blocking-task"; + Task authRequiredTask = createSampleTask(taskId, TaskState.TASK_STATE_AUTH_REQUIRED, "ctx1"); + + when(mockTaskManager.getTask()).thenReturn(authRequiredTask); + + // Create event queue infrastructure + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + InMemoryQueueManager queueManager = + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); + + try { + // Enqueue AUTH_REQUIRED task using callback pattern + waitForEventProcessing(processor, () -> queue.enqueueEvent(authRequiredTask)); + + // Create EventConsumer + EventConsumer eventConsumer = new EventConsumer(queue); + + // Call consumeAndBreakOnInterrupt with blocking=true + ResultAggregator.EventTypeAndInterrupt result = + aggregator.consumeAndBreakOnInterrupt(eventConsumer, true); + + // Assert: interrupted=true for AUTH_REQUIRED + assertTrue(result.interrupted(), "AUTH_REQUIRED should trigger interrupt in blocking mode"); + assertEquals(authRequiredTask, result.eventType(), "Event type should be the AUTH_REQUIRED task"); + + // Verify consumption continues in background (consumptionFuture should be running) + // For blocking mode, the consumption future should complete after processing + assertNotNull(result, "Result should not be null"); + } finally { + queue.close(); + EventQueueUtil.stop(processor); + } + } + + @Test + void testConsumeAndBreakOnAuthRequired_NonBlocking() throws Exception { + // Test that AUTH_REQUIRED with blocking=false sets interrupted=true and completes immediately + String taskId = "auth-required-nonblocking-task"; + Task authRequiredTask = createSampleTask(taskId, TaskState.TASK_STATE_AUTH_REQUIRED, "ctx1"); + + when(mockTaskManager.getTask()).thenReturn(authRequiredTask); + + // Create event queue infrastructure + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + InMemoryQueueManager queueManager = + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); + + try { + // Enqueue AUTH_REQUIRED task + waitForEventProcessing(processor, () -> queue.enqueueEvent(authRequiredTask)); + + // Create EventConsumer + EventConsumer eventConsumer = new EventConsumer(queue); + + // Call consumeAndBreakOnInterrupt with blocking=false + ResultAggregator.EventTypeAndInterrupt result = + aggregator.consumeAndBreakOnInterrupt(eventConsumer, false); + + // Assert: interrupted=true for AUTH_REQUIRED + assertTrue(result.interrupted(), "AUTH_REQUIRED should trigger interrupt in non-blocking mode"); + assertEquals(authRequiredTask, result.eventType(), "Event type should be the AUTH_REQUIRED task"); + + // For non-blocking mode, consumption should complete immediately + assertNotNull(result, "Result should not be null"); + } finally { + queue.close(); + EventQueueUtil.stop(processor); + } + } + + @Test + void testAuthRequiredWithTaskStatusUpdateEvent() throws Exception { + // Test that TaskStatusUpdateEvent with AUTH_REQUIRED state triggers same interrupt behavior + String taskId = "auth-required-status-update-task"; + TaskStatusUpdateEvent authRequiredEvent = new TaskStatusUpdateEvent( + taskId, + new TaskStatus(TaskState.TASK_STATE_AUTH_REQUIRED), + "ctx1", + false, // isFinal=false for AUTH_REQUIRED + null + ); + + Task authRequiredTask = createSampleTask(taskId, TaskState.TASK_STATE_AUTH_REQUIRED, "ctx1"); + when(mockTaskManager.getTask()).thenReturn(authRequiredTask); + + // Create event queue infrastructure + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + InMemoryQueueManager queueManager = + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); + + try { + // Enqueue TaskStatusUpdateEvent + waitForEventProcessing(processor, () -> queue.enqueueEvent(authRequiredEvent)); + + // Create EventConsumer + EventConsumer eventConsumer = new EventConsumer(queue); + + // Call consumeAndBreakOnInterrupt + ResultAggregator.EventTypeAndInterrupt result = + aggregator.consumeAndBreakOnInterrupt(eventConsumer, true); + + // Assert: interrupted=true for AUTH_REQUIRED (TaskStatusUpdateEvent) + assertTrue(result.interrupted(), "AUTH_REQUIRED via TaskStatusUpdateEvent should trigger interrupt"); + + // Note: ResultAggregator returns a Task (from getCurrentResult or capturedTask), + // not the TaskStatusUpdateEvent itself. The event triggers interrupt behavior, + // but the returned eventType is a Task. + assertNotNull(result.eventType(), "Result should have an event type"); + assertTrue(result.eventType() instanceof Task, "Event type should be a Task"); + Task resultTask = (Task) result.eventType(); + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, resultTask.status().state(), + "Task should have AUTH_REQUIRED state"); + } finally { + queue.close(); + EventQueueUtil.stop(processor); + } + } + + @Test + void testAuthRequiredWithTaskEvent() throws Exception { + // Test that Task event with AUTH_REQUIRED state triggers interrupt correctly + String taskId = "auth-required-task-event"; + Task authRequiredTask = createSampleTask(taskId, TaskState.TASK_STATE_AUTH_REQUIRED, "ctx1"); + + when(mockTaskManager.getTask()).thenReturn(authRequiredTask); + + // Create event queue infrastructure + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + InMemoryQueueManager queueManager = + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); + + try { + // Enqueue Task event with AUTH_REQUIRED + waitForEventProcessing(processor, () -> queue.enqueueEvent(authRequiredTask)); + + // Create EventConsumer + EventConsumer eventConsumer = new EventConsumer(queue); + + // Call consumeAndBreakOnInterrupt + ResultAggregator.EventTypeAndInterrupt result = + aggregator.consumeAndBreakOnInterrupt(eventConsumer, true); + + // Assert: interrupted=true for AUTH_REQUIRED + assertTrue(result.interrupted(), "AUTH_REQUIRED Task event should trigger interrupt"); + assertEquals(authRequiredTask, result.eventType(), "Event type should be the AUTH_REQUIRED task"); + + // Verify both Task and TaskStatusUpdateEvent can trigger AUTH_REQUIRED interrupt + // (this test validates Task event, testAuthRequiredWithTaskStatusUpdateEvent validates TaskStatusUpdateEvent) + TaskState state = ((Task) result.eventType()).status().state(); + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, state, "Task state should be AUTH_REQUIRED"); + } finally { + queue.close(); + EventQueueUtil.stop(processor); + } + } } diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 75ba229ea..5c852c642 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -1616,6 +1616,132 @@ public void testInputRequiredWorkflow() throws Exception { } } + /** + * Test AUTH_REQUIRED workflow: agent emits AUTH_REQUIRED, continues in background, completes after out-of-band auth. + *

+ * Flow: + * 1. Send initial message → Agent emits AUTH_REQUIRED and returns immediately + * 2. Verify client receives AUTH_REQUIRED state (non-streaming blocking call) + * 3. Subscribe to task to catch background completion + * 4. Verify agent completes in background (simulating out-of-band auth) + * 5. Verify COMPLETED state received via subscription + *

+ * Key behaviors: + * - AUTH_REQUIRED causes immediate return from blocking call (like INPUT_REQUIRED) + * - Agent continues executing in background after returning AUTH_REQUIRED + * - No second message sent (auth happens out-of-band) + * - Subscription receives completion event when agent finishes + */ + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void testAuthRequiredWorkflow() throws Exception { + String authRequiredTaskId = "auth-required-test-" + java.util.UUID.randomUUID(); + boolean taskCreated = false; + try { + // 1. Send initial message - AgentExecutor will transition task to AUTH_REQUIRED then continue in background + Message initialMessage = Message.builder(MESSAGE) + .taskId(authRequiredTaskId) + .contextId("test-context") + .parts(new TextPart("Initial request requiring auth")) + .build(); + + CountDownLatch initialLatch = new CountDownLatch(1); + AtomicReference initialState = new AtomicReference<>(); + AtomicBoolean initialUnexpectedEvent = new AtomicBoolean(false); + + BiConsumer initialConsumer = (event, agentCard) -> { + // Idempotency guard: prevent late events from modifying state after latch countdown + if (initialLatch.getCount() == 0) { + return; + } + if (event instanceof TaskEvent te) { + TaskState state = te.getTask().status().state(); + initialState.set(state); + // Only count down when we receive AUTH_REQUIRED, not intermediate states like WORKING + if (state == TaskState.TASK_STATE_AUTH_REQUIRED) { + initialLatch.countDown(); + } + } else { + initialUnexpectedEvent.set(true); + } + }; + + // Send initial message - task will go to AUTH_REQUIRED state and return immediately + getNonStreamingClient().sendMessage(initialMessage, List.of(initialConsumer), null); + assertTrue(initialLatch.await(10, TimeUnit.SECONDS), "Should receive AUTH_REQUIRED state"); + assertFalse(initialUnexpectedEvent.get(), "Should only receive TaskEvent"); + assertEquals(TaskState.TASK_STATE_AUTH_REQUIRED, initialState.get(), "Task should be in AUTH_REQUIRED state"); + taskCreated = true; + + // 2. Subscribe to task to catch background completion + // Agent continues executing after returning AUTH_REQUIRED (simulating out-of-band auth flow) + CountDownLatch completionLatch = new CountDownLatch(1); + AtomicReference completedState = new AtomicReference<>(); + AtomicBoolean completionUnexpectedEvent = new AtomicBoolean(false); + AtomicReference errorRef = new AtomicReference<>(); + + BiConsumer subscriptionConsumer = (event, agentCard) -> { + // Idempotency guard: prevent late events from modifying state after latch countdown + if (completionLatch.getCount() == 0) { + return; + } + // subscribeToTask returns initial state as TaskEvent, then subsequent events as TaskUpdateEvent + TaskState state = null; + if (event instanceof TaskEvent te) { + state = te.getTask().status().state(); + } else if (event instanceof TaskUpdateEvent tue) { + io.a2a.spec.UpdateEvent updateEvent = tue.getUpdateEvent(); + if (updateEvent instanceof TaskStatusUpdateEvent statusUpdate) { + state = statusUpdate.status().state(); + } else { + // Ignore other update events like TaskArtifactUpdateEvent as they don't change the task state + return; + } + } else { + completionUnexpectedEvent.set(true); + return; + } + + completedState.set(state); + // A2A spec: first event from subscribeToTask is TaskEvent with current state (AUTH_REQUIRED) + // Then we receive TaskUpdateEvent with COMPLETED when agent finishes + if (state == TaskState.TASK_STATE_COMPLETED) { + completionLatch.countDown(); + } + }; + + Consumer errorHandler = errorRef::set; + + // Wait for subscription to be established + CountDownLatch subscriptionLatch = new CountDownLatch(1); + awaitStreamingSubscription() + .whenComplete((unused, throwable) -> subscriptionLatch.countDown()); + + getClient().subscribeToTask(new TaskIdParams(authRequiredTaskId), + List.of(subscriptionConsumer), + errorHandler); + + assertTrue(subscriptionLatch.await(15, TimeUnit.SECONDS), "Subscription should be established"); + + // Note: We don't use awaitChildQueueCountStable() here because the agent is already running + // in the background (sleeping for 3s). By the time we check, it might have already completed. + // The subscriptionLatch already ensures the subscription is established, and completionLatch + // below will catch the COMPLETED event from the background agent. + + // 3. Verify subscription receives COMPLETED state from background agent execution + // Agent should complete after simulating out-of-band auth delay (500ms) + assertTrue(completionLatch.await(10, TimeUnit.SECONDS), "Should receive COMPLETED state from background agent"); + assertFalse(completionUnexpectedEvent.get(), "Should only receive TaskEvent"); + assertNull(errorRef.get(), "Should not receive errors"); + assertEquals(TaskState.TASK_STATE_COMPLETED, completedState.get(), "Task should be COMPLETED after background auth"); + + } finally { + if (taskCreated) { + deleteTaskInTaskStore(authRequiredTaskId); + } + } + } + @Test public void testMalformedJSONRPCRequest() { // skip this test for non-JSONRPC transports diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java index 1db5f2d78..74fe78eab 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java @@ -97,6 +97,31 @@ public void execute(RequestContext context, AgentEmitter agentEmitter) throws A2 } } + // Special handling for auth-required test + if (taskId != null && taskId.startsWith("auth-required-test")) { + // AUTH_REQUIRED workflow: agent emits AUTH_REQUIRED, simulates out-of-band auth delay, then completes + // Go directly to AUTH_REQUIRED without intermediate WORKING state + // This avoids race condition where blocking call interrupts on WORKING + // before AUTH_REQUIRED is persisted to TaskStore + agentEmitter.requiresAuth(agentEmitter.newAgentMessage( + List.of(new TextPart("Please authenticate with OAuth provider")), + context.getMessage().metadata())); + + try { + // Simulate out-of-band authentication delay (user authenticates externally) + // Sleep long enough for test to establish subscription and wait for completion + Thread.sleep(2000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new InternalError("Auth simulation interrupted: " + e.getMessage()); + } + + // Complete task (auth "received" out-of-band) + // Agent continues after AUTH_REQUIRED without new request + agentEmitter.complete(); + return; + } + if (context.getTaskId().equals("task-not-supported-123")) { throw new UnsupportedOperationError(); }