Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion core/src/main/java/com/google/adk/events/EventActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.adk.sessions.State;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -383,7 +384,7 @@ public Builder compaction(EventCompaction value) {
@CanIgnoreReturnValue
public Builder merge(EventActions other) {
other.skipSummarization().ifPresent(this::skipSummarization);
this.stateDelta.putAll(other.stateDelta());
other.stateDelta().forEach((key, value) -> stateDelta.merge(key, value, Builder::deepMerge));
this.artifactDelta.putAll(other.artifactDelta());
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
other.transferToAgent().ifPresent(this::transferToAgent);
Expand All @@ -395,6 +396,34 @@ public Builder merge(EventActions other) {
return this;
}

private static Object deepMerge(Object target, Object source) {
if (!(target instanceof Map) || !(source instanceof Map)) {
// If one of them is not a map, the source value overwrites the target.
return source;
}

Map<?, ?> targetMap = (Map<?, ?>) target;
Map<?, ?> sourceMap = (Map<?, ?>) source;

if (!targetMap.isEmpty() && !sourceMap.isEmpty()) {
Object targetKey = targetMap.keySet().iterator().next();
Object sourceKey = sourceMap.keySet().iterator().next();
if (targetKey != null
&& sourceKey != null
&& !targetKey.getClass().equals(sourceKey.getClass())) {
throw new IllegalArgumentException(
String.format(
"Cannot merge maps with different key types: %s vs %s",
targetKey.getClass().getName(), sourceKey.getClass().getName()));
}
}

// Create a new map to prevent UnsupportedOperationException from immutable maps
Map<Object, Object> mergedMap = new ConcurrentHashMap<>(targetMap);
sourceMap.forEach((key, value) -> mergedMap.merge(key, value, Builder::deepMerge));
return mergedMap;
}

public EventActions build() {
return new EventActions(this);
}
Expand Down
35 changes: 35 additions & 0 deletions core/src/test/java/com/google/adk/events/EventActionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package com.google.adk.events;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.adk.sessions.State;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -130,4 +132,37 @@ public void jsonSerialization_works() throws Exception {
assertThat(deserialized).isEqualTo(eventActions);
assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2");
}

@Test
@SuppressWarnings("unchecked") // the nested map is known to be Map<String, Object>
public void merge_deeplyMergesStateDelta() {
EventActions eventActions1 = EventActions.builder().build();
eventActions1.stateDelta().put("a", 1);
eventActions1.stateDelta().put("b", ImmutableMap.of("nested1", 10, "nested2", 20));
eventActions1.stateDelta().put("c", 100);
EventActions eventActions2 = EventActions.builder().build();
eventActions2.stateDelta().put("a", 2);
eventActions2.stateDelta().put("b", ImmutableMap.of("nested2", 22, "nested3", 30));
eventActions2.stateDelta().put("d", 200);

EventActions merged = eventActions1.toBuilder().merge(eventActions2).build();

assertThat(merged.stateDelta().keySet()).containsExactly("a", "b", "c", "d");
assertThat(merged.stateDelta()).containsEntry("a", 2);
assertThat((Map<String, Object>) merged.stateDelta().get("b"))
.containsExactly("nested1", 10, "nested2", 22, "nested3", 30);
assertThat(merged.stateDelta()).containsEntry("c", 100);
assertThat(merged.stateDelta()).containsEntry("d", 200);
}

@Test
public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() {
EventActions eventActions1 = EventActions.builder().build();
eventActions1.stateDelta().put("nested", ImmutableMap.of("a", 1));
EventActions eventActions2 = EventActions.builder().build();
eventActions2.stateDelta().put("nested", ImmutableMap.of(1, 2));

assertThrows(
IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2));
}
}