From 611376ce78f8fbf26ffc80b0b221ca63f17dfa4b Mon Sep 17 00:00:00 2001 From: Hemasekhar Puchuginjala Date: Tue, 16 Jun 2026 19:30:56 +0530 Subject: [PATCH] fix: Enable state context in AgentExecutor --- .../adk/a2a/executor/AgentExecutor.java | 16 ++++++++++++- .../adk/a2a/executor/AgentExecutorTest.java | 23 +++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 77a0d62d7..5713d1b2e 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -17,6 +17,7 @@ import static java.util.Objects.requireNonNull; +import com.google.adk.a2a.converters.AdkMetadataKey; import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.PartConverter; import com.google.adk.agents.BaseAgent; @@ -217,7 +218,8 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { getUserId(ctx), session.id(), content, - agentExecutorConfig.runConfig()); + agentExecutorConfig.runConfig(), + getStateDelta(ctx)); }); }) .concatMap( @@ -273,6 +275,18 @@ private String getUserId(RequestContext ctx) { return USER_ID_PREFIX + ctx.getContextId(); } + private Map getStateDelta(RequestContext ctx) { + Map metadata = new HashMap<>(); + + if (ctx.getTaskId() != null) { + metadata.put(AdkMetadataKey.TASK_ID.getType(), ctx.getTaskId()); + } + if (ctx.getContextId() != null) { + metadata.put(AdkMetadataKey.CONTEXT_ID.getType(), ctx.getContextId()); + } + return metadata; + } + private Maybe prepareSession( RequestContext ctx, String appName, BaseSessionService service) { return service diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index 02998063d..7db08b234 100644 --- a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -4,11 +4,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -23,19 +19,11 @@ import com.google.genai.types.Part; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; -import io.a2a.spec.Message; -import io.a2a.spec.TaskArtifactUpdateEvent; -import io.a2a.spec.TaskState; -import io.a2a.spec.TaskStatus; -import io.a2a.spec.TaskStatusUpdateEvent; -import io.a2a.spec.TextPart; +import io.a2a.spec.*; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.UUID; +import java.util.*; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -350,10 +338,15 @@ private RequestContext createRequestContext() { .parts(ImmutableList.of(new TextPart("trigger"))) .build(); + MessageSendParams params = mock(MessageSendParams.class); + when(params.message()).thenReturn(message); + when(params.metadata()).thenReturn(ImmutableMap.of("key", "value")); + RequestContext ctx = mock(RequestContext.class); when(ctx.getMessage()).thenReturn(message); when(ctx.getTaskId()).thenReturn("task-" + UUID.randomUUID()); when(ctx.getContextId()).thenReturn("ctx-" + UUID.randomUUID()); + when(ctx.getParams()).thenReturn(params); return ctx; }