diff --git a/.github/workflows/release-please.yaml b/.github/workflows/release-please.yaml index 6d3142907..258cf90db 100644 --- a/.github/workflows/release-please.yaml +++ b/.github/workflows/release-please.yaml @@ -14,4 +14,4 @@ jobs: steps: - uses: googleapis/release-please-action@v4 with: - token: ${{ secrets.GITHUB_TOKEN }} + token: ${{ secrets.RELEASE_PLEASE_TOKEN }} diff --git a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java similarity index 93% rename from a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java rename to a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index b8ff39808..021786162 100644 --- a/a2a/src/main/java/com/google/adk/a2a/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -1,4 +1,19 @@ -package com.google.adk.a2a; +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.agent; import static com.google.common.base.Strings.nullToEmpty; @@ -44,26 +59,21 @@ import org.slf4j.LoggerFactory; /** - * Agent that communicates with a remote A2A agent via A2A client. - * - *

This agent supports multiple ways to specify the remote agent: + * Agent that communicates with a remote A2A agent via an A2A client. * - *

    - *
  1. Direct AgentCard object - *
  2. URL to agent card JSON - *
  3. File path to agent card JSON - *
+ *

The remote agent can be specified directly by providing an {@link AgentCard} to the builder, + * or it can be resolved automatically using the provided A2A client. * - *

The agent handles: + *

Key responsibilities of this agent include: * *

- * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. */ public class RemoteA2AAgent extends BaseAgent { @@ -436,7 +446,7 @@ private boolean mergeAggregatedContentIntoEvent(Event event) { } Content aggregatedContent = Content.builder().role("model").parts(parts).build(); - event.setContent(Optional.of(aggregatedContent)); + event.setContent(aggregatedContent); ImmutableList.Builder newMetadata = ImmutableList.builder(); event.customMetadata().ifPresent(newMetadata::addAll); diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java index 8e8282742..466c89223 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AClientError.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Exception thrown when the A2A client encounters an error. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java index 5c75faeac..a5faeff2a 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/A2AMetadata.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Constants and utilities for A2A metadata keys. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java b/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java index a5947dcb8..0ac56fc01 100644 --- a/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java +++ b/a2a/src/main/java/com/google/adk/a2a/common/GenAiFieldMissingException.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.common; /** Exception thrown when the the genai class has an empty field. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java index b5b53c49a..e0e97c8e9 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2ADataPartMetadataType.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; /** Enum for the type of A2A DataPart metadata. */ diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index 1a49b0070..d823e3817 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -13,12 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Converter for ADK Events to A2A Messages. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Converter for ADK Events to A2A Messages. */ public final class EventConverter { private static final Logger logger = LoggerFactory.getLogger(EventConverter.class); diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 05125d170..96ef66bc8 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -32,12 +47,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utility class for converting between Google GenAI Parts and A2A DataParts. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Utility class for converting between Google GenAI Parts and A2A DataParts. */ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index ccbb1b9cf..f3be48c1b 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.converters; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -27,12 +42,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utility for converting ADK events to A2A spec messages (and back). - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = @@ -76,7 +86,7 @@ private static Optional handleTaskUpdate( boolean isLastChunk = Objects.equals(artifactEvent.isLastChunk(), true); Event eventPart = artifactToEvent(artifactEvent.getArtifact(), context); - eventPart.setPartial(Optional.of(isAppend || !isLastChunk)); + eventPart.setPartial(isAppend || !isLastChunk); // append=true, lastChunk=false: emit as partial, update aggregation // append=false, lastChunk=false: emit as partial, reset aggregation // append=true, lastChunk=true: emit as partial, update aggregation and emit as non-partial diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index b7b4e9953..7252cdec1 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import static java.util.Objects.requireNonNull; @@ -44,12 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ +/** Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. */ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java index ba0177dc4..3ee8656d2 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import com.google.adk.a2a.executor.Callbacks.AfterEventCallback; diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java index 666f1d8a0..3483c527f 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java @@ -1,3 +1,18 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.a2a.executor; import com.google.adk.events.Event; diff --git a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java similarity index 99% rename from a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java rename to a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index 87eaa2321..e75da64ba 100644 --- a/a2a/src/test/java/com/google/adk/a2a/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -1,4 +1,4 @@ -package com.google.adk.a2a; +package com.google.adk.a2a.agent; import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index 5570f40d0..647aaf21f 100644 --- a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -361,7 +361,7 @@ private RequestContext createRequestContext() { public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { Event partial1 = Event.builder() - .partial(Optional.of(true)) + .partial(true) .author("agent_author") .content( Content.builder() @@ -370,7 +370,7 @@ public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { .build(); Event partial2 = Event.builder() - .partial(Optional.of(true)) + .partial(true) .author("agent_author") .content( Content.builder() @@ -379,7 +379,7 @@ public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { .build(); Event finalEvent = Event.builder() - .partial(Optional.of(false)) + .partial(false) .author("agent_author") .content( Content.builder() diff --git a/contrib/samples/a2a_basic/A2AAgent.java b/contrib/samples/a2a_basic/A2AAgent.java index e4e79a4eb..e08a87a67 100644 --- a/contrib/samples/a2a_basic/A2AAgent.java +++ b/contrib/samples/a2a_basic/A2AAgent.java @@ -1,6 +1,6 @@ package com.example.a2a_basic; -import com.google.adk.a2a.RemoteA2AAgent; +import com.google.adk.a2a.agent.RemoteA2AAgent; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; import com.google.adk.tools.FunctionTool; diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 226e61abe..ed6631c50 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,10 +29,10 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -312,37 +312,41 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { + Context parentSpanContext = Context.current(); return Flowable.defer( () -> { InvocationContext invocationContext = createInvocationContext(parentContext); + Flowable mainAndAfterEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)) + .concatWith( + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .toFlowable())); + return callCallback( beforeCallbacksToFunctions( invocationContext.pluginManager(), beforeAgentCallback), invocationContext) .flatMapPublisher( - beforeEventOpt -> { + beforeEvent -> { if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); + return Flowable.just(beforeEvent); } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); + return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) + .switchIfEmpty(mainAndAfterEvents) .compose( - Tracing.traceAgent( - "invoke_agent " + name(), name(), description(), invocationContext)); + Tracing.trace("invoke_agent " + name()) + .setParent(parentSpanContext) + .configure( + span -> + Tracing.traceAgentInvocation( + span, name(), description(), invocationContext))); }); } @@ -383,13 +387,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -404,21 +408,20 @@ private Single> callCallback( .map( content -> { invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build()); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build(); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -426,12 +429,12 @@ private Single> callCallback( .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author(name()) - .branch(invocationContext.branch()) + .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 7602ca9f2..7f0e49d0c 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -75,7 +75,10 @@ protected InvocationContext(Builder builder) { this.eventsCompactionConfig = builder.eventsCompactionConfig; this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(builder.callbackContextData); + // Don't copy the callback context data. This should be the same instance for the full + // invocation invocation so that Plugins can access the same data it during the invocation + // across all types of callbacks. + this.callbackContextData = builder.callbackContextData; } /** @@ -345,7 +348,10 @@ private Builder(InvocationContext context) { this.eventsCompactionConfig = context.eventsCompactionConfig; this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; - this.callbackContextData = new ConcurrentHashMap<>(context.callbackContextData); + // Don't copy the callback context data. This should be the same instance for the full + // invocation invocation so that Plugins can access the same data it during the invocation + // across all types of callbacks. + this.callbackContextData = context.callbackContextData; } private BaseSessionService sessionService; diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index d9d049f80..743d569b9 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -21,7 +21,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,7 +34,7 @@ public class LoopAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(LoopAgent.class); - private final Optional maxIterations; + private final @Nullable Integer maxIterations; /** * Constructor for LoopAgent. @@ -50,7 +50,7 @@ private LoopAgent( String name, String description, List subAgents, - Optional maxIterations, + @Nullable Integer maxIterations, List beforeAgentCallback, List afterAgentCallback) { @@ -60,16 +60,10 @@ private LoopAgent( /** Builder for {@link LoopAgent}. */ public static class Builder extends BaseAgent.Builder { - private Optional maxIterations = Optional.empty(); + private @Nullable Integer maxIterations; @CanIgnoreReturnValue - public Builder maxIterations(int maxIterations) { - this.maxIterations = Optional.of(maxIterations); - return this; - } - - @CanIgnoreReturnValue - public Builder maxIterations(Optional maxIterations) { + public Builder maxIterations(@Nullable Integer maxIterations) { this.maxIterations = maxIterations; return this; } @@ -124,7 +118,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.fromIterable(subAgents) .concatMap(subAgent -> subAgent.runAsync(invocationContext)) - .repeat(maxIterations.orElse(Integer.MAX_VALUE)) + .repeat(maxIterations != null ? maxIterations : Integer.MAX_VALUE) .takeUntil(LoopAgent::hasEscalateAction); } @@ -137,4 +131,8 @@ protected Flowable runLiveImpl(InvocationContext invocationContext) { private static boolean hasEscalateAction(Event event) { return event.actions().escalate().orElse(false); } + + public @Nullable Integer maxIterations() { + return maxIterations; + } } diff --git a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java index 5268edf39..af2219d18 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/VertexAiCodeExecutor.java @@ -36,7 +36,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -140,7 +140,7 @@ public CodeExecutionResult executeCode( executeCodeInterpreter( getCodeWithImports(codeExecutionInput.code()), codeExecutionInput.inputFiles(), - codeExecutionInput.executionId()); + codeExecutionInput.executionId().orElse(null)); // Save output file as artifacts. List savedFiles = new ArrayList<>(); @@ -173,7 +173,7 @@ public CodeExecutionResult executeCode( } private Map executeCodeInterpreter( - String code, List inputFiles, Optional sessionId) { + String code, List inputFiles, @Nullable String sessionId) { ExtensionExecutionServiceClient codeInterpreterExtension = getCodeInterpreterExtension(); if (codeInterpreterExtension == null) { logger.warn("Vertex AI Code Interpreter execution is not available. Returning empty result."); @@ -196,8 +196,9 @@ private Map executeCodeInterpreter( paramsBuilder.putFields( "files", Value.newBuilder().setListValue(listBuilder.build()).build()); } - sessionId.ifPresent( - s -> paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(s).build())); + if (sessionId != null) { + paramsBuilder.putFields("session_id", Value.newBuilder().setStringValue(sessionId).build()); + } ExecuteExtensionRequest request = ExecuteExtensionRequest.newBuilder() diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 91dc79a56..2677b635d 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -49,21 +49,21 @@ public class Event extends JsonBaseModel { private String id; private String invocationId; private String author; - private Optional content = Optional.empty(); + private @Nullable Content content; private EventActions actions; - private Optional> longRunningToolIds = Optional.empty(); - private Optional partial = Optional.empty(); - private Optional turnComplete = Optional.empty(); - private Optional errorCode = Optional.empty(); - private Optional errorMessage = Optional.empty(); - private Optional finishReason = Optional.empty(); - private Optional usageMetadata = Optional.empty(); - private Optional avgLogprobs = Optional.empty(); - private Optional interrupted = Optional.empty(); - private Optional branch = Optional.empty(); - private Optional groundingMetadata = Optional.empty(); - private Optional> customMetadata = Optional.empty(); - private Optional modelVersion = Optional.empty(); + private @Nullable Set longRunningToolIds; + private @Nullable Boolean partial; + private @Nullable Boolean turnComplete; + private @Nullable FinishReason errorCode; + private @Nullable String errorMessage; + private @Nullable FinishReason finishReason; + private @Nullable GenerateContentResponseUsageMetadata usageMetadata; + private @Nullable Double avgLogprobs; + private @Nullable Boolean interrupted; + private @Nullable String branch; + private @Nullable GroundingMetadata groundingMetadata; + private @Nullable List customMetadata; + private @Nullable String modelVersion; private long timestamp; private Event() {} @@ -104,10 +104,10 @@ public void setAuthor(String author) { @JsonProperty("content") public Optional content() { - return content; + return Optional.ofNullable(content); } - public void setContent(Optional content) { + public void setContent(@Nullable Content content) { this.content = content; } @@ -126,10 +126,10 @@ public void setActions(EventActions actions) { */ @JsonProperty("longRunningToolIds") public Optional> longRunningToolIds() { - return longRunningToolIds; + return Optional.ofNullable(longRunningToolIds); } - public void setLongRunningToolIds(Optional> longRunningToolIds) { + public void setLongRunningToolIds(@Nullable Set longRunningToolIds) { this.longRunningToolIds = longRunningToolIds; } @@ -139,73 +139,79 @@ public void setLongRunningToolIds(Optional> longRunningToolIds) { */ @JsonProperty("partial") public Optional partial() { - return partial; + return Optional.ofNullable(partial); } - public void setPartial(Optional partial) { + public void setPartial(@Nullable Boolean partial) { this.partial = partial; } @JsonProperty("turnComplete") public Optional turnComplete() { - return turnComplete; + return Optional.ofNullable(turnComplete); } - public void setTurnComplete(Optional turnComplete) { + public void setTurnComplete(@Nullable Boolean turnComplete) { this.turnComplete = turnComplete; } @JsonProperty("errorCode") public Optional errorCode() { - return errorCode; + return Optional.ofNullable(errorCode); } @JsonProperty("finishReason") public Optional finishReason() { - return finishReason; + return Optional.ofNullable(finishReason); } - public void setErrorCode(Optional errorCode) { + public void setErrorCode(@Nullable FinishReason errorCode) { this.errorCode = errorCode; } + @Deprecated + @SuppressWarnings("checkstyle:IllegalType") public void setFinishReason(Optional finishReason) { + this.finishReason = finishReason.orElse(null); + } + + public void setFinishReason(@Nullable FinishReason finishReason) { this.finishReason = finishReason; } @JsonProperty("errorMessage") public Optional errorMessage() { - return errorMessage; + return Optional.ofNullable(errorMessage); } - public void setErrorMessage(Optional errorMessage) { + public void setErrorMessage(@Nullable String errorMessage) { this.errorMessage = errorMessage; } @JsonProperty("usageMetadata") public Optional usageMetadata() { - return usageMetadata; + return Optional.ofNullable(usageMetadata); } - public void setUsageMetadata(Optional usageMetadata) { + public void setUsageMetadata(@Nullable GenerateContentResponseUsageMetadata usageMetadata) { this.usageMetadata = usageMetadata; } @JsonProperty("avgLogprobs") public Optional avgLogprobs() { - return avgLogprobs; + return Optional.ofNullable(avgLogprobs); } - public void setAvgLogprobs(Optional avgLogprobs) { + public void setAvgLogprobs(@Nullable Double avgLogprobs) { this.avgLogprobs = avgLogprobs; } @JsonProperty("interrupted") public Optional interrupted() { - return interrupted; + return Optional.ofNullable(interrupted); } - public void setInterrupted(Optional interrupted) { + public void setInterrupted(@Nullable Boolean interrupted) { this.interrupted = interrupted; } @@ -216,7 +222,7 @@ public void setInterrupted(Optional interrupted) { */ @JsonProperty("branch") public Optional branch() { - return branch; + return Optional.ofNullable(branch); } /** @@ -227,40 +233,36 @@ public Optional branch() { * @param branch Branch identifier. */ public void branch(@Nullable String branch) { - this.branch = Optional.ofNullable(branch); - } - - public void branch(Optional branch) { this.branch = branch; } /** The grounding metadata of the event. */ @JsonProperty("groundingMetadata") public Optional groundingMetadata() { - return groundingMetadata; + return Optional.ofNullable(groundingMetadata); } - public void setGroundingMetadata(Optional groundingMetadata) { + public void setGroundingMetadata(@Nullable GroundingMetadata groundingMetadata) { this.groundingMetadata = groundingMetadata; } /** The custom metadata of the event. */ @JsonProperty("customMetadata") public Optional> customMetadata() { - return customMetadata; + return Optional.ofNullable(customMetadata); } public void setCustomMetadata(@Nullable List customMetadata) { - this.customMetadata = Optional.ofNullable(customMetadata); + this.customMetadata = customMetadata; } /** The model version used to generate the response. */ @JsonProperty("modelVersion") public Optional modelVersion() { - return modelVersion; + return Optional.ofNullable(modelVersion); } - public void setModelVersion(Optional modelVersion) { + public void setModelVersion(@Nullable String modelVersion) { this.modelVersion = modelVersion; } @@ -345,22 +347,22 @@ public static class Builder { private String id; private String invocationId; private String author; - private Optional content = Optional.empty(); + private @Nullable Content content; private EventActions actions; - private Optional> longRunningToolIds = Optional.empty(); - private Optional partial = Optional.empty(); - private Optional turnComplete = Optional.empty(); - private Optional errorCode = Optional.empty(); - private Optional errorMessage = Optional.empty(); - private Optional finishReason = Optional.empty(); - private Optional usageMetadata = Optional.empty(); - private Optional avgLogprobs = Optional.empty(); - private Optional interrupted = Optional.empty(); - private Optional branch = Optional.empty(); - private Optional groundingMetadata = Optional.empty(); - private Optional> customMetadata = Optional.empty(); - private Optional modelVersion = Optional.empty(); - private Optional timestamp = Optional.empty(); + private @Nullable Set longRunningToolIds; + private @Nullable Boolean partial; + private @Nullable Boolean turnComplete; + private @Nullable FinishReason errorCode; + private @Nullable String errorMessage; + private @Nullable FinishReason finishReason; + private @Nullable GenerateContentResponseUsageMetadata usageMetadata; + private @Nullable Double avgLogprobs; + private @Nullable Boolean interrupted; + private @Nullable String branch; + private @Nullable GroundingMetadata groundingMetadata; + private @Nullable List customMetadata; + private @Nullable String modelVersion; + private @Nullable Long timestamp; @JsonCreator private static Builder create() { @@ -391,12 +393,6 @@ public Builder author(String value) { @CanIgnoreReturnValue @JsonProperty("content") public Builder content(@Nullable Content value) { - this.content = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder content(Optional value) { this.content = value; return this; } @@ -415,12 +411,6 @@ Optional actions() { @CanIgnoreReturnValue @JsonProperty("longRunningToolIds") public Builder longRunningToolIds(@Nullable Set value) { - this.longRunningToolIds = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder longRunningToolIds(Optional> value) { this.longRunningToolIds = value; return this; } @@ -428,12 +418,6 @@ public Builder longRunningToolIds(Optional> value) { @CanIgnoreReturnValue @JsonProperty("partial") public Builder partial(@Nullable Boolean value) { - this.partial = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder partial(Optional value) { this.partial = value; return this; } @@ -441,12 +425,6 @@ public Builder partial(Optional value) { @CanIgnoreReturnValue @JsonProperty("turnComplete") public Builder turnComplete(@Nullable Boolean value) { - this.turnComplete = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder turnComplete(Optional value) { this.turnComplete = value; return this; } @@ -454,12 +432,6 @@ public Builder turnComplete(Optional value) { @CanIgnoreReturnValue @JsonProperty("errorCode") public Builder errorCode(@Nullable FinishReason value) { - this.errorCode = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder errorCode(Optional value) { this.errorCode = value; return this; } @@ -467,12 +439,6 @@ public Builder errorCode(Optional value) { @CanIgnoreReturnValue @JsonProperty("errorMessage") public Builder errorMessage(@Nullable String value) { - this.errorMessage = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder errorMessage(Optional value) { this.errorMessage = value; return this; } @@ -480,12 +446,6 @@ public Builder errorMessage(Optional value) { @CanIgnoreReturnValue @JsonProperty("finishReason") public Builder finishReason(@Nullable FinishReason value) { - this.finishReason = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder finishReason(Optional value) { this.finishReason = value; return this; } @@ -493,12 +453,6 @@ public Builder finishReason(Optional value) { @CanIgnoreReturnValue @JsonProperty("usageMetadata") public Builder usageMetadata(@Nullable GenerateContentResponseUsageMetadata value) { - this.usageMetadata = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder usageMetadata(Optional value) { this.usageMetadata = value; return this; } @@ -506,12 +460,6 @@ public Builder usageMetadata(Optional valu @CanIgnoreReturnValue @JsonProperty("avgLogprobs") public Builder avgLogprobs(@Nullable Double value) { - this.avgLogprobs = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder avgLogprobs(Optional value) { this.avgLogprobs = value; return this; } @@ -519,12 +467,6 @@ public Builder avgLogprobs(Optional value) { @CanIgnoreReturnValue @JsonProperty("interrupted") public Builder interrupted(@Nullable Boolean value) { - this.interrupted = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder interrupted(Optional value) { this.interrupted = value; return this; } @@ -532,84 +474,52 @@ public Builder interrupted(Optional value) { @CanIgnoreReturnValue @JsonProperty("timestamp") public Builder timestamp(long value) { - this.timestamp = Optional.of(value); - return this; - } - - @CanIgnoreReturnValue - public Builder timestamp(Optional value) { this.timestamp = value; return this; } // Getter for builder's timestamp, used in build() Optional timestamp() { - return timestamp; + return Optional.ofNullable(timestamp); } @CanIgnoreReturnValue @JsonProperty("branch") public Builder branch(@Nullable String value) { - this.branch = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder branch(Optional value) { this.branch = value; return this; } // Getter for builder's branch, used in build() Optional branch() { - return branch; + return Optional.ofNullable(branch); } @CanIgnoreReturnValue @JsonProperty("groundingMetadata") public Builder groundingMetadata(@Nullable GroundingMetadata value) { - this.groundingMetadata = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder groundingMetadata(Optional value) { this.groundingMetadata = value; return this; } Optional groundingMetadata() { - return groundingMetadata; + return Optional.ofNullable(groundingMetadata); } @CanIgnoreReturnValue @JsonProperty("customMetadata") public Builder customMetadata(@Nullable List value) { - this.customMetadata = Optional.ofNullable(value); + this.customMetadata = value; return this; } - Optional> customMetadata() { - return customMetadata; - } - @CanIgnoreReturnValue @JsonProperty("modelVersion") public Builder modelVersion(@Nullable String value) { - this.modelVersion = Optional.ofNullable(value); - return this; - } - - @CanIgnoreReturnValue - public Builder modelVersion(Optional value) { this.modelVersion = value; return this; } - Optional modelVersion() { - return modelVersion; - } - public Event build() { Event event = new Event(); event.setId(id); @@ -627,7 +537,7 @@ public Event build() { event.setInterrupted(interrupted); event.branch(branch); event.setGroundingMetadata(groundingMetadata); - event.setCustomMetadata(customMetadata.orElse(null)); + event.setCustomMetadata(customMetadata); event.setModelVersion(modelVersion); event.setActions(actions().orElseGet(() -> EventActions.builder().build())); event.setTimestamp(timestamp().orElseGet(() -> Instant.now().toEpochMilli())); @@ -664,7 +574,7 @@ public Builder toBuilder() { .interrupted(this.interrupted) .branch(this.branch) .groundingMetadata(this.groundingMetadata) - .customMetadata(this.customMetadata.orElse(null)) + .customMetadata(this.customMetadata) .modelVersion(this.modelVersion); if (this.timestamp != 0) { builder.timestamp(this.timestamp); diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index bf25acfc7..0b167de93 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -28,36 +28,32 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** Represents the actions attached to an event. */ // TODO - b/414081262 make json wire camelCase @JsonDeserialize(builder = EventActions.Builder.class) public class EventActions extends JsonBaseModel { - private Optional skipSummarization; + private @Nullable Boolean skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; private Set deletedArtifactIds; - private Optional transferToAgent; - private Optional escalate; + private @Nullable String transferToAgent; + private @Nullable Boolean escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; - private Optional compaction; + private @Nullable EventCompaction compaction; /** Default constructor for Jackson. */ public EventActions() { - this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); this.deletedArtifactIds = new HashSet<>(); - this.transferToAgent = Optional.empty(); - this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; - this.compaction = Optional.empty(); } private EventActions(Builder builder) { @@ -75,19 +71,15 @@ private EventActions(Builder builder) { @JsonProperty("skipSummarization") public Optional skipSummarization() { - return skipSummarization; + return Optional.ofNullable(skipSummarization); } public void setSkipSummarization(@Nullable Boolean skipSummarization) { - this.skipSummarization = Optional.ofNullable(skipSummarization); - } - - public void setSkipSummarization(Optional skipSummarization) { this.skipSummarization = skipSummarization; } public void setSkipSummarization(boolean skipSummarization) { - this.skipSummarization = Optional.of(skipSummarization); + this.skipSummarization = skipSummarization; } @JsonProperty("stateDelta") @@ -110,12 +102,12 @@ public void removeStateByKey(String key) { } @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public Map artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { - this.artifactDelta = artifactDelta; + public void setArtifactDelta(Map artifactDelta) { + this.artifactDelta = new ConcurrentHashMap<>(artifactDelta); } @JsonProperty("deletedArtifactIds") @@ -130,30 +122,22 @@ public void setDeletedArtifactIds(Set deletedArtifactIds) { @JsonProperty("transferToAgent") public Optional transferToAgent() { - return transferToAgent; + return Optional.ofNullable(transferToAgent); } - public void setTransferToAgent(Optional transferToAgent) { + public void setTransferToAgent(@Nullable String transferToAgent) { this.transferToAgent = transferToAgent; } - public void setTransferToAgent(String transferToAgent) { - this.transferToAgent = Optional.ofNullable(transferToAgent); - } - @JsonProperty("escalate") public Optional escalate() { - return escalate; + return Optional.ofNullable(escalate); } - public void setEscalate(Optional escalate) { + public void setEscalate(@Nullable Boolean escalate) { this.escalate = escalate; } - public void setEscalate(boolean escalate) { - this.escalate = Optional.of(escalate); - } - @JsonProperty("requestedAuthConfigs") public ConcurrentMap> requestedAuthConfigs() { return requestedAuthConfigs; @@ -165,13 +149,20 @@ public void setRequestedAuthConfigs( } @JsonProperty("requestedToolConfirmations") - public ConcurrentMap requestedToolConfirmations() { + public Map requestedToolConfirmations() { return requestedToolConfirmations; } public void setRequestedToolConfirmations( - ConcurrentMap requestedToolConfirmations) { - this.requestedToolConfirmations = requestedToolConfirmations; + Map requestedToolConfirmations) { + if (requestedToolConfirmations == null) { + this.requestedToolConfirmations = new ConcurrentHashMap<>(); + } else if (requestedToolConfirmations instanceof ConcurrentMap) { + this.requestedToolConfirmations = + (ConcurrentMap) requestedToolConfirmations; + } else { + this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); + } } @JsonProperty("endOfAgent") @@ -192,14 +183,6 @@ public Optional endInvocation() { return endOfAgent ? Optional.of(true) : Optional.empty(); } - /** - * @deprecated Use {@link #setEndOfAgent(boolean)} instead. - */ - @Deprecated - public void setEndInvocation(Optional endInvocation) { - this.endOfAgent = endInvocation.orElse(false); - } - /** * @deprecated Use {@link #setEndOfAgent(boolean)} instead. */ @@ -210,10 +193,10 @@ public void setEndInvocation(boolean endInvocation) { @JsonProperty("compaction") public Optional compaction() { - return compaction; + return Optional.ofNullable(compaction); } - public void setCompaction(Optional compaction) { + public void setCompaction(@Nullable EventCompaction compaction) { this.compaction = compaction; } @@ -262,47 +245,43 @@ public int hashCode() { /** Builder for {@link EventActions}. */ public static class Builder { - private Optional skipSummarization; + private @Nullable Boolean skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; private Set deletedArtifactIds; - private Optional transferToAgent; - private Optional escalate; + private @Nullable String transferToAgent; + private @Nullable Boolean escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; - private Optional compaction; + private @Nullable EventCompaction compaction; public Builder() { - this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); this.deletedArtifactIds = new HashSet<>(); - this.transferToAgent = Optional.empty(); - this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); - this.compaction = Optional.empty(); } private Builder(EventActions eventActions) { - this.skipSummarization = eventActions.skipSummarization(); + this.skipSummarization = eventActions.skipSummarization; this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); - this.transferToAgent = eventActions.transferToAgent(); - this.escalate = eventActions.escalate(); + this.transferToAgent = eventActions.transferToAgent; + this.escalate = eventActions.escalate; this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); this.requestedToolConfirmations = new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); - this.endOfAgent = eventActions.endOfAgent(); - this.compaction = eventActions.compaction(); + this.endOfAgent = eventActions.endOfAgent; + this.compaction = eventActions.compaction; } @CanIgnoreReturnValue @JsonProperty("skipSummarization") public Builder skipSummarization(boolean skipSummarization) { - this.skipSummarization = Optional.of(skipSummarization); + this.skipSummarization = skipSummarization; return this; } @@ -315,8 +294,8 @@ public Builder stateDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { - this.artifactDelta = value; + public Builder artifactDelta(Map value) { + this.artifactDelta = new ConcurrentHashMap<>(value); return this; } @@ -329,15 +308,15 @@ public Builder deletedArtifactIds(Set value) { @CanIgnoreReturnValue @JsonProperty("transferToAgent") - public Builder transferToAgent(String agentId) { - this.transferToAgent = Optional.ofNullable(agentId); + public Builder transferToAgent(@Nullable String agentId) { + this.transferToAgent = agentId; return this; } @CanIgnoreReturnValue @JsonProperty("escalate") public Builder escalate(boolean escalate) { - this.escalate = Optional.of(escalate); + this.escalate = escalate; return this; } @@ -351,8 +330,16 @@ public Builder requestedAuthConfigs( @CanIgnoreReturnValue @JsonProperty("requestedToolConfirmations") - public Builder requestedToolConfirmations(ConcurrentMap value) { - this.requestedToolConfirmations = value; + public Builder requestedToolConfirmations(@Nullable Map value) { + if (value == null) { + this.requestedToolConfirmations = new ConcurrentHashMap<>(); + return this; + } + if (value instanceof ConcurrentMap) { + this.requestedToolConfirmations = (ConcurrentMap) value; + } else { + this.requestedToolConfirmations = new ConcurrentHashMap<>(value); + } return this; } @@ -376,8 +363,8 @@ public Builder endInvocation(boolean endInvocation) { @CanIgnoreReturnValue @JsonProperty("compaction") - public Builder compaction(EventCompaction value) { - this.compaction = Optional.ofNullable(value); + public Builder compaction(@Nullable EventCompaction value) { + this.compaction = value; return this; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 1249728d8..79066b213 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -164,58 +164,60 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + Context spanContext, + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + .toFlowable() + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm").setParent(spanContext)) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( + private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = @@ -228,7 +230,7 @@ private Single> handleBeforeModelCallback( List callbacks = agent.canonicalBeforeModelCallbacks(); if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + return pluginResult; } Maybe callbackResult = @@ -238,10 +240,7 @@ private Single> handleBeforeModelCallback( .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) .firstElement()); - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + return pluginResult.switchIfEmpty(callbackResult); } /** @@ -323,7 +322,7 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context) { + private Flowable runOneStep(Context spanContext, InvocationContext context) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( @@ -351,11 +350,15 @@ private Flowable runOneStep(InvocationContext context) { .id(Event.generateEventId()) .invocationId(context.invocationId()) .author(context.agent().name()) - .branch(context.branch()) + .branch(context.branch().orElse(null)) .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + spanContext, + context, + llmRequestAfterPreprocess, + mutableEventTemplate) .concatMap( llmResponse -> { try (Scope postScope = currentContext.makeCurrent()) { @@ -407,11 +410,12 @@ private Flowable runOneStep(InvocationContext context) { */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, 0); + return run(Context.current(), invocationContext, 0); } - private Flowable run(InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext).cache(); + private Flowable run( + Context spanContext, InvocationContext invocationContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -431,7 +435,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, stepsCompleted + 1); + return run(spanContext, invocationContext, stepsCompleted + 1); } })); } @@ -448,6 +452,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + Context spanContext = Context.current(); return preprocessEvents.concatWith( Flowable.defer( @@ -485,7 +490,7 @@ public Flowable runLive(InvocationContext invocationContext) { eventIdForSendData, llmRequestAfterPreprocess.contents()); }) - .compose(Tracing.trace("send_data")); + .compose(Tracing.trace("send_data").setParent(spanContext)); Flowable liveRequests = invocationContext @@ -535,7 +540,7 @@ public void onError(Throwable e) { Event.builder() .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) - .branch(invocationContext.branch()); + .branch(invocationContext.branch().orElse(null)); Flowable receiveFlow = connection @@ -639,17 +644,17 @@ private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = baseEventForLlmResponse.toBuilder() - .content(llmResponse.content()) - .partial(llmResponse.partial()) - .errorCode(llmResponse.errorCode()) - .errorMessage(llmResponse.errorMessage()) - .interrupted(llmResponse.interrupted()) - .turnComplete(llmResponse.turnComplete()) - .groundingMetadata(llmResponse.groundingMetadata()) - .avgLogprobs(llmResponse.avgLogprobs()) - .finishReason(llmResponse.finishReason()) - .usageMetadata(llmResponse.usageMetadata()) - .modelVersion(llmResponse.modelVersion()); + .content(llmResponse.content().orElse(null)) + .partial(llmResponse.partial().orElse(null)) + .errorCode(llmResponse.errorCode().orElse(null)) + .errorMessage(llmResponse.errorMessage().orElse(null)) + .interrupted(llmResponse.interrupted().orElse(null)) + .turnComplete(llmResponse.turnComplete().orElse(null)) + .groundingMetadata(llmResponse.groundingMetadata().orElse(null)) + .avgLogprobs(llmResponse.avgLogprobs().orElse(null)) + .finishReason(llmResponse.finishReason().orElse(null)) + .usageMetadata(llmResponse.usageMetadata().orElse(null)) + .modelVersion(llmResponse.modelVersion().orElse(null)); Event event = eventBuilder.build(); @@ -661,7 +666,7 @@ private Event buildModelResponseEvent( Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools()); logger.debug("longRunningToolIds: {}", longRunningToolIds); if (!longRunningToolIds.isEmpty()) { - event.setLongRunningToolIds(Optional.of(longRunningToolIds)); + event.setLongRunningToolIds(longRunningToolIds); } } return event; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index be0504dd4..f2cbe967e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -159,8 +159,7 @@ public Single processResponse( InvocationContext invocationContext, LlmResponse llmResponse) { if (llmResponse.partial().orElse(false)) { return Single.just( - ResponseProcessor.ResponseProcessingResult.create( - llmResponse, ImmutableList.of(), Optional.empty())); + ResponseProcessor.ResponseProcessingResult.create(llmResponse, ImmutableList.of())); } var llmResponseBuilder = llmResponse.toBuilder(); return runPostProcessor(invocationContext, llmResponseBuilder) @@ -168,7 +167,7 @@ public Single processResponse( .map( events -> ResponseProcessor.ResponseProcessingResult.create( - llmResponseBuilder.build(), events, Optional.empty())); + llmResponseBuilder.build(), events)); } } @@ -229,7 +228,7 @@ private static Flowable runPreProcessor( Event.builder() .invocationId(invocationContext.invocationId()) .author(llmAgent.name()) - .content(Optional.of(codeContent)) + .content(codeContent) .build(); return Flowable.defer( @@ -309,7 +308,7 @@ private static Flowable runPostProcessor( Event.builder() .invocationId(invocationContext.invocationId()) .author(llmAgent.name()) - .content(Optional.of(responseContent)) + .content(responseContent) .actions(EventActions.builder().build()) .build(); @@ -456,7 +455,7 @@ private static Single postProcessCodeExecutionResult( return Event.builder() .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) - .content(Optional.of(resultContent)) + .content(resultContent) .actions(eventActionsBuilder.build()) .build(); }); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index a770808d4..6ebd39a9c 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -25,6 +25,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventCompaction; import com.google.adk.models.LlmRequest; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -41,6 +42,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import javax.annotation.Nullable; /** {@link RequestProcessor} that populates content in request for LLM flows. */ public final class Contents implements RequestProcessor { @@ -68,7 +70,7 @@ public Single processRequest( request.toBuilder() .contents( getCurrentTurnContents( - context.branch(), + context.branch().orElse(null), context.session().events(), context.agent().name(), modelName)) @@ -78,7 +80,10 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch(), context.session().events(), context.agent().name(), modelName); + context.branch().orElse(null), + context.session().events(), + context.agent().name(), + modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -87,7 +92,7 @@ public Single processRequest( /** Gets contents for the current turn only (no conversation history). */ private ImmutableList getCurrentTurnContents( - Optional currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName, String modelName) { // Find the latest event that starts the current turn and process from there. for (int i = events.size() - 1; i >= 0; i--) { Event event = events.get(i); @@ -99,7 +104,7 @@ private ImmutableList getCurrentTurnContents( } private ImmutableList getContents( - Optional currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName, String modelName) { List filteredEvents = new ArrayList<>(); boolean hasCompactEvent = false; @@ -414,16 +419,12 @@ private static String convertMapToJson(Map struct) { } } - private static boolean isEventBelongsToBranch(Optional invocationBranchOpt, Event event) { - Optional eventBranchOpt = event.branch(); + private static boolean isEventBelongsToBranch(@Nullable String invocationBranch, Event event) { + @Nullable String eventBranch = event.branch().orElse(null); - if (invocationBranchOpt.isEmpty() || invocationBranchOpt.get().isEmpty()) { - return true; - } - if (eventBranchOpt.isEmpty() || eventBranchOpt.get().isEmpty()) { - return true; - } - return invocationBranchOpt.get().startsWith(eventBranchOpt.get()); + return Strings.isNullOrEmpty(invocationBranch) + || Strings.isNullOrEmpty(eventBranch) + || invocationBranch.startsWith(eventBranch); } /** @@ -760,9 +761,7 @@ private static Event mergeFunctionResponseEvents(List functionResponseEve } return baseEvent.toBuilder() - .content( - Optional.of( - Content.builder().role(baseContent.role().get()).parts(partsInMergedEvent).build())) + .content(Content.builder().role(baseContent.role().get()).parts(partsInMergedEvent).build()) .build(); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 269764046..c1a996064 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -122,7 +122,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) { new IllegalStateException( "Content role is missing in event: " + modelResponseEvent.id())); Content newContent = Content.builder().role(role).parts(newParts).build(); - modelResponseEvent.setContent(Optional.of(newContent)); + modelResponseEvent.setContent(newContent); } } @@ -178,7 +178,7 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response", parentContext)); + .compose(Tracing.trace("tool_response").setParent(parentContext)); } return Maybe.just(mergedEvent); }); @@ -432,8 +432,8 @@ private static Maybe postProcessFunctionResult( toolContext, invocationContext)) .compose( - Tracing.trace( - "tool_response [" + tool.name() + "]", parentContext)) + Tracing.trace("tool_response [" + tool.name() + "]") + .setParent(parentContext)) .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); }); } @@ -468,8 +468,8 @@ private static Optional mergeParallelFunctionResponseEvents( .id(Event.generateEventId()) .invocationId(baseEvent.invocationId()) .author(baseEvent.author()) - .branch(baseEvent.branch()) - .content(Optional.of(Content.builder().role("user").parts(mergedParts).build())) + .branch(baseEvent.branch().orElse(null)) + .content(Content.builder().role("user").parts(mergedParts).build()) .actions(mergedActionsBuilder.build()) .timestamp(baseEvent.timestamp()) .build()); @@ -593,7 +593,9 @@ private static Maybe> callTool( Tracing.traceToolCall( tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose(Tracing.trace("tool_call [" + tool.name() + "]", parentContext)) + .compose( + Tracing.>trace("tool_call [" + tool.name() + "]") + .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( @@ -624,7 +626,7 @@ private static Event buildResponseEvent( .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) - .branch(invocationContext.branch()) + .branch(invocationContext.branch().orElse(null)) .content(Content.builder().role("user").parts(partFunctionResponse).build()) .actions(toolContext.eventActions()) .build(); @@ -684,7 +686,7 @@ public static Optional generateRequestConfirmationEvent( Event.builder() .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) - .branch(invocationContext.branch()) + .branch(invocationContext.branch().orElse(null)) .content(contentBuilder.build()) .longRunningToolIds(longRunningToolIds) .build()); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java index 4baa29523..d8e5ce3ab 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/ResponseProcessor.java @@ -50,11 +50,16 @@ public abstract static class ResponseProcessingResult { */ public abstract Optional transferToAgent(); - /** Creates a new {@link ResponseProcessingResult}. */ public static ResponseProcessingResult create( - LlmResponse updatedResponse, Iterable events, Optional transferToAgent) { + LlmResponse updatedResponse, Iterable events, String transferToAgent) { return new AutoValue_ResponseProcessor_ResponseProcessingResult( - updatedResponse, events, transferToAgent); + updatedResponse, events, Optional.of(transferToAgent)); + } + + public static ResponseProcessingResult create( + LlmResponse updatedResponse, Iterable events) { + return new AutoValue_ResponseProcessor_ResponseProcessingResult( + updatedResponse, events, /* transferToAgent= */ Optional.empty()); } } diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index e35969147..1a45c3a95 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -17,7 +17,6 @@ package com.google.adk.models; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -172,29 +171,33 @@ public final Builder appendInstructions(List instructions) { return liveConnectConfig(liveCfg.toBuilder().systemInstruction(newLiveSi).build()); } + // In this particular case we can keep the Optional as a type of a + // parameter, since the function is private and used in only one place while + // the Optional type plays nicely with flatMaps in the code (if we had a + // nullable here, we'd wrap it in the Optional anyway) private Content addInstructions( - Optional currentSystemInstruction, List additionalInstructions) { + @SuppressWarnings("checkstyle:IllegalType") Optional currentSystemInstruction, + List additionalInstructions) { checkArgument( - currentSystemInstruction.isEmpty() - || currentSystemInstruction.get().parts().map(parts -> parts.size()).orElse(0) <= 1, + currentSystemInstruction.flatMap(Content::parts).map(parts -> parts.size()).orElse(0) + <= 1, "At most one instruction is supported."); // Either append to the existing instruction, or create a new one. String instructions = String.join("\n\n", additionalInstructions); - Optional part = - currentSystemInstruction - .flatMap(Content::parts) - .flatMap(parts -> parts.stream().findFirst()); - if (part.isEmpty() || part.get().text().isEmpty()) { - part = Optional.of(Part.fromText(instructions)); - } else { - part = Optional.of(Part.fromText(part.get().text().get() + "\n\n" + instructions)); - } - checkState(part.isPresent(), "Failed to create instruction."); + Part part = + Part.fromText( + currentSystemInstruction + .flatMap(Content::parts) + .flatMap(parts -> parts.stream().findFirst()) + .flatMap(Part::text) + .map(text -> text + "\n\n" + instructions) + .orElse(instructions)); String role = currentSystemInstruction.flatMap(Content::role).orElse("user"); - return Content.builder().parts(part.get()).role(role).build(); + + return Content.builder().parts(part).role(role).build(); } @CanIgnoreReturnValue diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 56dea936a..e534da787 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,7 +48,7 @@ public class PluginManager extends BasePlugin { private static final Logger logger = LoggerFactory.getLogger(PluginManager.class); private final List plugins = new ArrayList<>(); - public PluginManager(List plugins) { + public PluginManager(@Nullable List plugins) { super("PluginManager"); if (plugins != null) { plugins.forEach(this::registerPlugin); @@ -123,10 +124,6 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { plugin -> plugin.beforeRunCallback(invocationContext), "beforeRunCallback"); } - public Completable runAfterRunCallback(InvocationContext invocationContext) { - return afterRunCallback(invocationContext); - } - @Override public Completable afterRunCallback(InvocationContext invocationContext) { return Flowable.fromIterable(plugins) @@ -155,41 +152,24 @@ public Completable close() { "[{}] Error during callback 'close'", plugin.getName(), e))); } - public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { - return onEventCallback(invocationContext, event); - } - @Override public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return runMaybeCallbacks( plugin -> plugin.onEventCallback(invocationContext, event), "onEventCallback"); } - public Maybe runBeforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return beforeAgentCallback(agent, callbackContext); - } - @Override public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.beforeAgentCallback(agent, callbackContext), "beforeAgentCallback"); } - public Maybe runAfterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return afterAgentCallback(agent, callbackContext); - } - @Override public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.afterAgentCallback(agent, callbackContext), "afterAgentCallback"); } - public Maybe runBeforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return beforeModelCallback(callbackContext, llmRequest); - } - @Override public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { @@ -197,11 +177,6 @@ public Maybe beforeModelCallback( plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback"); } - public Maybe runAfterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return afterModelCallback(callbackContext, llmResponse); - } - @Override public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { @@ -209,11 +184,6 @@ public Maybe afterModelCallback( plugin -> plugin.afterModelCallback(callbackContext, llmResponse), "afterModelCallback"); } - public Maybe runOnModelErrorCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { - return onModelErrorCallback(callbackContext, llmRequest, error); - } - @Override public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { @@ -222,11 +192,6 @@ public Maybe onModelErrorCallback( "onModelErrorCallback"); } - public Maybe> runBeforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return beforeToolCallback(tool, toolArgs, toolContext); - } - @Override public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { @@ -234,14 +199,6 @@ public Maybe> beforeToolCallback( plugin -> plugin.beforeToolCallback(tool, toolArgs, toolContext), "beforeToolCallback"); } - public Maybe> runAfterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return afterToolCallback(tool, toolArgs, toolContext, result); - } - @Override public Maybe> afterToolCallback( BaseTool tool, @@ -253,11 +210,6 @@ public Maybe> afterToolCallback( "afterToolCallback"); } - public Maybe> runOnToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { - return onToolErrorCallback(tool, toolArgs, toolContext, error); - } - @Override public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index f2cb5b9d5..29b2b76d3 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -340,7 +340,7 @@ private Single appendNewMessageToSession( .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author("user") - .content(Optional.of(newMessage)); + .content(newMessage); // Add state delta if provided if (stateDelta != null && !stateDelta.isEmpty()) { @@ -540,7 +540,7 @@ private Flowable runAgentWithFreshSession( .id(Event.generateEventId()) .invocationId(contextWithUpdatedSession.invocationId()) .author("model") - .content(Optional.of(content)) + .content(content) .build()); // Agent execution @@ -568,7 +568,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } diff --git a/core/src/main/java/com/google/adk/sessions/ApiClient.java b/core/src/main/java/com/google/adk/sessions/ApiClient.java index 6bf69ee47..e850199e9 100644 --- a/core/src/main/java/com/google/adk/sessions/ApiClient.java +++ b/core/src/main/java/com/google/adk/sessions/ApiClient.java @@ -67,7 +67,7 @@ abstract class ApiClient { applyHttpOptions(customHttpOptions.get()); } - this.httpClient = createHttpClient(httpOptions.timeout()); + this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } ApiClient( @@ -113,13 +113,13 @@ abstract class ApiClient { } this.apiKey = Optional.empty(); this.vertexAI = true; - this.httpClient = createHttpClient(httpOptions.timeout()); + this.httpClient = createHttpClient(httpOptions.timeout().orElse(null)); } - private OkHttpClient createHttpClient(Optional timeout) { + private OkHttpClient createHttpClient(@Nullable Integer timeout) { OkHttpClient.Builder builder = new OkHttpClient().newBuilder(); - if (timeout.isPresent()) { - builder.connectTimeout(Duration.ofMillis(timeout.get())); + if (timeout != null) { + builder.connectTimeout(Duration.ofMillis(timeout)); } return builder.build(); } diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 71b072695..0c2b33704 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -73,7 +73,7 @@ static String convertEventToJson(Event event, boolean useIsoString) { event.turnComplete().ifPresent(v -> metadataJson.put("turnComplete", v)); event.interrupted().ifPresent(v -> metadataJson.put("interrupted", v)); event.branch().ifPresent(v -> metadataJson.put("branch", v)); - putIfNotEmpty(metadataJson, "longRunningToolIds", event.longRunningToolIds()); + event.longRunningToolIds().ifPresent(v -> putIfNotEmpty(metadataJson, "longRunningToolIds", v)); event.groundingMetadata().ifPresent(v -> metadataJson.put("groundingMetadata", v)); event.usageMetadata().ifPresent(v -> metadataJson.put("usageMetadata", v)); Map eventJson = new HashMap<>(); @@ -208,9 +208,12 @@ static Event fromApiEvent(Map apiEvent) { .timestamp(convertToInstant(apiEvent.get("timestamp")).toEpochMilli()) .errorCode( Optional.ofNullable(apiEvent.get("errorCode")) - .map(value -> new FinishReason((String) value))) + .map(value -> new FinishReason((String) value)) + .orElse(null)) .errorMessage( - Optional.ofNullable(apiEvent.get("errorMessage")).map(value -> (String) value)) + Optional.ofNullable(apiEvent.get("errorMessage")) + .map(value -> (String) value) + .orElse(null)) .build(); Map eventMetadata = (Map) apiEvent.get("eventMetadata"); if (eventMetadata != null) { @@ -236,7 +239,7 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable((Boolean) eventMetadata.get("turnComplete")).orElse(false)) .interrupted( Optional.ofNullable((Boolean) eventMetadata.get("interrupted")).orElse(false)) - .branch(Optional.ofNullable((String) eventMetadata.get("branch"))) + .branch((String) eventMetadata.get("branch")) .groundingMetadata(groundingMetadata) .usageMetadata(usageMetadata) .longRunningToolIds( @@ -352,11 +355,6 @@ private static void putIfNotEmpty(Map map, String key, Map } } - private static void putIfNotEmpty( - Map map, String key, Optional> values) { - values.ifPresent(v -> putIfNotEmpty(map, key, v)); - } - private static void putIfNotEmpty( Map map, String key, @Nullable Collection values) { if (values != null && !values.isEmpty()) { diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..70d2dfbf2 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -24,6 +24,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import javax.annotation.Nullable; /** A {@link State} object that also keeps track of the changes to the state. */ @SuppressWarnings("ShouldNotSubclass") @@ -39,13 +40,22 @@ public final class State implements ConcurrentMap { private final ConcurrentMap state; private final ConcurrentMap delta; - public State(ConcurrentMap state) { - this(state, new ConcurrentHashMap<>()); - } - - public State(ConcurrentMap state, ConcurrentMap delta) { - this.state = Objects.requireNonNull(state); - this.delta = delta; + public State(Map state) { + this(state, null); + } + + public State(Map state, @Nullable Map delta) { + Objects.requireNonNull(state, "state is null"); + this.state = + state instanceof ConcurrentMap + ? (ConcurrentMap) state + : new ConcurrentHashMap<>(state); + this.delta = + delta == null + ? new ConcurrentHashMap<>() + : delta instanceof ConcurrentMap + ? (ConcurrentMap) delta + : new ConcurrentHashMap<>(delta); } @Override diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 07a640c37..7f338fdcf 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -127,13 +127,13 @@ public class Tracing { private Tracing() {} - private static Optional getValidCurrentSpan(String methodName) { + private static void traceWithSpan(String methodName, Consumer traceAction) { Span span = Span.current(); if (!span.getSpanContext().isValid()) { log.trace("{}: No valid span in current context.", methodName); - return Optional.empty(); + return; } - return Optional.of(span); + traceAction.accept(span); } private static void setInvocationAttributes( @@ -206,16 +206,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - getValidCurrentSpan("traceToolCall") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); + traceWithSpan( + "traceToolCall", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -225,33 +225,33 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - getValidCurrentSpan("traceToolResponse") - .ifPresent( - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); - - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + traceWithSpan( + "traceToolResponse", + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -296,58 +296,54 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - getValidCurrentSpan("traceCallLlm") - .ifPresent( - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> - span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> - span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> - span.setAttribute( - GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); - }); + traceWithSpan( + "traceCallLlm", + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent( + topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> + span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); + }); } /** @@ -359,17 +355,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - getValidCurrentSpan("traceSendData") - .ifPresent( - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + traceWithSpan( + "traceSendData", + span -> { + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -426,19 +422,6 @@ public static TracerProvider trace(String spanName) { return new TracerProvider<>(spanName); } - /** - * Returns a transformer that traces the execution of an RxJava stream with an explicit parent - * context. - * - * @param spanName The name of the span to create. - * @param parentContext The explicit parent context for the span. - * @param The type of the stream. - * @return A TracerProvider that can be used with .compose(). - */ - public static TracerProvider trace(String spanName, Context parentContext) { - return new TracerProvider(spanName).setParent(parentContext); - } - /** * Returns a transformer that traces an agent invocation. * diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java index 4f3be46c2..b3d0f73bb 100644 --- a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; /** * Represents the current state of the computer environment. @@ -31,11 +32,11 @@ */ public final class ComputerState { private final byte[] screenshot; - private final Optional url; + private final @Nullable String url; @JsonCreator private ComputerState( - @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") Optional url) { + @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") @Nullable String url) { this.screenshot = screenshot.clone(); this.url = url; } @@ -47,7 +48,7 @@ public byte[] screenshot() { @JsonProperty("url") public Optional url() { - return url; + return Optional.ofNullable(url); } public static Builder builder() { @@ -57,7 +58,7 @@ public static Builder builder() { /** Builder for {@link ComputerState}. */ public static final class Builder { private byte[] screenshot; - private Optional url = Optional.empty(); + private @Nullable String url; @CanIgnoreReturnValue public Builder screenshot(byte[] screenshot) { @@ -66,17 +67,11 @@ public Builder screenshot(byte[] screenshot) { } @CanIgnoreReturnValue - public Builder url(Optional url) { + public Builder url(@Nullable String url) { this.url = url; return this; } - @CanIgnoreReturnValue - public Builder url(String url) { - this.url = Optional.ofNullable(url); - return this; - } - public ComputerState build() { return new ComputerState(screenshot, url); } diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java index 73af9cc6a..bcc786d69 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java @@ -22,6 +22,8 @@ import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; import com.google.adk.tools.NamedToolPredicate; +import com.google.adk.tools.ToolPredicate; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.modelcontextprotocol.client.McpAsyncClient; @@ -32,8 +34,8 @@ import java.time.Duration; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -59,14 +61,14 @@ public class McpAsyncToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private final AtomicReference>> mcpTools = new AtomicReference<>(); /** Builder for McpAsyncToolset */ public static class Builder { private Object connectionParams = null; private ObjectMapper objectMapper = null; - private Optional toolFilter = null; + private @Nullable Object toolFilter = null; @CanIgnoreReturnValue public Builder connectionParams(ServerParameters connectionParams) { @@ -87,14 +89,14 @@ public Builder objectMapper(ObjectMapper objectMapper) { } @CanIgnoreReturnValue - public Builder toolFilter(Optional toolFilter) { - this.toolFilter = toolFilter; + public Builder toolFilter(List toolNames) { + this.toolFilter = new NamedToolPredicate(Preconditions.checkNotNull(toolNames)); return this; } @CanIgnoreReturnValue - public Builder toolFilter(List toolNames) { - this.toolFilter = Optional.of(new NamedToolPredicate(toolNames)); + public Builder toolFilter(@Nullable ToolPredicate toolPredicate) { + this.toolFilter = toolPredicate; return this; } @@ -118,12 +120,12 @@ public McpAsyncToolset build() { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolFilter Either a ToolPredicate or a List of tool names. */ - public McpAsyncToolset( + McpAsyncToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { + @Nullable Object toolFilter) { Objects.requireNonNull(connectionParams); Objects.requireNonNull(objectMapper); this.objectMapper = objectMapper; @@ -136,10 +138,10 @@ public McpAsyncToolset( * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolFilter Either a ToolPredicate or a List of tool names or null. */ - public McpAsyncToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { + McpAsyncToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, @Nullable Object toolFilter) { Objects.requireNonNull(connectionParams); Objects.requireNonNull(objectMapper); this.objectMapper = objectMapper; @@ -147,22 +149,6 @@ public McpAsyncToolset( this.toolFilter = toolFilter; } - /** - * Initializes the McpAsyncToolset with a provided McpSessionManager. - * - * @param mcpSessionManager The session manager for MCP connections. - * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. - */ - public McpAsyncToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = mcpSessionManager; - this.toolFilter = toolFilter; - } - @Override public Flowable getTools(ReadonlyContext readonlyContext) { return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture())) @@ -170,7 +156,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { .map( tools -> tools.stream() - .filter(tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext)) + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext)) .toList()) .onErrorResumeNext( err -> { diff --git a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java index 5c04ac74b..b2d0778c6 100644 --- a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java @@ -33,7 +33,6 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -165,11 +164,7 @@ public void runAsync_withNoMaxIterations_keepsLooping() { Event event2 = createEvent("event2"); TestBaseAgent subAgent = createSubAgent("subAgent", () -> Flowable.just(event1, event2)); LoopAgent loopAgent = - LoopAgent.builder() - .name("loopAgent") - .subAgents(ImmutableList.of(subAgent)) - .maxIterations(Optional.empty()) - .build(); + LoopAgent.builder().name("loopAgent").subAgents(ImmutableList.of(subAgent)).build(); InvocationContext invocationContext = createInvocationContext(loopAgent); Iterable result = loopAgent.runAsync(invocationContext).blockingIterable(); diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 28123bab8..22bb94e64 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -26,6 +26,7 @@ import com.google.genai.types.Part; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -109,6 +110,16 @@ public void merge_mergesAllFields() { assertThat(merged.compaction()).hasValue(COMPACTION); } + @Test + public void setArtifactDelta_copiesRegularMap() { + EventActions eventActions = new EventActions(); + ImmutableMap artifactDelta = ImmutableMap.of("artifact1", 1); + + eventActions.setArtifactDelta(artifactDelta); + + assertThat(eventActions.artifactDelta()).containsExactly("artifact1", 1); + } + @Test public void removeStateByKey_marksKeyAsRemoved() { EventActions eventActions = new EventActions(); @@ -165,4 +176,27 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { assertThrows( IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } + + @Test + public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { + ConcurrentHashMap map = new ConcurrentHashMap<>(); + map.put("tool", TOOL_CONFIRMATION); + + EventActions actions = new EventActions(); + actions.setRequestedToolConfirmations(map); + + assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); + } + + @Test + public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { + ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); + + EventActions actions = new EventActions(); + actions.setRequestedToolConfirmations(map); + + assertThat(actions.requestedToolConfirmations()).isNotSameInstanceAs(map); + assertThat(actions.requestedToolConfirmations()).isInstanceOf(ConcurrentMap.class); + assertThat(actions.requestedToolConfirmations()).containsExactly("tool", TOOL_CONFIRMATION); + } } diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index cbfb6ef0b..a4feab5c1 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -76,6 +76,7 @@ public final class EventTest { .avgLogprobs(0.5) .interrupted(true) .timestamp(123456789L) + .modelVersion("model_version") .build(); @Test @@ -99,6 +100,7 @@ public void event_builder_works() { assertThat(EVENT.interrupted()).hasValue(true); assertThat(EVENT.timestamp()).isEqualTo(123456789L); assertThat(EVENT.actions()).isEqualTo(EVENT_ACTIONS); + assertThat(EVENT.modelVersion()).hasValue("model_version"); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index ff151a0b2..4a0b345c6 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -524,19 +524,14 @@ private static RequestProcessor createRequestProcessor( private static ResponseProcessor createResponseProcessor() { return (context, response) -> - Single.just( - ResponseProcessingResult.create( - response, ImmutableList.of(), /* transferToAgent= */ Optional.empty())); + Single.just(ResponseProcessingResult.create(response, ImmutableList.of())); } private static ResponseProcessor createResponseProcessor( Function responseUpdater) { return (context, response) -> Single.just( - ResponseProcessingResult.create( - responseUpdater.apply(response), - ImmutableList.of(), - /* transferToAgent= */ Optional.empty())); + ResponseProcessingResult.create(responseUpdater.apply(response), ImmutableList.of())); } private static class TestTool extends BaseTool { diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 3041a855b..85e78666d 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -784,7 +784,7 @@ private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) .author(USER) - .content(Optional.of(Content.fromParts(Part.fromText(text)))) + .content(Content.fromParts(Part.fromText(text))) .invocationId("invocationId") .build(); } @@ -794,7 +794,7 @@ private static Event createUserEvent( return Event.builder() .id(id) .author(USER) - .content(Optional.of(Content.fromParts(Part.fromText(text)))) + .content(Content.fromParts(Part.fromText(text))) .invocationId(invocationId) .timestamp(timestamp) .build(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index 97092f68c..d5db4d4b3 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -33,7 +33,6 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -43,12 +42,7 @@ public final class FunctionsTest { private static final Event EVENT_WITH_NO_CONTENT = - Event.builder() - .id("event1") - .invocationId("invocation1") - .author("agent") - .content(Optional.empty()) - .build(); + Event.builder().id("event1").invocationId("invocation1").author("agent").build(); private static final Event EVENT_WITH_NO_PARTS = Event.builder() diff --git a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java index 764525ff0..a08599c9a 100644 --- a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java @@ -65,9 +65,7 @@ public class LoggingPluginTest { Event.builder() .id("event_id") .author("author") - .content(Optional.empty()) .actions(EventActions.builder().build()) - .longRunningToolIds(Optional.empty()) .build(); private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("default").contents(ImmutableList.of()); diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index f6120cf08..335f7f1d0 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -22,7 +22,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.junit.Test; @@ -49,12 +48,12 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio .author("user") .invocationId("inv-123") .timestamp(Instant.parse("2023-01-01T00:00:00Z").toEpochMilli()) - .errorCode(Optional.of(new FinishReason("OTHER"))) - .errorMessage(Optional.of("Something was not found")) + .errorCode(new FinishReason("OTHER")) + .errorMessage("Something was not found") .partial(true) .turnComplete(true) .interrupted(false) - .branch(Optional.of("branch-1")) + .branch("branch-1") .content(Content.fromParts(Part.fromText("Hello"))) .actions(actions) .build(); diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java new file mode 100644 index 000000000..e1fcaeadc --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -0,0 +1,50 @@ +package com.google.adk.sessions; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StateTest { + @Test + public void constructor_nullDelta_createsEmptyConcurrentHashMap() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap, null); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } + + @Test + public void constructor_nullState_throwsException() { + Assert.assertThrows(NullPointerException.class, () -> new State(null, new HashMap<>())); + } + + @Test + public void constructor_regularMapState() { + Map stateMap = new HashMap<>(); + stateMap.put("initial", "val"); + State state = new State(stateMap, null); + // It should have copied the contents + assertThat(state).containsEntry("initial", "val"); + state.put("key", "value"); + // The original map should NOT be updated because a copy was created + assertThat(stateMap).doesNotContainKey("key"); + } + + @Test + public void constructor_singleArgument() { + ConcurrentMap stateMap = new ConcurrentHashMap<>(); + State state = new State(stateMap); + assertThat(state.hasDelta()).isFalse(); + state.put("key", "value"); + assertThat(state.hasDelta()).isTrue(); + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java index 20fb146cf..236172b27 100644 --- a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java @@ -30,7 +30,6 @@ import java.lang.reflect.Method; import java.util.Base64; import java.util.Map; -import java.util.Optional; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -128,10 +127,7 @@ public void testNormalizeDragAndDrop() throws NoSuchMethodException { public void testResultFormatting() throws NoSuchMethodException { byte[] screenshot = new byte[] {1, 2, 3}; computerMock.nextState = - ComputerState.builder() - .screenshot(screenshot) - .url(Optional.of("https://example.com")) - .build(); + ComputerState.builder().screenshot(screenshot).url("https://example.com").build(); Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); ComputerUseTool tool = @@ -226,8 +222,7 @@ public static class ComputerMock { public int lastY; public int lastDestX; public int lastDestY; - public ComputerState nextState = - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build(); + public ComputerState nextState = ComputerState.builder().screenshot(new byte[0]).build(); public Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y) { this.lastX = x; diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java index 1ed49419e..8051a018d 100644 --- a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java @@ -173,87 +173,73 @@ public Single environment() { @Override public Single openWebBrowser() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single clickAt(int x, int y) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single hoverAt(int x, int y) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single typeTextAt( int x, int y, String text, Boolean pressEnter, Boolean clearBeforeTyping) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single scrollDocument(String direction) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single scrollAt(int x, int y, String direction, int magnitude) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single wait(Duration duration) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single goBack() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single goForward() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single search() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single navigate(String url) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.of(url)).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).url(url).build()); } @Override public Single keyCombination(List keys) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single dragAndDrop(int x, int y, int destinationX, int destinationY) { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override public Single currentState() { - return Single.just( - ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + return Single.just(ComputerState.builder().screenshot(new byte[0]).build()); } @Override