diff --git a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java index d3295a6ef..ae5a26f99 100644 --- a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java +++ b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java @@ -46,7 +46,6 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; @@ -85,10 +84,19 @@ private CollectionReference getSessionsCollection(String userId) { .collection(SESSION_COLLECTION_NAME); } + @Override + public Single createSession( + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + /** Creates a new session in Firestore. */ @Override public Single createSession( - String appName, String userId, ConcurrentMap state, String sessionId) { + String appName, String userId, Map state, String sessionId) { return Single.fromCallable( () -> { Objects.requireNonNull(appName, "appName cannot be null"); @@ -100,21 +108,17 @@ public Single createSession( .filter(s -> !s.isEmpty()) .orElseGet(() -> UUID.randomUUID().toString()); - ConcurrentMap initialState = - (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); logger.info( "Creating session for userId: {} with sessionId: {} and initial state: {}", userId, resolvedSessionId, - initialState); - List initialEvents = new ArrayList<>(); + state); Instant now = Instant.now(); Session newSession = Session.builder(resolvedSessionId) .appName(appName) .userId(userId) - .state(initialState) - .events(initialEvents) + .state(state) .lastUpdateTime(now) .build(); @@ -200,8 +204,7 @@ public Maybe getSession( }) .map( events -> { - ConcurrentMap state = - new ConcurrentHashMap<>((Map) data.get(STATE_KEY)); + Map state = (Map) data.get(STATE_KEY); return Session.builder((String) data.get(ID_KEY)) .appName((String) data.get(APP_NAME_KEY)) .userId((String) data.get(USER_ID_KEY)) @@ -451,8 +454,6 @@ public Single listSessions(String appName, String userId) .appName((String) data.get(APP_NAME_KEY)) .userId((String) data.get(USER_ID_KEY)) .lastUpdateTime(Instant.parse((String) data.get(UPDATE_TIME_KEY))) - .state(new ConcurrentHashMap<>()) // Empty state - .events(new ArrayList<>()) // Empty events .build(); sessions.add(session); } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..a4fe4461d 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,14 +19,13 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.JsonBaseModel; -import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; -import java.util.HashSet; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; /** Represents the actions attached to an event. */ @@ -35,39 +34,37 @@ public class EventActions extends JsonBaseModel { private Optional skipSummarization; - private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; - private Set deletedArtifactIds; + private Map stateDelta; + private Map artifactDelta; private Optional transferToAgent; private Optional escalate; - private ConcurrentMap> requestedAuthConfigs; - private ConcurrentMap requestedToolConfirmations; + private Map> requestedAuthConfigs; + private Map requestedToolConfirmations; private boolean endOfAgent; private Optional compaction; /** Default constructor for Jackson. */ public EventActions() { this.skipSummarization = Optional.empty(); - this.stateDelta = new ConcurrentHashMap<>(); - this.artifactDelta = new ConcurrentHashMap<>(); - this.deletedArtifactIds = new HashSet<>(); + this.stateDelta = Collections.synchronizedMap(new HashMap<>()); + this.artifactDelta = Collections.synchronizedMap(new HashMap<>()); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(); - this.requestedToolConfirmations = new ConcurrentHashMap<>(); + this.requestedAuthConfigs = Collections.synchronizedMap(new HashMap<>()); + this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>()); this.endOfAgent = false; this.compaction = Optional.empty(); } private EventActions(Builder builder) { this.skipSummarization = builder.skipSummarization; - this.stateDelta = builder.stateDelta; - this.artifactDelta = builder.artifactDelta; - this.deletedArtifactIds = builder.deletedArtifactIds; + this.stateDelta = Collections.synchronizedMap(builder.stateDelta); + this.artifactDelta = Collections.synchronizedMap(builder.artifactDelta); this.transferToAgent = builder.transferToAgent; this.escalate = builder.escalate; - this.requestedAuthConfigs = builder.requestedAuthConfigs; - this.requestedToolConfirmations = builder.requestedToolConfirmations; + this.requestedAuthConfigs = Collections.synchronizedMap(builder.requestedAuthConfigs); + this.requestedToolConfirmations = + Collections.synchronizedMap(builder.requestedToolConfirmations); this.endOfAgent = builder.endOfAgent; this.compaction = builder.compaction; } @@ -90,41 +87,32 @@ public void setSkipSummarization(boolean skipSummarization) { } @JsonProperty("stateDelta") - public ConcurrentMap stateDelta() { + public Map stateDelta() { return stateDelta; } - @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. - public void setStateDelta(ConcurrentMap stateDelta) { - this.stateDelta = stateDelta; + public void setStateDelta(Map stateDelta) { + this.stateDelta = Collections.synchronizedMap(new HashMap<>(stateDelta)); } /** * Removes a key from the state delta. * * @param key The key to remove. + * @deprecated Use {@link #stateDelta()}.put(key, null) instead. */ + @Deprecated public void removeStateByKey(String key) { - stateDelta.put(key, State.REMOVED); + stateDelta().put(key, null); } @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public Map artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { - this.artifactDelta = artifactDelta; - } - - @JsonProperty("deletedArtifactIds") - @JsonInclude(JsonInclude.Include.NON_EMPTY) - public Set deletedArtifactIds() { - return deletedArtifactIds; - } - - public void setDeletedArtifactIds(Set deletedArtifactIds) { - this.deletedArtifactIds = deletedArtifactIds; + public void setArtifactDelta(Map artifactDelta) { + this.artifactDelta = Collections.synchronizedMap(new HashMap<>(artifactDelta)); } @JsonProperty("transferToAgent") @@ -154,23 +142,23 @@ public void setEscalate(boolean escalate) { } @JsonProperty("requestedAuthConfigs") - public ConcurrentMap> requestedAuthConfigs() { + public Map> requestedAuthConfigs() { return requestedAuthConfigs; } - public void setRequestedAuthConfigs( - ConcurrentMap> requestedAuthConfigs) { + public void setRequestedAuthConfigs(Map> requestedAuthConfigs) { this.requestedAuthConfigs = requestedAuthConfigs; } @JsonProperty("requestedToolConfirmations") - public ConcurrentMap requestedToolConfirmations() { + public Map requestedToolConfirmations() { return requestedToolConfirmations; } public void setRequestedToolConfirmations( - ConcurrentMap requestedToolConfirmations) { - this.requestedToolConfirmations = requestedToolConfirmations; + Map requestedToolConfirmations) { + this.requestedToolConfirmations = + Collections.synchronizedMap(new HashMap<>(requestedToolConfirmations)); } @JsonProperty("endOfAgent") @@ -235,7 +223,6 @@ public boolean equals(Object o) { return Objects.equals(skipSummarization, that.skipSummarization) && Objects.equals(stateDelta, that.stateDelta) && Objects.equals(artifactDelta, that.artifactDelta) - && Objects.equals(deletedArtifactIds, that.deletedArtifactIds) && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) @@ -250,7 +237,6 @@ public int hashCode() { skipSummarization, stateDelta, artifactDelta, - deletedArtifactIds, transferToAgent, escalate, requestedAuthConfigs, @@ -262,38 +248,34 @@ public int hashCode() { /** Builder for {@link EventActions}. */ public static class Builder { private Optional skipSummarization; - private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; - private Set deletedArtifactIds; + private Map stateDelta; + private Map artifactDelta; private Optional transferToAgent; private Optional escalate; - private ConcurrentMap> requestedAuthConfigs; - private ConcurrentMap requestedToolConfirmations; + private Map> requestedAuthConfigs; + private Map requestedToolConfirmations; private boolean endOfAgent = false; private Optional compaction; public Builder() { this.skipSummarization = Optional.empty(); - this.stateDelta = new ConcurrentHashMap<>(); - this.artifactDelta = new ConcurrentHashMap<>(); - this.deletedArtifactIds = new HashSet<>(); + this.stateDelta = new HashMap<>(); + this.artifactDelta = new HashMap<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(); - this.requestedToolConfirmations = new ConcurrentHashMap<>(); + this.requestedAuthConfigs = new HashMap<>(); + this.requestedToolConfirmations = new HashMap<>(); this.compaction = Optional.empty(); } private Builder(EventActions eventActions) { this.skipSummarization = eventActions.skipSummarization(); - this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); - this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); - this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); + this.stateDelta = new HashMap<>(eventActions.stateDelta()); + this.artifactDelta = new HashMap<>(eventActions.artifactDelta()); this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); - this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); - this.requestedToolConfirmations = - new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); + this.requestedAuthConfigs = new HashMap<>(eventActions.requestedAuthConfigs()); + this.requestedToolConfirmations = new HashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); this.compaction = eventActions.compaction(); } @@ -307,14 +289,14 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { + public Builder stateDelta(Map value) { this.stateDelta = value; return this; } @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { + public Builder artifactDelta(Map value) { this.artifactDelta = value; return this; } @@ -322,7 +304,7 @@ public Builder artifactDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("deletedArtifactIds") public Builder deletedArtifactIds(Set value) { - this.deletedArtifactIds = value; + value.forEach(v -> artifactDelta.put(v, null)); return this; } @@ -342,16 +324,15 @@ public Builder escalate(boolean escalate) { @CanIgnoreReturnValue @JsonProperty("requestedAuthConfigs") - public Builder requestedAuthConfigs( - ConcurrentMap> value) { + public Builder requestedAuthConfigs(Map> value) { this.requestedAuthConfigs = value; return this; } @CanIgnoreReturnValue @JsonProperty("requestedToolConfirmations") - public Builder requestedToolConfirmations(ConcurrentMap value) { - this.requestedToolConfirmations = value; + public Builder requestedToolConfirmations(Map value) { + this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>(value)); return this; } @@ -385,7 +366,6 @@ public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); this.stateDelta.putAll(other.stateDelta()); this.artifactDelta.putAll(other.artifactDelta()); - this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 0ddfdaea1..8908b4ecb 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -58,6 +58,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -337,7 +338,9 @@ private Single appendNewMessageToSession( // Add state delta if provided if (stateDelta != null && !stateDelta.isEmpty()) { eventBuilder.actions( - EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); + EventActions.builder() + .stateDelta(stateDelta == null ? new HashMap<>() : new HashMap<>(stateDelta)) + .build()); } return this.sessionService.appendEvent(session, eventBuilder.build()); diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 540153460..a16443888 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -23,8 +23,10 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -47,13 +49,36 @@ public interface BaseSessionService { * service should generate a unique ID. * @return The newly created {@link Session} instance. * @throws SessionException if creation fails. + * @deprecated Use {@link #createSession(String, String, Map, String)} instead. */ + @Deprecated Single createSession( String appName, String userId, @Nullable ConcurrentMap state, @Nullable String sessionId); + /** + * Creates a new session with the specified parameters. + * + * @param appName The name of the application associated with the session. + * @param userId The identifier for the user associated with the session. + * @param state An optional map representing the initial state of the session. Can be null or + * empty. + * @param sessionId An optional client-provided identifier for the session. If empty or null, the + * service should generate a unique ID. + * @return The newly created {@link Session} instance. + * @throws SessionException if creation fails. + */ + default Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { + return createSession( + appName, userId, state == null ? null : new ConcurrentHashMap<>(state), sessionId); + } + /** * Creates a new session with the specified application name and user ID, using a default state * (null) and allowing the service to generate a unique session ID. @@ -165,9 +190,9 @@ default Single appendEvent(Session session, Event event) { EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { - ConcurrentMap sessionState = session.state(); + Map sessionState = session.state(); if (sessionState != null) { stateDelta.forEach( (key, value) -> { diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 060fcaf60..5900263ac 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -28,6 +28,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -71,6 +72,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); @@ -81,8 +91,7 @@ public Single createSession( .orElseGet(() -> UUID.randomUUID().toString()); // Ensure state map and events list are mutable for the new session - ConcurrentMap initialState = - (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); + Map initialState = (state == null) ? new HashMap<>() : new HashMap<>(state); List initialEvents = new ArrayList<>(); // Assuming Session constructor or setters allow setting these mutable collections @@ -95,10 +104,7 @@ public Single createSession( .lastUpdateTime(Instant.now()) .build(); - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(resolvedSessionId, newSession); + getUserSessionsMap(appName, userId).put(resolvedSessionId, newSession); // Create a mutable copy for the return value Session returnCopy = copySession(newSession); @@ -114,11 +120,7 @@ public Maybe getSession( Objects.requireNonNull(sessionId, "sessionId cannot be null"); Objects.requireNonNull(configOpt, "configOpt cannot be null"); - Session storedSession = - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .get(sessionId); + Session storedSession = getUserSessionsMap(appName, userId).get(sessionId); if (storedSession == null) { return Maybe.empty(); @@ -165,10 +167,9 @@ public Single listSessions(String appName, String userId) Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); - Map userSessionsMap = - sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId); + Map userSessionsMap = getUserSessionsMap(appName, userId); - if (userSessionsMap == null || userSessionsMap.isEmpty()) { + if (userSessionsMap.isEmpty()) { return Single.just(ListSessionsResponse.builder().build()); } @@ -185,13 +186,7 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap> appSessionsMap = sessions.get(appName); - if (appSessionsMap != null) { - ConcurrentMap userSessionsMap = appSessionsMap.get(userId); - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); - } - } + getUserSessionsMap(appName, userId).remove(sessionId); return Completable.complete(); } @@ -201,11 +196,7 @@ public Single listEvents(String appName, String userId, Stri Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - Session storedSession = - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .get(sessionId); + Session storedSession = getUserSessionsMap(appName, userId).get(sessionId); if (storedSession == null) { return Single.just(ListEventsResponse.builder().build()); @@ -237,30 +228,21 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - if (value == State.REMOVED) { - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .remove(appStateKey); + ConcurrentMap currentAppState = getAppStateMap(appName); + if (value == null) { + currentAppState.remove(appStateKey); } else { - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .put(appStateKey, value); + currentAppState.put(appStateKey, value); } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - if (value == State.REMOVED) { - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .remove(userStateKey); + if (value == null) { + getUserStateMap(appName, userId).remove(userStateKey); } else { - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(userStateKey, value); + getUserStateMap(appName, userId).put(userStateKey, value); } } else { - if (value == State.REMOVED) { + if (value == null) { session.state().remove(key); } else { session.state().put(key, value); @@ -274,10 +256,7 @@ public Single appendEvent(Session session, Event event) { session.lastUpdateTime(getInstantFromEvent(event)); // --- Update the session stored in this service --- - sessions - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) - .put(sessionId, session); + getUserSessionsMap(appName, userId).put(sessionId, session); mergeWithGlobalState(appName, userId, session); @@ -304,7 +283,7 @@ private Session copySession(Session original) { return Session.builder(original.id()) .appName(original.appName()) .userId(original.userId()) - .state(new ConcurrentHashMap<>(original.state())) + .state(new HashMap<>(original.state())) .events(new ArrayList<>(original.events())) .lastUpdateTime(original.lastUpdateTime()) .build(); @@ -324,13 +303,10 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess Map sessionState = session.state(); // Merge App State directly into the session's state map - appState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + getAppStateMap(appName) .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); - userState - .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + getUserStateMap(appName, userId) .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); return session; @@ -355,4 +331,20 @@ private List prepareSessionsForListResponse( .map(s -> mergeWithGlobalState(appName, userId, s)) .collect(toCollection(ArrayList::new)); } + + private ConcurrentMap getAppStateMap(String appName) { + return appState.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()); + } + + private ConcurrentMap getUserStateMap(String appName, String userId) { + return userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()); + } + + private ConcurrentMap getUserSessionsMap(String appName, String userId) { + return sessions + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()); + } } diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index 3bf27b55e..0f773c942 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -26,8 +26,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.Map; /** A {@link Session} object that encapsulates the {@link State} and {@link Event}s of a session. */ @JsonDeserialize(builder = Session.Builder.class) @@ -53,7 +52,7 @@ public static final class Builder { private String id; private String appName; private String userId; - private State state = new State(new ConcurrentHashMap<>()); + private State state = new State(); private List events = new ArrayList<>(); private Instant lastUpdateTime = Instant.EPOCH; @@ -79,7 +78,7 @@ public Builder state(State state) { @CanIgnoreReturnValue @JsonProperty("state") - public Builder state(ConcurrentMap state) { + public Builder state(Map state) { this.state = new State(state); return this; } @@ -135,7 +134,7 @@ public String id() { } @JsonProperty("state") - public ConcurrentMap state() { + public Map state() { return state; } diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 71b072695..6e188886b 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -37,8 +37,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -99,7 +97,7 @@ static String convertEventToJson(Event event, boolean useIsoString) { Map actionsJson = new HashMap<>(); EventActions actions = event.actions(); actions.skipSummarization().ifPresent(v -> actionsJson.put("skipSummarization", v)); - actionsJson.put("stateDelta", stateDeltaToJson(actions.stateDelta())); + actionsJson.put("stateDelta", actions.stateDelta()); putIfNotEmpty(actionsJson, "artifactDelta", actions.artifactDelta()); actions .transferToAgent() @@ -168,9 +166,7 @@ static Event fromApiEvent(Map apiEvent) { eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); Object artifactDelta = actionsMap.get("artifactDelta"); eventActionsBuilder.artifactDelta( - artifactDelta != null - ? convertToArtifactDeltaMap(artifactDelta) - : new ConcurrentHashMap<>()); + artifactDelta != null ? convertToArtifactDeltaMap(artifactDelta) : new HashMap<>()); String transferAgent = (String) actionsMap.get("transferAgent"); if (transferAgent == null) { transferAgent = (String) actionsMap.get("transferToAgent"); @@ -186,12 +182,12 @@ static Event fromApiEvent(Map apiEvent) { } eventActionsBuilder.requestedAuthConfigs( Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) - .map(SessionJsonConverter::asConcurrentMapOfConcurrentMaps) - .orElse(new ConcurrentHashMap<>())); + .map(value -> (Map>) value) + .orElse(new HashMap<>())); eventActionsBuilder.requestedToolConfirmations( Optional.ofNullable(actionsMap.get("requestedToolConfirmations")) - .map(SessionJsonConverter::asConcurrentMapOfToolConfirmations) - .orElse(new ConcurrentHashMap<>())); + .map(SessionJsonConverter::asMapOfToolConfirmations) + .orElse(new HashMap<>())); } Event event = @@ -247,29 +243,11 @@ static Event fromApiEvent(Map apiEvent) { } @SuppressWarnings("unchecked") // stateDeltaFromMap is a Map from JSON. - private static ConcurrentMap stateDeltaFromJson(Object stateDeltaFromMap) { + private static Map stateDeltaFromJson(Object stateDeltaFromMap) { if (stateDeltaFromMap == null) { - return new ConcurrentHashMap<>(); + return new HashMap<>(); } - return ((Map) stateDeltaFromMap) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> - map.put( - entry.getKey(), - entry.getValue() == null ? State.REMOVED : entry.getValue()), - ConcurrentHashMap::putAll); - } - - private static Map stateDeltaToJson(Map stateDelta) { - return stateDelta.entrySet().stream() - .collect( - HashMap::new, - (map, entry) -> - map.put( - entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()), - HashMap::putAll); + return (Map) stateDeltaFromMap; } /** @@ -291,18 +269,18 @@ private static Instant convertToInstant(Object timestampObj) { } /** - * Converts a raw object from "artifactDelta" into a {@link ConcurrentMap} of {@link String} to - * {@link Part}. + * Converts a raw object from "artifactDelta" into a {@link Map} of {@link String} to {@link + * Integer}. * - * @param artifactDeltaObj The raw object from which to parse the artifact delta. - * @return A {@link ConcurrentMap} representing the artifact delta. + * @param artifactDeltaObj The raw object from which to parse the artifact deltas. + * @return A {@link Map} representing the artifact deltas. */ - @SuppressWarnings("unchecked") - private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { + @SuppressWarnings("unchecked") // artifactDeltaObj is a Map from JSON. + private static Map convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { - return new ConcurrentHashMap<>(); + return new HashMap<>(); } - ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); + HashMap artifactDeltaMap = new HashMap<>(); Map rawMap = (Map) artifactDeltaObj; for (Map.Entry entry : rawMap.entrySet()) { try { @@ -316,34 +294,17 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object a return artifactDeltaMap; } - /** - * Converts a nested map into a {@link ConcurrentMap} of {@link ConcurrentMap}s. - * - * @return thread-safe nested map. - */ - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. - private static ConcurrentMap> - asConcurrentMapOfConcurrentMaps(Object value) { - return ((Map>) value) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> map.put(entry.getKey(), new ConcurrentHashMap<>(entry.getValue())), - ConcurrentHashMap::putAll); - } - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. - private static ConcurrentMap asConcurrentMapOfToolConfirmations( - Object value) { + private static Map asMapOfToolConfirmations(Object value) { return ((Map) value) .entrySet().stream() .collect( - ConcurrentHashMap::new, + HashMap::new, (map, entry) -> map.put( entry.getKey(), objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)), - ConcurrentHashMap::putAll); + HashMap::putAll); } private static void putIfNotEmpty(Map map, String key, Map values) { diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index ec23857d9..b9fbc921f 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -16,36 +16,54 @@ package com.google.adk.sessions; -import com.fasterxml.jackson.annotation.JsonValue; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; -import java.util.Objects; +import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import javax.annotation.Nullable; /** A {@link State} object that also keeps track of the changes to the state. */ -@SuppressWarnings("ShouldNotSubclass") -public final class State implements ConcurrentMap { +@SuppressWarnings("ShouldNotSubclass") // Implementing Map is the desired interface for State. +public final class State implements Map { public static final String APP_PREFIX = "app:"; public static final String USER_PREFIX = "user:"; public static final String TEMP_PREFIX = "temp:"; /** Sentinel object to mark removed entries in the delta map. */ - public static final Object REMOVED = RemovedSentinel.INSTANCE; + public static final Object REMOVED = null; - private final ConcurrentMap state; - private final ConcurrentMap delta; + /** The underlying state map. A Map that supports null values as "removed". */ + private final Map state; - public State(ConcurrentMap state) { - this(state, new ConcurrentHashMap<>()); + /** The delta map. A Map that supports null values as "removed". */ + private final Map delta; + + public State() { + this(null, null); + } + + public State(@Nullable Map state) { + this(state, null); } - public State(ConcurrentMap state, ConcurrentMap delta) { - this.state = Objects.requireNonNull(state); - this.delta = delta; + /** + * Creates a {@link State} object with the given state and delta maps. + * + * @param state the underlying state map. The Map needs to accept null values. + * @param delta the delta map. The Map needs to accept null values. + */ + public State(@Nullable Map state, @Nullable Map delta) { + this.state = Optional.ofNullable(state).orElseGet(State::createSynchronizedHashMap); + this.delta = Optional.ofNullable(delta).orElseGet(State::createSynchronizedHashMap); + } + + /** Creates a new synchronized {@link HashMap}. */ + private static Map createSynchronizedHashMap() { + return Collections.synchronizedMap(new HashMap<>()); } @Override @@ -124,7 +142,7 @@ public void putAll(Map m) { @Override public Object remove(Object key) { if (state.containsKey(key)) { - delta.put((String) key, REMOVED); + delta.put((String) key, null); } return state.remove(key); } @@ -133,7 +151,7 @@ public Object remove(Object key) { public boolean remove(Object key, Object value) { boolean removed = state.remove(key, value); if (removed) { - delta.put((String) key, REMOVED); + delta.put((String) key, null); } return removed; } @@ -169,17 +187,4 @@ public Collection values() { public boolean hasDelta() { return !delta.isEmpty(); } - - private static final class RemovedSentinel { - public static final RemovedSentinel INSTANCE = new RemovedSentinel(); - - private RemovedSentinel() { - // Enforce singleton. - } - - @JsonValue - public String toJson() { - return "__ADK_SENTINEL_REMOVED__"; - } - } } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index d35bbccae..718738b92 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -14,10 +14,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; @@ -51,8 +51,8 @@ final class VertexAiClient { } Maybe createSession( - String reasoningEngineId, String userId, ConcurrentMap state) { - ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); + String reasoningEngineId, String userId, Map state) { + Map sessionJsonMap = new HashMap<>(); sessionJsonMap.put("userId", userId); if (state != null) { sessionJsonMap.put("sessionState", state); diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 7878daf22..cf2de0e53 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -33,10 +33,10 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -76,6 +76,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); return client @@ -93,20 +102,20 @@ private static Session parseSession( .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) .orElse(fallbackSessionId); Instant updateTimestamp = Instant.parse(getSessionResponseMap.get("updateTime").asText()); - ConcurrentMap sessionState = null; + Map sessionState = null; if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { JsonNode sessionStateNode = getSessionResponseMap.get("sessionState"); if (sessionStateNode != null) { sessionState = objectMapper.convertValue( - sessionStateNode, new TypeReference>() {}); + sessionStateNode, new TypeReference>() {}); } } return Session.builder(sessId) .appName(appName) .userId(userId) .lastUpdateTime(updateTimestamp) - .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) + .state(sessionState == null ? new HashMap<>() : sessionState) .build(); } @@ -140,10 +149,10 @@ private ListSessionsResponse parseListSessionsResponse( .userId(userId) .state( apiSession.get("sessionState") == null - ? new ConcurrentHashMap<>() + ? new HashMap<>() : objectMapper.convertValue( apiSession.get("sessionState"), - new TypeReference>() {})) + new TypeReference>() {})) .lastUpdateTime(updateTimestamp) .build(); sessions.add(session); @@ -168,8 +177,7 @@ private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) return ListEventsResponse.builder() .events( objectMapper - .convertValue( - sessionEventsNode, new TypeReference>>() {}) + .convertValue(sessionEventsNode, new TypeReference>>() {}) .stream() .map(SessionJsonConverter::fromApiEvent) .collect(toCollection(ArrayList::new))) @@ -193,12 +201,12 @@ public Maybe getSession( .map(updateTime -> Instant.parse(updateTime.asText())) .orElse(null); - ConcurrentMap sessionState = new ConcurrentHashMap<>(); + Map sessionState = new HashMap<>(); if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { sessionState.putAll( objectMapper.convertValue( getSessionResponseMap.get("sessionState"), - new TypeReference>() {})); + new TypeReference>() {})); } return listEvents(appName, userId, sessionId) diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 94cd399df..0920016ac 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -18,11 +18,11 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -44,11 +44,13 @@ public final class EventActionsTest { @Test public void toBuilder_createsBuilderWithSameValues() { + Map artifactDelta = new HashMap<>(); + artifactDelta.put("d1", null); EventActions eventActionsWithSkipSummarization = EventActions.builder() .skipSummarization(true) .compaction(COMPACTION) - .deletedArtifactIds(ImmutableSet.of("d1")) + .artifactDelta(artifactDelta) .build(); EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build(); @@ -59,12 +61,14 @@ public void toBuilder_createsBuilderWithSameValues() { @Test public void merge_mergesAllFields() { + Map artifactDelta1 = new HashMap<>(); + artifactDelta1.put("artifact1", 1); + artifactDelta1.put("deleted1", null); EventActions eventActions1 = EventActions.builder() .skipSummarization(true) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", 1))) - .deletedArtifactIds(ImmutableSet.of("deleted1")) + .stateDelta(new HashMap<>(ImmutableMap.of("key1", "value1"))) + .artifactDelta(artifactDelta1) .requestedAuthConfigs( new ConcurrentHashMap<>( ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) @@ -72,18 +76,18 @@ public void merge_mergesAllFields() { new ConcurrentHashMap<>(ImmutableMap.of("tool1", TOOL_CONFIRMATION))) .compaction(COMPACTION) .build(); + Map artifactDelta2 = new HashMap<>(); + artifactDelta2.put("artifact2", 2); + artifactDelta2.put("deleted2", null); EventActions eventActions2 = EventActions.builder() - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", 2))) - .deletedArtifactIds(ImmutableSet.of("deleted2")) + .stateDelta(new HashMap<>(ImmutableMap.of("key2", "value2"))) + .artifactDelta(artifactDelta2) .transferToAgent("agentId") .escalate(true) .requestedAuthConfigs( - new ConcurrentHashMap<>( - ImmutableMap.of("config2", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) - .requestedToolConfirmations( - new ConcurrentHashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) + new HashMap<>(ImmutableMap.of("config2", new HashMap<>(ImmutableMap.of("k", "v"))))) + .requestedToolConfirmations(new HashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) .endOfAgent(true) .build(); @@ -91,16 +95,13 @@ public void merge_mergesAllFields() { assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); - assertThat(merged.artifactDelta()).containsExactly("artifact1", 1, "artifact2", 2); - assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); + assertThat(merged.artifactDelta()) + .containsExactly("artifact1", 1, "artifact2", 2, "deleted1", null, "deleted2", null); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); assertThat(merged.requestedAuthConfigs()) .containsExactly( - "config1", - new ConcurrentHashMap<>(ImmutableMap.of("k", "v")), - "config2", - new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))); + "config1", ImmutableMap.of("k", "v"), "config2", ImmutableMap.of("k", "v")); assertThat(merged.requestedToolConfirmations()) .containsExactly("tool1", TOOL_CONFIRMATION, "tool2", TOOL_CONFIRMATION); assertThat(merged.endOfAgent()).isTrue(); @@ -111,23 +112,8 @@ public void merge_mergesAllFields() { public void removeStateByKey_marksKeyAsRemoved() { EventActions eventActions = new EventActions(); eventActions.stateDelta().put("key1", "value1"); - eventActions.removeStateByKey("key1"); - - assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); - } - - @Test - public void jsonSerialization_works() throws Exception { - EventActions eventActions = - EventActions.builder() - .deletedArtifactIds(ImmutableSet.of("d1", "d2")) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))) - .build(); - - String json = eventActions.toJson(); - EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + eventActions.stateDelta().put("key1", null); - assertThat(deserialized).isEqualTo(eventActions); - assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); + assertThat(eventActions.stateDelta()).containsExactly("key1", null); } } diff --git a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java index 345314256..e110e7245 100644 --- a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java @@ -44,7 +44,7 @@ public class GlobalInstructionPluginTest { @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); @Mock private CallbackContext mockCallbackContext; @Mock private InvocationContext mockInvocationContext; - private final State state = new State(new ConcurrentHashMap<>()); + private final State state = new State(); private final Session session = Session.builder("session_id").state(state).build(); @Mock private BaseArtifactService mockArtifactService; diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 6223dd2f0..381ad86a7 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,9 +20,9 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -84,10 +84,10 @@ public void lifecycle_listSessions() { Session session = sessionService - .createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1") + .createSession("app-name", "user-id", new HashMap<>(), "session-1") .blockingGet(); - ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + Map stateDelta = new HashMap<>(); stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); @@ -130,11 +130,9 @@ public void lifecycle_deleteSession() { public void appendEvent_updatesSessionState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); - ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + Map stateDelta = new HashMap<>(); stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); @@ -167,11 +165,9 @@ public void appendEvent_updatesSessionState() { public void appendEvent_removesState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); - ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + Map stateDeltaAdd = new HashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); stateDeltaAdd.put("_app_appKey", "appValue"); stateDeltaAdd.put("_user_userKey", "userValue"); @@ -193,11 +189,11 @@ public void appendEvent_removesState() { assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue"); // Prepare and append event to remove state - ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); - stateDeltaRemove.put("sessionKey", State.REMOVED); - stateDeltaRemove.put("_app_appKey", State.REMOVED); - stateDeltaRemove.put("_user_userKey", State.REMOVED); - stateDeltaRemove.put("temp:tempKey", State.REMOVED); + Map stateDeltaRemove = new HashMap<>(); + stateDeltaRemove.put("sessionKey", null); + stateDeltaRemove.put("_app_appKey", null); + stateDeltaRemove.put("_user_userKey", null); + stateDeltaRemove.put("temp:tempKey", null); Event eventRemove = Event.builder() @@ -221,12 +217,10 @@ public void appendEvent_removesState() { public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); // Agent 1 writes to temp state - ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); + Map stateDelta1 = new HashMap<>(); stateDelta1.put("temp:agent1_output", "data"); Event event1 = Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build(); @@ -237,9 +231,9 @@ public void sequentialAgents_shareTempState() { // Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes // "agent1_output" - ConcurrentMap stateDelta2 = new ConcurrentHashMap<>(); + Map stateDelta2 = new HashMap<>(); stateDelta2.put("temp:agent2_output", "processed_data"); - stateDelta2.put("temp:agent1_output", State.REMOVED); + stateDelta2.put("temp:agent1_output", null); Event event2 = Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build(); unused = sessionService.appendEvent(session, event2).blockingGet(); diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index f6120cf08..f772bf5fd 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -23,8 +23,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,8 +36,8 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio EventActions actions = EventActions.builder() .skipSummarization(true) - .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact", 1))) + .stateDelta(new HashMap<>(ImmutableMap.of("key", "value"))) + .artifactDelta(new HashMap<>(ImmutableMap.of("artifact", 1))) .transferToAgent("agent") .escalate(true) .build(); @@ -174,10 +172,10 @@ public void fromApiEvent_withTransferToAgent_success() { @Test public void convertEventToJson_complexActions_success() throws JsonProcessingException { - ConcurrentMap> authConfigs = new ConcurrentHashMap<>(); - authConfigs.put("auth1", new ConcurrentHashMap<>(ImmutableMap.of("param1", "value1"))); + Map> authConfigs = new HashMap<>(); + authConfigs.put("auth1", new HashMap<>(ImmutableMap.of("param1", "value1"))); - ConcurrentMap toolConfirmations = new ConcurrentHashMap<>(); + Map toolConfirmations = new HashMap<>(); toolConfirmations.put( "tool1", ToolConfirmation.builder().hint("hint1").confirmed(true).build()); @@ -225,6 +223,7 @@ public void convertEventToJson_complexActions_success() throws JsonProcessingExc } @Test + // Testing conversion of raw Map following a known schema. public void fromApiEvent_complexActions_success() { Map apiEvent = new HashMap<>(); apiEvent.put("name", "sessions/123/events/456"); @@ -333,11 +332,10 @@ public void fromApiEvent_missingMetadataFields_success() { @Test public void convertEventToJson_withStateRemoved_success() throws JsonProcessingException { - EventActions actions = - EventActions.builder() - .stateDelta( - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", State.REMOVED))) - .build(); + HashMap stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", null); + EventActions actions = EventActions.builder().stateDelta(stateDelta).build(); Event event = Event.builder() @@ -422,6 +420,6 @@ public void fromApiEvent_withNullStateDeltaValue_success() { EventActions eventActions = event.actions(); assertThat(eventActions.stateDelta()).containsEntry("key1", "value1"); - assertThat(eventActions.stateDelta()).containsEntry("key2", State.REMOVED); + assertThat(eventActions.stateDelta()).containsEntry("key2", null); } } diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 36eab1d16..fc2e2ed96 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -104,7 +104,7 @@ public class VertexAiSessionServiceTest { ] """; - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Casting raw Object from JSON Map to Map. private static Session getMockSession() throws Exception { Map sessionJson = mapper.readValue(MOCK_SESSION_STRING_1, new TypeReference>() {}); @@ -341,13 +341,12 @@ public void listEmptySession_success() { @Test public void appendEvent_withStateRemoved_updatesSessionState() { String userId = "userB"; - ConcurrentMap initialState = - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + ImmutableMap initialState = ImmutableMap.of("key1", "value1", "key2", "value2"); Session session = vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); - ConcurrentMap stateDelta = - new ConcurrentHashMap<>(ImmutableMap.of("key2", State.REMOVED)); + Map stateDelta = new HashMap<>(); + stateDelta.put("key2", null); Event event = Event.builder() .invocationId("456") diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index f29298bce..403c73ea2 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -59,8 +59,7 @@ class ReplayPluginTest { void setUp() { plugin = new ReplayPlugin(); mockSession = mock(Session.class); - sessionState = new ConcurrentHashMap<>(); - state = new State(sessionState); + state = new State(); when(mockSession.state()).thenReturn(sessionState); }