From 5dcfc68f6a7ffa35edb52de0ef4d5e5f52a564b0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Feb 2026 08:44:02 -0800 Subject: [PATCH] refactor: Simplifying State interfaces This is really two changes: 1. Replace the interface of state from ConcurrentMap to Map 2. Under the covers use extenral synchornization (Collections.synchronize, etc) along with HashMap wich allows nulls to represent "remove this variable from the session" This devX improvement comes with a subtle assumption that State will be passed in as a HashMap. This change may cause subtle breaking changes. PiperOrigin-RevId: 872418434 --- .../adk/sessions/FirestoreSessionService.java | 25 ++-- .../com/google/adk/events/EventActions.java | 118 ++++++++---------- .../java/com/google/adk/runner/Runner.java | 5 +- .../adk/sessions/BaseSessionService.java | 29 ++++- .../adk/sessions/InMemorySessionService.java | 98 +++++++-------- .../java/com/google/adk/sessions/Session.java | 9 +- .../adk/sessions/SessionJsonConverter.java | 79 +++--------- .../java/com/google/adk/sessions/State.java | 63 +++++----- .../google/adk/sessions/VertexAiClient.java | 8 +- .../adk/sessions/VertexAiSessionService.java | 28 +++-- .../google/adk/events/EventActionsTest.java | 58 ++++----- .../plugins/GlobalInstructionPluginTest.java | 2 +- .../sessions/InMemorySessionServiceTest.java | 40 +++--- .../sessions/SessionJsonConverterTest.java | 24 ++-- .../sessions/VertexAiSessionServiceTest.java | 9 +- .../google/adk/plugins/ReplayPluginTest.java | 3 +- 16 files changed, 274 insertions(+), 324 deletions(-) diff --git a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java index d3295a6ef..ae5a26f99 100644 --- a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java +++ b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java @@ -46,7 +46,6 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; @@ -85,10 +84,19 @@ private CollectionReference getSessionsCollection(String userId) { .collection(SESSION_COLLECTION_NAME); } + @Override + public Single createSession( + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + /** Creates a new session in Firestore. */ @Override public Single createSession( - String appName, String userId, ConcurrentMap state, String sessionId) { + String appName, String userId, Map state, String sessionId) { return Single.fromCallable( () -> { Objects.requireNonNull(appName, "appName cannot be null"); @@ -100,21 +108,17 @@ public Single createSession( .filter(s -> !s.isEmpty()) .orElseGet(() -> UUID.randomUUID().toString()); - ConcurrentMap initialState = - (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); logger.info( "Creating session for userId: {} with sessionId: {} and initial state: {}", userId, resolvedSessionId, - initialState); - List initialEvents = new ArrayList<>(); + state); Instant now = Instant.now(); Session newSession = Session.builder(resolvedSessionId) .appName(appName) .userId(userId) - .state(initialState) - .events(initialEvents) + .state(state) .lastUpdateTime(now) .build(); @@ -200,8 +204,7 @@ public Maybe getSession( }) .map( events -> { - ConcurrentMap state = - new ConcurrentHashMap<>((Map) data.get(STATE_KEY)); + Map state = (Map) data.get(STATE_KEY); return Session.builder((String) data.get(ID_KEY)) .appName((String) data.get(APP_NAME_KEY)) .userId((String) data.get(USER_ID_KEY)) @@ -451,8 +454,6 @@ public Single listSessions(String appName, String userId) .appName((String) data.get(APP_NAME_KEY)) .userId((String) data.get(USER_ID_KEY)) .lastUpdateTime(Instant.parse((String) data.get(UPDATE_TIME_KEY))) - .state(new ConcurrentHashMap<>()) // Empty state - .events(new ArrayList<>()) // Empty events .build(); sessions.add(session); } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..a4fe4461d 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,14 +19,13 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.JsonBaseModel; -import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; -import java.util.HashSet; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; /** Represents the actions attached to an event. */ @@ -35,39 +34,37 @@ public class EventActions extends JsonBaseModel { private Optional skipSummarization; - private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; - private Set deletedArtifactIds; + private Map stateDelta; + private Map artifactDelta; private Optional transferToAgent; private Optional escalate; - private ConcurrentMap> requestedAuthConfigs; - private ConcurrentMap requestedToolConfirmations; + private Map> requestedAuthConfigs; + private Map requestedToolConfirmations; private boolean endOfAgent; private Optional compaction; /** Default constructor for Jackson. */ public EventActions() { this.skipSummarization = Optional.empty(); - this.stateDelta = new ConcurrentHashMap<>(); - this.artifactDelta = new ConcurrentHashMap<>(); - this.deletedArtifactIds = new HashSet<>(); + this.stateDelta = Collections.synchronizedMap(new HashMap<>()); + this.artifactDelta = Collections.synchronizedMap(new HashMap<>()); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(); - this.requestedToolConfirmations = new ConcurrentHashMap<>(); + this.requestedAuthConfigs = Collections.synchronizedMap(new HashMap<>()); + this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>()); this.endOfAgent = false; this.compaction = Optional.empty(); } private EventActions(Builder builder) { this.skipSummarization = builder.skipSummarization; - this.stateDelta = builder.stateDelta; - this.artifactDelta = builder.artifactDelta; - this.deletedArtifactIds = builder.deletedArtifactIds; + this.stateDelta = Collections.synchronizedMap(builder.stateDelta); + this.artifactDelta = Collections.synchronizedMap(builder.artifactDelta); this.transferToAgent = builder.transferToAgent; this.escalate = builder.escalate; - this.requestedAuthConfigs = builder.requestedAuthConfigs; - this.requestedToolConfirmations = builder.requestedToolConfirmations; + this.requestedAuthConfigs = Collections.synchronizedMap(builder.requestedAuthConfigs); + this.requestedToolConfirmations = + Collections.synchronizedMap(builder.requestedToolConfirmations); this.endOfAgent = builder.endOfAgent; this.compaction = builder.compaction; } @@ -90,41 +87,32 @@ public void setSkipSummarization(boolean skipSummarization) { } @JsonProperty("stateDelta") - public ConcurrentMap stateDelta() { + public Map stateDelta() { return stateDelta; } - @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. - public void setStateDelta(ConcurrentMap stateDelta) { - this.stateDelta = stateDelta; + public void setStateDelta(Map stateDelta) { + this.stateDelta = Collections.synchronizedMap(new HashMap<>(stateDelta)); } /** * Removes a key from the state delta. * * @param key The key to remove. + * @deprecated Use {@link #stateDelta()}.put(key, null) instead. */ + @Deprecated public void removeStateByKey(String key) { - stateDelta.put(key, State.REMOVED); + stateDelta().put(key, null); } @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public Map artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { - this.artifactDelta = artifactDelta; - } - - @JsonProperty("deletedArtifactIds") - @JsonInclude(JsonInclude.Include.NON_EMPTY) - public Set deletedArtifactIds() { - return deletedArtifactIds; - } - - public void setDeletedArtifactIds(Set deletedArtifactIds) { - this.deletedArtifactIds = deletedArtifactIds; + public void setArtifactDelta(Map artifactDelta) { + this.artifactDelta = Collections.synchronizedMap(new HashMap<>(artifactDelta)); } @JsonProperty("transferToAgent") @@ -154,23 +142,23 @@ public void setEscalate(boolean escalate) { } @JsonProperty("requestedAuthConfigs") - public ConcurrentMap> requestedAuthConfigs() { + public Map> requestedAuthConfigs() { return requestedAuthConfigs; } - public void setRequestedAuthConfigs( - ConcurrentMap> requestedAuthConfigs) { + public void setRequestedAuthConfigs(Map> requestedAuthConfigs) { this.requestedAuthConfigs = requestedAuthConfigs; } @JsonProperty("requestedToolConfirmations") - public ConcurrentMap requestedToolConfirmations() { + public Map requestedToolConfirmations() { return requestedToolConfirmations; } public void setRequestedToolConfirmations( - ConcurrentMap requestedToolConfirmations) { - this.requestedToolConfirmations = requestedToolConfirmations; + Map requestedToolConfirmations) { + this.requestedToolConfirmations = + Collections.synchronizedMap(new HashMap<>(requestedToolConfirmations)); } @JsonProperty("endOfAgent") @@ -235,7 +223,6 @@ public boolean equals(Object o) { return Objects.equals(skipSummarization, that.skipSummarization) && Objects.equals(stateDelta, that.stateDelta) && Objects.equals(artifactDelta, that.artifactDelta) - && Objects.equals(deletedArtifactIds, that.deletedArtifactIds) && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) @@ -250,7 +237,6 @@ public int hashCode() { skipSummarization, stateDelta, artifactDelta, - deletedArtifactIds, transferToAgent, escalate, requestedAuthConfigs, @@ -262,38 +248,34 @@ public int hashCode() { /** Builder for {@link EventActions}. */ public static class Builder { private Optional skipSummarization; - private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; - private Set deletedArtifactIds; + private Map stateDelta; + private Map artifactDelta; private Optional transferToAgent; private Optional escalate; - private ConcurrentMap> requestedAuthConfigs; - private ConcurrentMap requestedToolConfirmations; + private Map> requestedAuthConfigs; + private Map requestedToolConfirmations; private boolean endOfAgent = false; private Optional compaction; public Builder() { this.skipSummarization = Optional.empty(); - this.stateDelta = new ConcurrentHashMap<>(); - this.artifactDelta = new ConcurrentHashMap<>(); - this.deletedArtifactIds = new HashSet<>(); + this.stateDelta = new HashMap<>(); + this.artifactDelta = new HashMap<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(); - this.requestedToolConfirmations = new ConcurrentHashMap<>(); + this.requestedAuthConfigs = new HashMap<>(); + this.requestedToolConfirmations = new HashMap<>(); this.compaction = Optional.empty(); } private Builder(EventActions eventActions) { this.skipSummarization = eventActions.skipSummarization(); - this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); - this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); - this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); + this.stateDelta = new HashMap<>(eventActions.stateDelta()); + this.artifactDelta = new HashMap<>(eventActions.artifactDelta()); this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); - this.requestedToolConfirmations = - new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); + this.requestedAuthConfigs = new HashMap<>(eventActions.requestedAuthConfigs()); + this.requestedToolConfirmations = new HashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); this.compaction = eventActions.compaction(); } @@ -307,14 +289,14 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { + public Builder stateDelta(Map value) { this.stateDelta = value; return this; } @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { + public Builder artifactDelta(Map value) { this.artifactDelta = value; return this; } @@ -322,7 +304,7 @@ public Builder artifactDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("deletedArtifactIds") public Builder deletedArtifactIds(Set value) { - this.deletedArtifactIds = value; + value.forEach(v -> artifactDelta.put(v, null)); return this; } @@ -342,16 +324,15 @@ public Builder escalate(boolean escalate) { @CanIgnoreReturnValue @JsonProperty("requestedAuthConfigs") - public Builder requestedAuthConfigs( - ConcurrentMap> value) { + public Builder requestedAuthConfigs(Map> value) { this.requestedAuthConfigs = value; return this; } @CanIgnoreReturnValue @JsonProperty("requestedToolConfirmations") - public Builder requestedToolConfirmations(ConcurrentMap value) { - this.requestedToolConfirmations = value; + public Builder requestedToolConfirmations(Map value) { + this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>(value)); return this; } @@ -385,7 +366,6 @@ public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); this.stateDelta.putAll(other.stateDelta()); this.artifactDelta.putAll(other.artifactDelta()); - this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 0ddfdaea1..8908b4ecb 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -58,6 +58,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -337,7 +338,9 @@ private Single appendNewMessageToSession( // Add state delta if provided if (stateDelta != null && !stateDelta.isEmpty()) { eventBuilder.actions( - EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); + EventActions.builder() + .stateDelta(stateDelta == null ? new HashMap<>() : new HashMap<>(stateDelta)) + .build()); } return this.sessionService.appendEvent(session, eventBuilder.build()); diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 540153460..a16443888 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -23,8 +23,10 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -47,13 +49,36 @@ public interface BaseSessionService { * service should generate a unique ID. * @return The newly created {@link Session} instance. * @throws SessionException if creation fails. + * @deprecated Use {@link #createSession(String, String, Map, String)} instead. */ + @Deprecated Single createSession( String appName, String userId, @Nullable ConcurrentMap state, @Nullable String sessionId); + /** + * Creates a new session with the specified parameters. + * + * @param appName The name of the application associated with the session. + * @param userId The identifier for the user associated with the session. + * @param state An optional map representing the initial state of the session. Can be null or + * empty. + * @param sessionId An optional client-provided identifier for the session. If empty or null, the + * service should generate a unique ID. + * @return The newly created {@link Session} instance. + * @throws SessionException if creation fails. + */ + default Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { + return createSession( + appName, userId, state == null ? null : new ConcurrentHashMap<>(state), sessionId); + } + /** * Creates a new session with the specified application name and user ID, using a default state * (null) and allowing the service to generate a unique session ID. @@ -165,9 +190,9 @@ default Single appendEvent(Session session, Event event) { EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { - ConcurrentMap sessionState = session.state(); + Map sessionState = session.state(); if (sessionState != null) { stateDelta.forEach( (key, value) -> { diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 060fcaf60..5900263ac 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -28,6 +28,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -71,6 +72,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); @@ -81,8 +91,7 @@ public Single createSession( .orElseGet(() -> UUID.randomUUID().toString()); // Ensure state map and events list are mutable for the new session - ConcurrentMap initialState = - (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); + Map initialState = (state == null) ? new HashMap<>() : new HashMap<>(state); List initialEvents = new ArrayList<>(); // Assuming Session constructor or setters allow setting these mutable collections @@ -95,10 +104,7 @@ public Single createSession( .lastUpdateTime(Instant.now()) .build(); - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(resolvedSessionId, newSession); + getUserSessionsMap(appName, userId).put(resolvedSessionId, newSession); // Create a mutable copy for the return value Session returnCopy = copySession(newSession); @@ -114,11 +120,7 @@ public Maybe getSession( Objects.requireNonNull(sessionId, "sessionId cannot be null"); Objects.requireNonNull(configOpt, "configOpt cannot be null"); - Session storedSession = - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .get(sessionId); + Session storedSession = getUserSessionsMap(appName, userId).get(sessionId); if (storedSession == null) { return Maybe.empty(); @@ -165,10 +167,9 @@ public Single listSessions(String appName, String userId) Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); - Map userSessionsMap = - sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId); + Map userSessionsMap = getUserSessionsMap(appName, userId); - if (userSessionsMap == null || userSessionsMap.isEmpty()) { + if (userSessionsMap.isEmpty()) { return Single.just(ListSessionsResponse.builder().build()); } @@ -185,13 +186,7 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap> appSessionsMap = sessions.get(appName); - if (appSessionsMap != null) { - ConcurrentMap userSessionsMap = appSessionsMap.get(userId); - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); - } - } + getUserSessionsMap(appName, userId).remove(sessionId); return Completable.complete(); } @@ -201,11 +196,7 @@ public Single listEvents(String appName, String userId, Stri Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - Session storedSession = - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .get(sessionId); + Session storedSession = getUserSessionsMap(appName, userId).get(sessionId); if (storedSession == null) { return Single.just(ListEventsResponse.builder().build()); @@ -237,30 +228,21 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - if (value == State.REMOVED) { - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .remove(appStateKey); + ConcurrentMap currentAppState = getAppStateMap(appName); + if (value == null) { + currentAppState.remove(appStateKey); } else { - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .put(appStateKey, value); + currentAppState.put(appStateKey, value); } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - if (value == State.REMOVED) { - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .remove(userStateKey); + if (value == null) { + getUserStateMap(appName, userId).remove(userStateKey); } else { - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(userStateKey, value); + getUserStateMap(appName, userId).put(userStateKey, value); } } else { - if (value == State.REMOVED) { + if (value == null) { session.state().remove(key); } else { session.state().put(key, value); @@ -274,10 +256,7 @@ public Single appendEvent(Session session, Event event) { session.lastUpdateTime(getInstantFromEvent(event)); // --- Update the session stored in this service --- - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(sessionId, session); + getUserSessionsMap(appName, userId).put(sessionId, session); mergeWithGlobalState(appName, userId, session); @@ -304,7 +283,7 @@ private Session copySession(Session original) { return Session.builder(original.id()) .appName(original.appName()) .userId(original.userId()) - .state(new ConcurrentHashMap<>(original.state())) + .state(new HashMap<>(original.state())) .events(new ArrayList<>(original.events())) .lastUpdateTime(original.lastUpdateTime()) .build(); @@ -324,13 +303,10 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess Map sessionState = session.state(); // Merge App State directly into the session's state map - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + getAppStateMap(appName) .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + getUserStateMap(appName, userId) .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); return session; @@ -355,4 +331,20 @@ private List prepareSessionsForListResponse( .map(s -> mergeWithGlobalState(appName, userId, s)) .collect(toCollection(ArrayList::new)); } + + private ConcurrentMap getAppStateMap(String appName) { + return appState.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()); + } + + private ConcurrentMap getUserStateMap(String appName, String userId) { + return userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()); + } + + private ConcurrentMap getUserSessionsMap(String appName, String userId) { + return sessions + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()); + } } diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index 3bf27b55e..0f773c942 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -26,8 +26,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.Map; /** A {@link Session} object that encapsulates the {@link State} and {@link Event}s of a session. */ @JsonDeserialize(builder = Session.Builder.class) @@ -53,7 +52,7 @@ public static final class Builder { private String id; private String appName; private String userId; - private State state = new State(new ConcurrentHashMap<>()); + private State state = new State(); private List events = new ArrayList<>(); private Instant lastUpdateTime = Instant.EPOCH; @@ -79,7 +78,7 @@ public Builder state(State state) { @CanIgnoreReturnValue @JsonProperty("state") - public Builder state(ConcurrentMap state) { + public Builder state(Map state) { this.state = new State(state); return this; } @@ -135,7 +134,7 @@ public String id() { } @JsonProperty("state") - public ConcurrentMap state() { + public Map state() { return state; } diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 71b072695..6e188886b 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -37,8 +37,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -99,7 +97,7 @@ static String convertEventToJson(Event event, boolean useIsoString) { Map actionsJson = new HashMap<>(); EventActions actions = event.actions(); actions.skipSummarization().ifPresent(v -> actionsJson.put("skipSummarization", v)); - actionsJson.put("stateDelta", stateDeltaToJson(actions.stateDelta())); + actionsJson.put("stateDelta", actions.stateDelta()); putIfNotEmpty(actionsJson, "artifactDelta", actions.artifactDelta()); actions .transferToAgent() @@ -168,9 +166,7 @@ static Event fromApiEvent(Map apiEvent) { eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); Object artifactDelta = actionsMap.get("artifactDelta"); eventActionsBuilder.artifactDelta( - artifactDelta != null - ? convertToArtifactDeltaMap(artifactDelta) - : new ConcurrentHashMap<>()); + artifactDelta != null ? convertToArtifactDeltaMap(artifactDelta) : new HashMap<>()); String transferAgent = (String) actionsMap.get("transferAgent"); if (transferAgent == null) { transferAgent = (String) actionsMap.get("transferToAgent"); @@ -186,12 +182,12 @@ static Event fromApiEvent(Map apiEvent) { } eventActionsBuilder.requestedAuthConfigs( Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) - .map(SessionJsonConverter::asConcurrentMapOfConcurrentMaps) - .orElse(new ConcurrentHashMap<>())); + .map(value -> (Map>) value) + .orElse(new HashMap<>())); eventActionsBuilder.requestedToolConfirmations( Optional.ofNullable(actionsMap.get("requestedToolConfirmations")) - .map(SessionJsonConverter::asConcurrentMapOfToolConfirmations) - .orElse(new ConcurrentHashMap<>())); + .map(SessionJsonConverter::asMapOfToolConfirmations) + .orElse(new HashMap<>())); } Event event = @@ -247,29 +243,11 @@ static Event fromApiEvent(Map apiEvent) { } @SuppressWarnings("unchecked") // stateDeltaFromMap is a Map from JSON. - private static ConcurrentMap stateDeltaFromJson(Object stateDeltaFromMap) { + private static Map stateDeltaFromJson(Object stateDeltaFromMap) { if (stateDeltaFromMap == null) { - return new ConcurrentHashMap<>(); + return new HashMap<>(); } - return ((Map) stateDeltaFromMap) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> - map.put( - entry.getKey(), - entry.getValue() == null ? State.REMOVED : entry.getValue()), - ConcurrentHashMap::putAll); - } - - private static Map stateDeltaToJson(Map stateDelta) { - return stateDelta.entrySet().stream() - .collect( - HashMap::new, - (map, entry) -> - map.put( - entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()), - HashMap::putAll); + return (Map) stateDeltaFromMap; } /** @@ -291,18 +269,18 @@ private static Instant convertToInstant(Object timestampObj) { } /** - * Converts a raw object from "artifactDelta" into a {@link ConcurrentMap} of {@link String} to - * {@link Part}. + * Converts a raw object from "artifactDelta" into a {@link Map} of {@link String} to {@link + * Integer}. * - * @param artifactDeltaObj The raw object from which to parse the artifact delta. - * @return A {@link ConcurrentMap} representing the artifact delta. + * @param artifactDeltaObj The raw object from which to parse the artifact deltas. + * @return A {@link Map} representing the artifact deltas. */ - @SuppressWarnings("unchecked") - private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { + @SuppressWarnings("unchecked") // artifactDeltaObj is a Map from JSON. + private static Map convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { - return new ConcurrentHashMap<>(); + return new HashMap<>(); } - ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); + HashMap artifactDeltaMap = new HashMap<>(); Map rawMap = (Map) artifactDeltaObj; for (Map.Entry entry : rawMap.entrySet()) { try { @@ -316,34 +294,17 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object a return artifactDeltaMap; } - /** - * Converts a nested map into a {@link ConcurrentMap} of {@link ConcurrentMap}s. - * - * @return thread-safe nested map. - */ - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. - private static ConcurrentMap> - asConcurrentMapOfConcurrentMaps(Object value) { - return ((Map>) value) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> map.put(entry.getKey(), new ConcurrentHashMap<>(entry.getValue())), - ConcurrentHashMap::putAll); - } - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. - private static ConcurrentMap asConcurrentMapOfToolConfirmations( - Object value) { + private static Map asMapOfToolConfirmations(Object value) { return ((Map) value) .entrySet().stream() .collect( - ConcurrentHashMap::new, + HashMap::new, (map, entry) -> map.put( entry.getKey(), objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)), - ConcurrentHashMap::putAll); + HashMap::putAll); } private static void putIfNotEmpty(Map map, String key, Map values) { diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..b9fbc921f 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -16,36 +16,54 @@ package com.google.adk.sessions; -import com.fasterxml.jackson.annotation.JsonValue; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; -import java.util.Objects; +import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import javax.annotation.Nullable; /** A {@link State} object that also keeps track of the changes to the state. */ -@SuppressWarnings("ShouldNotSubclass") -public final class State implements ConcurrentMap { +@SuppressWarnings("ShouldNotSubclass") // Implementing Map is the desired interface for State. +public final class State implements Map { public static final String APP_PREFIX = "app:"; public static final String USER_PREFIX = "user:"; public static final String TEMP_PREFIX = "temp:"; /** Sentinel object to mark removed entries in the delta map. */ - public static final Object REMOVED = RemovedSentinel.INSTANCE; + public static final Object REMOVED = null; - private final ConcurrentMap state; - private final ConcurrentMap delta; + /** The underlying state map. A Map that supports null values as "removed". */ + private final Map state; - public State(ConcurrentMap state) { - this(state, new ConcurrentHashMap<>()); + /** The delta map. A Map that supports null values as "removed". */ + private final Map delta; + + public State() { + this(null, null); + } + + public State(@Nullable Map state) { + this(state, null); } - public State(ConcurrentMap state, ConcurrentMap delta) { - this.state = Objects.requireNonNull(state); - this.delta = delta; + /** + * Creates a {@link State} object with the given state and delta maps. + * + * @param state the underlying state map. The Map needs to accept null values. + * @param delta the delta map. The Map needs to accept null values. + */ + public State(@Nullable Map state, @Nullable Map delta) { + this.state = Optional.ofNullable(state).orElseGet(State::createSynchronizedHashMap); + this.delta = Optional.ofNullable(delta).orElseGet(State::createSynchronizedHashMap); + } + + /** Creates a new synchronized {@link HashMap}. */ + private static Map createSynchronizedHashMap() { + return Collections.synchronizedMap(new HashMap<>()); } @Override @@ -124,7 +142,7 @@ public void putAll(Map m) { @Override public Object remove(Object key) { if (state.containsKey(key)) { - delta.put((String) key, REMOVED); + delta.put((String) key, null); } return state.remove(key); } @@ -133,7 +151,7 @@ public Object remove(Object key) { public boolean remove(Object key, Object value) { boolean removed = state.remove(key, value); if (removed) { - delta.put((String) key, REMOVED); + delta.put((String) key, null); } return removed; } @@ -169,17 +187,4 @@ public Collection values() { public boolean hasDelta() { return !delta.isEmpty(); } - - private static final class RemovedSentinel { - public static final RemovedSentinel INSTANCE = new RemovedSentinel(); - - private RemovedSentinel() { - // Enforce singleton. - } - - @JsonValue - public String toJson() { - return "__ADK_SENTINEL_REMOVED__"; - } - } } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index d35bbccae..718738b92 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -14,10 +14,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; @@ -51,8 +51,8 @@ final class VertexAiClient { } Maybe createSession( - String reasoningEngineId, String userId, ConcurrentMap state) { - ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); + String reasoningEngineId, String userId, Map state) { + Map sessionJsonMap = new HashMap<>(); sessionJsonMap.put("userId", userId); if (state != null) { sessionJsonMap.put("sessionState", state); diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 7878daf22..cf2de0e53 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -33,10 +33,10 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -76,6 +76,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); return client @@ -93,20 +102,20 @@ private static Session parseSession( .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) .orElse(fallbackSessionId); Instant updateTimestamp = Instant.parse(getSessionResponseMap.get("updateTime").asText()); - ConcurrentMap sessionState = null; + Map sessionState = null; if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { JsonNode sessionStateNode = getSessionResponseMap.get("sessionState"); if (sessionStateNode != null) { sessionState = objectMapper.convertValue( - sessionStateNode, new TypeReference>() {}); + sessionStateNode, new TypeReference>() {}); } } return Session.builder(sessId) .appName(appName) .userId(userId) .lastUpdateTime(updateTimestamp) - .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) + .state(sessionState == null ? new HashMap<>() : sessionState) .build(); } @@ -140,10 +149,10 @@ private ListSessionsResponse parseListSessionsResponse( .userId(userId) .state( apiSession.get("sessionState") == null - ? new ConcurrentHashMap<>() + ? new HashMap<>() : objectMapper.convertValue( apiSession.get("sessionState"), - new TypeReference>() {})) + new TypeReference>() {})) .lastUpdateTime(updateTimestamp) .build(); sessions.add(session); @@ -168,8 +177,7 @@ private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) return ListEventsResponse.builder() .events( objectMapper - .convertValue( - sessionEventsNode, new TypeReference>>() {}) + .convertValue(sessionEventsNode, new TypeReference>>() {}) .stream() .map(SessionJsonConverter::fromApiEvent) .collect(toCollection(ArrayList::new))) @@ -193,12 +201,12 @@ public Maybe getSession( .map(updateTime -> Instant.parse(updateTime.asText())) .orElse(null); - ConcurrentMap sessionState = new ConcurrentHashMap<>(); + Map sessionState = new HashMap<>(); if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { sessionState.putAll( objectMapper.convertValue( getSessionResponseMap.get("sessionState"), - new TypeReference>() {})); + new TypeReference>() {})); } return listEvents(appName, userId, sessionId) diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 94cd399df..0920016ac 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -18,11 +18,11 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -44,11 +44,13 @@ public final class EventActionsTest { @Test public void toBuilder_createsBuilderWithSameValues() { + Map artifactDelta = new HashMap<>(); + artifactDelta.put("d1", null); EventActions eventActionsWithSkipSummarization = EventActions.builder() .skipSummarization(true) .compaction(COMPACTION) - .deletedArtifactIds(ImmutableSet.of("d1")) + .artifactDelta(artifactDelta) .build(); EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build(); @@ -59,12 +61,14 @@ public void toBuilder_createsBuilderWithSameValues() { @Test public void merge_mergesAllFields() { + Map artifactDelta1 = new HashMap<>(); + artifactDelta1.put("artifact1", 1); + artifactDelta1.put("deleted1", null); EventActions eventActions1 = EventActions.builder() .skipSummarization(true) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", 1))) - .deletedArtifactIds(ImmutableSet.of("deleted1")) + .stateDelta(new HashMap<>(ImmutableMap.of("key1", "value1"))) + .artifactDelta(artifactDelta1) .requestedAuthConfigs( new ConcurrentHashMap<>( ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) @@ -72,18 +76,18 @@ public void merge_mergesAllFields() { new ConcurrentHashMap<>(ImmutableMap.of("tool1", TOOL_CONFIRMATION))) .compaction(COMPACTION) .build(); + Map artifactDelta2 = new HashMap<>(); + artifactDelta2.put("artifact2", 2); + artifactDelta2.put("deleted2", null); EventActions eventActions2 = EventActions.builder() - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", 2))) - .deletedArtifactIds(ImmutableSet.of("deleted2")) + .stateDelta(new HashMap<>(ImmutableMap.of("key2", "value2"))) + .artifactDelta(artifactDelta2) .transferToAgent("agentId") .escalate(true) .requestedAuthConfigs( - new ConcurrentHashMap<>( - ImmutableMap.of("config2", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) - .requestedToolConfirmations( - new ConcurrentHashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) + new HashMap<>(ImmutableMap.of("config2", new HashMap<>(ImmutableMap.of("k", "v"))))) + .requestedToolConfirmations(new HashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) .endOfAgent(true) .build(); @@ -91,16 +95,13 @@ public void merge_mergesAllFields() { assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); - assertThat(merged.artifactDelta()).containsExactly("artifact1", 1, "artifact2", 2); - assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); + assertThat(merged.artifactDelta()) + .containsExactly("artifact1", 1, "artifact2", 2, "deleted1", null, "deleted2", null); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); assertThat(merged.requestedAuthConfigs()) .containsExactly( - "config1", - new ConcurrentHashMap<>(ImmutableMap.of("k", "v")), - "config2", - new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))); + "config1", ImmutableMap.of("k", "v"), "config2", ImmutableMap.of("k", "v")); assertThat(merged.requestedToolConfirmations()) .containsExactly("tool1", TOOL_CONFIRMATION, "tool2", TOOL_CONFIRMATION); assertThat(merged.endOfAgent()).isTrue(); @@ -111,23 +112,8 @@ public void merge_mergesAllFields() { public void removeStateByKey_marksKeyAsRemoved() { EventActions eventActions = new EventActions(); eventActions.stateDelta().put("key1", "value1"); - eventActions.removeStateByKey("key1"); - - assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); - } - - @Test - public void jsonSerialization_works() throws Exception { - EventActions eventActions = - EventActions.builder() - .deletedArtifactIds(ImmutableSet.of("d1", "d2")) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))) - .build(); - - String json = eventActions.toJson(); - EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + eventActions.stateDelta().put("key1", null); - assertThat(deserialized).isEqualTo(eventActions); - assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); + assertThat(eventActions.stateDelta()).containsExactly("key1", null); } } diff --git a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java index 345314256..e110e7245 100644 --- a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java @@ -44,7 +44,7 @@ public class GlobalInstructionPluginTest { @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); @Mock private CallbackContext mockCallbackContext; @Mock private InvocationContext mockInvocationContext; - private final State state = new State(new ConcurrentHashMap<>()); + private final State state = new State(); private final Session session = Session.builder("session_id").state(state).build(); @Mock private BaseArtifactService mockArtifactService; diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 6223dd2f0..381ad86a7 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,9 +20,9 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -84,10 +84,10 @@ public void lifecycle_listSessions() { Session session = sessionService - .createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1") + .createSession("app-name", "user-id", new HashMap<>(), "session-1") .blockingGet(); - ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + Map stateDelta = new HashMap<>(); stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); @@ -130,11 +130,9 @@ public void lifecycle_deleteSession() { public void appendEvent_updatesSessionState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); - ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + Map stateDelta = new HashMap<>(); stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); @@ -167,11 +165,9 @@ public void appendEvent_updatesSessionState() { public void appendEvent_removesState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); - ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + Map stateDeltaAdd = new HashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); stateDeltaAdd.put("_app_appKey", "appValue"); stateDeltaAdd.put("_user_userKey", "userValue"); @@ -193,11 +189,11 @@ public void appendEvent_removesState() { assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue"); // Prepare and append event to remove state - ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); - stateDeltaRemove.put("sessionKey", State.REMOVED); - stateDeltaRemove.put("_app_appKey", State.REMOVED); - stateDeltaRemove.put("_user_userKey", State.REMOVED); - stateDeltaRemove.put("temp:tempKey", State.REMOVED); + Map stateDeltaRemove = new HashMap<>(); + stateDeltaRemove.put("sessionKey", null); + stateDeltaRemove.put("_app_appKey", null); + stateDeltaRemove.put("_user_userKey", null); + stateDeltaRemove.put("temp:tempKey", null); Event eventRemove = Event.builder() @@ -221,12 +217,10 @@ public void appendEvent_removesState() { public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); // Agent 1 writes to temp state - ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); + Map stateDelta1 = new HashMap<>(); stateDelta1.put("temp:agent1_output", "data"); Event event1 = Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build(); @@ -237,9 +231,9 @@ public void sequentialAgents_shareTempState() { // Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes // "agent1_output" - ConcurrentMap stateDelta2 = new ConcurrentHashMap<>(); + Map stateDelta2 = new HashMap<>(); stateDelta2.put("temp:agent2_output", "processed_data"); - stateDelta2.put("temp:agent1_output", State.REMOVED); + stateDelta2.put("temp:agent1_output", null); Event event2 = Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build(); unused = sessionService.appendEvent(session, event2).blockingGet(); diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index f6120cf08..f772bf5fd 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -23,8 +23,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,8 +36,8 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio EventActions actions = EventActions.builder() .skipSummarization(true) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact", 1))) + .stateDelta(new HashMap<>(ImmutableMap.of("key", "value"))) + .artifactDelta(new HashMap<>(ImmutableMap.of("artifact", 1))) .transferToAgent("agent") .escalate(true) .build(); @@ -174,10 +172,10 @@ public void fromApiEvent_withTransferToAgent_success() { @Test public void convertEventToJson_complexActions_success() throws JsonProcessingException { - ConcurrentMap> authConfigs = new ConcurrentHashMap<>(); - authConfigs.put("auth1", new ConcurrentHashMap<>(ImmutableMap.of("param1", "value1"))); + Map> authConfigs = new HashMap<>(); + authConfigs.put("auth1", new HashMap<>(ImmutableMap.of("param1", "value1"))); - ConcurrentMap toolConfirmations = new ConcurrentHashMap<>(); + Map toolConfirmations = new HashMap<>(); toolConfirmations.put( "tool1", ToolConfirmation.builder().hint("hint1").confirmed(true).build()); @@ -225,6 +223,7 @@ public void convertEventToJson_complexActions_success() throws JsonProcessingExc } @Test + // Testing conversion of raw Map following a known schema. public void fromApiEvent_complexActions_success() { Map apiEvent = new HashMap<>(); apiEvent.put("name", "sessions/123/events/456"); @@ -333,11 +332,10 @@ public void fromApiEvent_missingMetadataFields_success() { @Test public void convertEventToJson_withStateRemoved_success() throws JsonProcessingException { - EventActions actions = - EventActions.builder() - .stateDelta( - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", State.REMOVED))) - .build(); + HashMap stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", null); + EventActions actions = EventActions.builder().stateDelta(stateDelta).build(); Event event = Event.builder() @@ -422,6 +420,6 @@ public void fromApiEvent_withNullStateDeltaValue_success() { EventActions eventActions = event.actions(); assertThat(eventActions.stateDelta()).containsEntry("key1", "value1"); - assertThat(eventActions.stateDelta()).containsEntry("key2", State.REMOVED); + assertThat(eventActions.stateDelta()).containsEntry("key2", null); } } diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 36eab1d16..fc2e2ed96 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -104,7 +104,7 @@ public class VertexAiSessionServiceTest { ] """; - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Casting raw Object from JSON Map to Map. private static Session getMockSession() throws Exception { Map sessionJson = mapper.readValue(MOCK_SESSION_STRING_1, new TypeReference>() {}); @@ -341,13 +341,12 @@ public void listEmptySession_success() { @Test public void appendEvent_withStateRemoved_updatesSessionState() { String userId = "userB"; - ConcurrentMap initialState = - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + ImmutableMap initialState = ImmutableMap.of("key1", "value1", "key2", "value2"); Session session = vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); - ConcurrentMap stateDelta = - new ConcurrentHashMap<>(ImmutableMap.of("key2", State.REMOVED)); + Map stateDelta = new HashMap<>(); + stateDelta.put("key2", null); Event event = Event.builder() .invocationId("456") diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index f29298bce..403c73ea2 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -59,8 +59,7 @@ class ReplayPluginTest { void setUp() { plugin = new ReplayPlugin(); mockSession = mock(Session.class); - sessionState = new ConcurrentHashMap<>(); - state = new State(sessionState); + state = new State(); when(mockSession.state()).thenReturn(sessionState); }