diff --git a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java index b9afdcaff..a4d3771c3 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java +++ b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutionUtils.java @@ -34,6 +34,7 @@ import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; /** Utility functions for code execution. */ public final class CodeExecutionUtils { @@ -237,8 +238,7 @@ public abstract static class CodeExecutionInput extends JsonBaseModel { public static Builder builder() { return new AutoValue_CodeExecutionUtils_CodeExecutionInput.Builder() - .inputFiles(ImmutableList.of()) - .executionId(Optional.empty()); + .inputFiles(ImmutableList.of()); } /** Builder for {@link CodeExecutionInput}. */ @@ -248,7 +248,7 @@ public abstract static class Builder { public abstract Builder inputFiles(List inputFiles); - public abstract Builder executionId(Optional executionId); + public abstract Builder executionId(@Nullable String executionId); public abstract CodeExecutionInput build(); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index f2cbe967e..d76cd1a04 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -240,7 +240,8 @@ private static Flowable runPreProcessor( .code(codeStr) .inputFiles(ImmutableList.of(file)) .executionId( - getOrSetExecutionId(invocationContext, codeExecutorContext)) + getOrSetExecutionId(invocationContext, codeExecutorContext) + .orElse(null)) .build()); codeExecutorContext.updateCodeExecutionResult( @@ -320,7 +321,9 @@ private static Flowable runPostProcessor( CodeExecutionInput.builder() .code(codeStr) .inputFiles(codeExecutorContext.getInputFiles()) - .executionId(getOrSetExecutionId(invocationContext, codeExecutorContext)) + .executionId( + getOrSetExecutionId(invocationContext, codeExecutorContext) + .orElse(null)) .build()); codeExecutorContext.updateCodeExecutionResult( invocationContext.invocationId(), diff --git a/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java b/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java index 353504dac..1485ca2c4 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/CodeExecutionTest.java @@ -20,6 +20,7 @@ import static com.google.adk.testing.TestUtils.createTestAgentBuilder; import static com.google.adk.testing.TestUtils.createTestLlm; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; @@ -32,14 +33,19 @@ import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionResult; import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; +import java.util.ArrayList; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -115,4 +121,38 @@ public void testResponseProcessor_withCode_executesCode() { assertThat(executionResultPart.codeExecutionResult().get().output()) .hasValue("Code execution result:\nhello\n\n"); } + + @Test + public void testRequestProcessor_withCode_hasNoErrors() throws Exception { + // arrange + LlmRequest.Builder llmReqBuilder = LlmRequest.builder(); + when(mockCodeExecutor.codeBlockDelimiters()) + .thenReturn(ImmutableList.of(ImmutableList.of("```tool_code", "\n```"))); + when(mockCodeExecutor.optimizeDataFile()).thenReturn(true); + when(mockCodeExecutor.errorRetryAttempts()).thenReturn(2); + CodeExecutionResult executionResult = CodeExecutionResult.builder().stdout("hello\n").build(); + when(mockCodeExecutor.executeCode(any(), any())).thenReturn(executionResult); + llmReqBuilder.contents( + new ArrayList<>( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .inlineData( + Blob.builder() + .mimeType("text/csv") + .data("1,2,3\n".getBytes(UTF_8))) + .build())) + .build()))); + + // act + Single result = + CodeExecution.requestProcessor.processRequest(invocationContext, llmReqBuilder.build()); + TestObserver testObserver = result.test(); + + // assert + testObserver.assertNoErrors(); + } }