Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.a2a.spec.Event;
import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatusUpdateEvent;
import mutiny.zero.BackpressureStrategy;
import mutiny.zero.TubeConfiguration;
Expand Down Expand Up @@ -77,7 +78,7 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
} else if (event instanceof Message) {
isFinalEvent = true;
} else if (event instanceof Task task) {
isFinalEvent = task.status().state().isFinal();
isFinalEvent = isStreamTerminatingTask(task);
} else if (event instanceof QueueClosedEvent) {
// Poison pill event - signals queue closure from remote node
// Do NOT send to subscribers - just close the queue
Expand All @@ -94,7 +95,7 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
}

if (isFinalEvent) {
LOGGER.debug("Final event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue));
LOGGER.debug("Final or interrupted event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue));
queue.close();
LOGGER.debug("Queue closed, breaking loop for queue {}", System.identityHashCode(queue));
break;
Expand All @@ -120,6 +121,21 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
});
}

/**
* Determines if a task is in a state for terminating the stream.
* <p>A task is terminating if:</p>
* <ul>
* <li>Its state is final (e.g., completed, canceled, rejected, failed), OR</li>
* <li>Its state is interrupted (e.g., input-required)</li>
* </ul>
* @param task the task to check
* @return true if the task has a final state or an interrupted state, false otherwise
*/
private boolean isStreamTerminatingTask(Task task) {
TaskState state = task.status().state();
return state.isFinal() || state == TaskState.INPUT_REQUIRED;
}

public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() {
return agentRunnable -> {
if (agentRunnable.getError() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException {
final List<Event> receivedEvents = new ArrayList<>();
final AtomicReference<Throwable> error = new AtomicReference<>();

publisher.subscribe(new Flow.Subscriber<>() {
private Flow.Subscription subscription;

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
subscription.request(1);
}

@Override
public void onNext(EventQueueItem item) {
receivedEvents.add(item.getEvent());
subscription.request(1);

}

@Override
public void onError(Throwable throwable) {
error.set(throwable);
}

@Override
public void onComplete() {
subscription.cancel();
}
});
publisher.subscribe(getSubscriber(receivedEvents, error));

assertNull(error.get());
assertEquals(events.size(), receivedEvents.size());
Expand Down Expand Up @@ -175,32 +150,7 @@ public void testConsumeUntilMessage() throws Exception {
final List<Event> receivedEvents = new ArrayList<>();
final AtomicReference<Throwable> error = new AtomicReference<>();

publisher.subscribe(new Flow.Subscriber<>() {
private Flow.Subscription subscription;

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
subscription.request(1);
}

@Override
public void onNext(EventQueueItem item) {
receivedEvents.add(item.getEvent());
subscription.request(1);

}

@Override
public void onError(Throwable throwable) {
error.set(throwable);
}

@Override
public void onComplete() {
subscription.cancel();
}
});
publisher.subscribe(getSubscriber(receivedEvents, error));

assertNull(error.get());
assertEquals(3, receivedEvents.size());
Expand All @@ -224,7 +174,55 @@ public void testConsumeMessageEvents() throws Exception {
final List<Event> receivedEvents = new ArrayList<>();
final AtomicReference<Throwable> error = new AtomicReference<>();

publisher.subscribe(new Flow.Subscriber<>() {
publisher.subscribe(getSubscriber(receivedEvents, error));

assertNull(error.get());
// The stream is closed after the first Message
assertEquals(1, receivedEvents.size());
assertSame(message, receivedEvents.get(0));
}

@Test
public void testConsumeTaskInputRequired() {
Task task = Task.builder()
.id("task-id")
.contextId("task-context")
.status(new TaskStatus(TaskState.INPUT_REQUIRED))
.build();
List<Event> events = List.of(
task,
TaskArtifactUpdateEvent.builder()
.taskId("task-123")
.contextId("session-xyz")
.artifact(Artifact.builder()
.artifactId("11")
.parts(new TextPart("text"))
.build())
.build(),
TaskStatusUpdateEvent.builder()
.taskId("task-123")
.contextId("session-xyz")
.status(new TaskStatus(TaskState.WORKING))
.isFinal(true)
.build());
for (Event event : events) {
eventQueue.enqueueEvent(event);
}

Flow.Publisher<EventQueueItem> publisher = eventConsumer.consumeAll();
final List<Event> receivedEvents = new ArrayList<>();
final AtomicReference<Throwable> error = new AtomicReference<>();

publisher.subscribe(getSubscriber(receivedEvents, error));

assertNull(error.get());
// The stream is closed after the input_required task
assertEquals(1, receivedEvents.size());
assertSame(task, receivedEvents.get(0));
}

private Flow.Subscriber<EventQueueItem> getSubscriber(List<Event> receivedEvents, AtomicReference<Throwable> error) {
return new Flow.Subscriber<>() {
private Flow.Subscription subscription;

@Override
Expand All @@ -237,7 +235,6 @@ public void onSubscribe(Flow.Subscription subscription) {
public void onNext(EventQueueItem item) {
receivedEvents.add(item.getEvent());
subscription.request(1);

}

@Override
Expand All @@ -249,12 +246,7 @@ public void onError(Throwable throwable) {
public void onComplete() {
subscription.cancel();
}
});

assertNull(error.get());
// The stream is closed after the first Message
assertEquals(1, receivedEvents.size());
assertSame(message, receivedEvents.get(0));
};
}

@Test
Expand Down