diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 6ebd39a9c..840a370c6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -64,6 +64,11 @@ public Single processRequest( modelName = ""; } + ImmutableList sessionEvents; + synchronized (context.session().events()) { + sessionEvents = ImmutableList.copyOf(context.session().events()); + } + if (llmAgent.includeContents() == LlmAgent.IncludeContents.NONE) { return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -71,7 +76,7 @@ public Single processRequest( .contents( getCurrentTurnContents( context.branch().orElse(null), - context.session().events(), + sessionEvents, context.agent().name(), modelName)) .build(), @@ -80,10 +85,7 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch().orElse(null), - context.session().events(), - context.agent().name(), - modelName); + context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85e78666d..7164991f3 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,10 +36,13 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -780,6 +783,63 @@ public void processRequest_notEmptyContent() { assertThat(contents).containsExactly(e.content().get()); } + @Test + public void processRequest_concurrentReadAndWrite_noException() throws Exception { + LlmAgent agent = + LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + Session session = + sessionService + .createSession("test-app", "test-user", new HashMap<>(), "test-session") + .blockingGet(); + + // Seed with dummy events to widen the race capability + for (int i = 0; i < 5000; i++) { + session.events().add(createUserEvent("dummy" + i, "dummy")); + } + + InvocationContext context = + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .sessionService(sessionService) + .build(); + + LlmRequest initialRequest = LlmRequest.builder().build(); + + AtomicReference writerError = new AtomicReference<>(); + CountDownLatch startLatch = new CountDownLatch(1); + + Thread writerThread = + new Thread( + () -> { + startLatch.countDown(); + try { + for (int i = 0; i < 2000; i++) { + session.events().add(createUserEvent("writer" + i, "new data")); + } + } catch (Throwable t) { + writerError.set(t); + } + }); + + writerThread.start(); + startLatch.await(); // wait for writer to be ready + + // Process (read) requests concurrently to trigger race conditions + for (int i = 0; i < 200; i++) { + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + + writerThread.join(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id)