diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..bf25acfc7 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -22,6 +22,7 @@ import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -383,7 +384,7 @@ public Builder compaction(EventCompaction value) { @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); - this.stateDelta.putAll(other.stateDelta()); + other.stateDelta().forEach((key, value) -> stateDelta.merge(key, value, Builder::deepMerge)); this.artifactDelta.putAll(other.artifactDelta()); this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); @@ -395,6 +396,34 @@ public Builder merge(EventActions other) { return this; } + private static Object deepMerge(Object target, Object source) { + if (!(target instanceof Map) || !(source instanceof Map)) { + // If one of them is not a map, the source value overwrites the target. + return source; + } + + Map targetMap = (Map) target; + Map sourceMap = (Map) source; + + if (!targetMap.isEmpty() && !sourceMap.isEmpty()) { + Object targetKey = targetMap.keySet().iterator().next(); + Object sourceKey = sourceMap.keySet().iterator().next(); + if (targetKey != null + && sourceKey != null + && !targetKey.getClass().equals(sourceKey.getClass())) { + throw new IllegalArgumentException( + String.format( + "Cannot merge maps with different key types: %s vs %s", + targetKey.getClass().getName(), sourceKey.getClass().getName())); + } + } + + // Create a new map to prevent UnsupportedOperationException from immutable maps + Map mergedMap = new ConcurrentHashMap<>(targetMap); + sourceMap.forEach((key, value) -> mergedMap.merge(key, value, Builder::deepMerge)); + return mergedMap; + } + public EventActions build() { return new EventActions(this); } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 94cd399df..28123bab8 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -17,12 +17,14 @@ package com.google.adk.events; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -130,4 +132,37 @@ public void jsonSerialization_works() throws Exception { assertThat(deserialized).isEqualTo(eventActions); assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); } + + @Test + @SuppressWarnings("unchecked") // the nested map is known to be Map + public void merge_deeplyMergesStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("a", 1); + eventActions1.stateDelta().put("b", ImmutableMap.of("nested1", 10, "nested2", 20)); + eventActions1.stateDelta().put("c", 100); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("a", 2); + eventActions2.stateDelta().put("b", ImmutableMap.of("nested2", 22, "nested3", 30)); + eventActions2.stateDelta().put("d", 200); + + EventActions merged = eventActions1.toBuilder().merge(eventActions2).build(); + + assertThat(merged.stateDelta().keySet()).containsExactly("a", "b", "c", "d"); + assertThat(merged.stateDelta()).containsEntry("a", 2); + assertThat((Map) merged.stateDelta().get("b")) + .containsExactly("nested1", 10, "nested2", 22, "nested3", 30); + assertThat(merged.stateDelta()).containsEntry("c", 100); + assertThat(merged.stateDelta()).containsEntry("d", 200); + } + + @Test + public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("nested", ImmutableMap.of("a", 1)); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("nested", ImmutableMap.of(1, 2)); + + assertThrows( + IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); + } }