diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 9c977240c..c3c921be5 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -289,10 +289,16 @@ public Builder skipSummarization(@Nullable Boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") public Builder stateDelta(@Nullable Map value) { - if (value == null) { - this.stateDelta = new ConcurrentHashMap<>(); - } else { - this.stateDelta = new ConcurrentHashMap<>(value); + this.stateDelta = new ConcurrentHashMap<>(); + if (value != null) { + // Convert null values to State.REMOVED to avoid NPEs. + value + .entrySet() + .forEach( + entry -> { + stateDelta.put( + entry.getKey(), Optional.ofNullable(entry.getValue()).orElse(State.REMOVED)); + }); } return this; } diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index 577559f85..9a7042a72 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -46,16 +47,24 @@ public State(Map state) { public State(Map state, @Nullable Map delta) { Objects.requireNonNull(state, "state is null"); - this.state = - state instanceof ConcurrentMap - ? (ConcurrentMap) state - : new ConcurrentHashMap<>(state); - this.delta = - delta == null - ? new ConcurrentHashMap<>() - : delta instanceof ConcurrentMap - ? (ConcurrentMap) delta - : new ConcurrentHashMap<>(delta); + this.state = toConcurrentMap(state); + this.delta = delta == null ? new ConcurrentHashMap<>() : toConcurrentMap(delta); + } + + /** + * Converts a map to a concurrent map. Null values are converted to {@link #REMOVED} to avoid + * NPEs. + * + *

If the map is already a concurrent map, it is returned as is. Otherwise, a new concurrent + * map is created and returned. + */ + private static ConcurrentMap toConcurrentMap(Map map) { + if (map instanceof ConcurrentMap) { + return (ConcurrentMap) map; + } + ConcurrentMap concurrentMap = new ConcurrentHashMap<>(); + map.forEach((key, value) -> concurrentMap.put(key, Optional.ofNullable(value).orElse(REMOVED))); + return concurrentMap; } @Override diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index b1e645e1a..4b542c0e9 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -129,6 +130,33 @@ public void removeStateByKey_marksKeyAsRemoved() { assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); } + @Test + public void builderStateDelta_withNullMap_initializesEmptyMap() { + EventActions eventActions = EventActions.builder().stateDelta(null).build(); + + assertThat(eventActions.stateDelta()).isEmpty(); + } + + @Test + public void builderStateDelta_withNullValue_marksKeyAsRemoved() { + Map inputDelta = new HashMap<>(); + inputDelta.put("key1", "value1"); + inputDelta.put("key2", null); + + EventActions eventActions = EventActions.builder().stateDelta(inputDelta).build(); + + assertThat(eventActions.stateDelta()).containsExactly("key1", "value1", "key2", State.REMOVED); + } + + @Test + public void jsonDeserialization_withNullValueInStateDelta_deserializesAsRemoved() + throws Exception { + String json = "{\"stateDelta\":{\"key1\":\"value1\",\"key2\":null}}"; + EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + + assertThat(deserialized.stateDelta()).containsExactly("key1", "value1", "key2", State.REMOVED); + } + @Test public void jsonSerialization_works() throws Exception { EventActions eventActions = diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java index e1fcaeadc..295466d1c 100644 --- a/core/src/test/java/com/google/adk/sessions/StateTest.java +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -6,7 +6,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -22,11 +21,6 @@ public void constructor_nullDelta_createsEmptyConcurrentHashMap() { assertThat(state.hasDelta()).isTrue(); } - @Test - public void constructor_nullState_throwsException() { - Assert.assertThrows(NullPointerException.class, () -> new State(null, new HashMap<>())); - } - @Test public void constructor_regularMapState() { Map stateMap = new HashMap<>(); @@ -47,4 +41,14 @@ public void constructor_singleArgument() { state.put("key", "value"); assertThat(state.hasDelta()).isTrue(); } + + @Test + public void constructor_stateMapWithNullValues_replacesWithRemoved() { + Map stateMap = new HashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", null); + State state = new State(stateMap); + assertThat(state).containsEntry("key1", "value1"); + assertThat(state).containsEntry("key2", State.REMOVED); + } }