diff --git a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java index d3295a6ef..db236fc88 100644 --- a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java +++ b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java @@ -50,6 +50,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,7 +89,20 @@ private CollectionReference getSessionsCollection(String userId) { /** Creates a new session in Firestore. */ @Override public Single createSession( - String appName, String userId, ConcurrentMap state, String sessionId) { + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + /** Creates a new session in Firestore. */ + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { return Single.fromCallable( () -> { Objects.requireNonNull(appName, "appName cannot be null"); diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 540153460..94e8cd7ba 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -23,8 +23,10 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -47,13 +49,35 @@ public interface BaseSessionService { * service should generate a unique ID. * @return The newly created {@link Session} instance. * @throws SessionException if creation fails. + * @deprecated Use {@link #createSession(String, String, Map, String)} instead. */ + @Deprecated Single createSession( String appName, String userId, @Nullable ConcurrentMap state, @Nullable String sessionId); + /** + * Creates a new session with the specified parameters. + * + * @param appName The name of the application associated with the session. + * @param userId The identifier for the user associated with the session. + * @param state An optional map representing the initial state of the session. Can be null or + * empty. + * @param sessionId An optional client-provided identifier for the session. If empty or null, the + * service should generate a unique ID. + * @return The newly created {@link Session} instance. + * @throws SessionException if creation fails. + */ + default Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { + return createSession(appName, userId, ensureConcurrentMap(state), sessionId); + } + /** * Creates a new session with the specified application name and user ID, using a default state * (null) and allowing the service to generate a unique session ID. @@ -165,9 +189,9 @@ default Single appendEvent(Session session, Event event) { EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { - ConcurrentMap sessionState = session.state(); + Map sessionState = session.state(); if (sessionState != null) { stateDelta.forEach( (key, value) -> { @@ -190,4 +214,21 @@ default Single appendEvent(Session session, Event event) { return Single.just(event); } + + /** + * Ensures the given {@link Map} is a {@link ConcurrentMap}. If the input is null, returns null. + * If the input is already a {@link ConcurrentMap}, it is cast and returned. Otherwise, a new + * {@link ConcurrentHashMap} is created from the input map. + */ + @Nullable + private static ConcurrentMap ensureConcurrentMap( + @Nullable Map state) { + if (state == null) { + return null; + } + if (state instanceof ConcurrentMap concurrentMap) { + return concurrentMap; + } + return new ConcurrentHashMap<>(state); + } } diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 060fcaf60..b2a584b11 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -71,6 +71,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); @@ -83,7 +92,6 @@ public Single createSession( // Ensure state map and events list are mutable for the new session ConcurrentMap initialState = (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); - List initialEvents = new ArrayList<>(); // Assuming Session constructor or setters allow setting these mutable collections Session newSession = @@ -91,7 +99,6 @@ public Single createSession( .appName(appName) .userId(userId) .state(initialState) - .events(initialEvents) .lastUpdateTime(Instant.now()) .build(); diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index d35bbccae..718738b92 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -14,10 +14,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; @@ -51,8 +51,8 @@ final class VertexAiClient { } Maybe createSession( - String reasoningEngineId, String userId, ConcurrentMap state) { - ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); + String reasoningEngineId, String userId, Map state) { + Map sessionJsonMap = new HashMap<>(); sessionJsonMap.put("userId", userId); if (state != null) { sessionJsonMap.put("sessionState", state); diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 7878daf22..2fff7a752 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -76,6 +76,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); return client diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 6223dd2f0..41e156ffd 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -84,7 +85,7 @@ public void lifecycle_listSessions() { Session session = sessionService - .createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1") + .createSession("app-name", "user-id", new HashMap<>(), "session-1") .blockingGet(); ConcurrentMap stateDelta = new ConcurrentHashMap<>(); @@ -130,9 +131,7 @@ public void lifecycle_deleteSession() { public void appendEvent_updatesSessionState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); @@ -167,9 +166,7 @@ public void appendEvent_updatesSessionState() { public void appendEvent_removesState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); @@ -221,9 +218,7 @@ public void appendEvent_removesState() { public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); // Agent 1 writes to temp state ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 36eab1d16..def4faf4c 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -167,8 +167,7 @@ public void setUp() throws Exception { @Test public void createSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value")); Single sessionSingle = vertexAiSessionService.createSession("123", "test_user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -190,8 +189,7 @@ public void createSession_success() throws Exception { @Test public void createSession_getSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value")); Single sessionSingle = vertexAiSessionService.createSession("789", "test_user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -252,8 +250,7 @@ public void getAndDeleteSession_success() throws Exception { @Test public void createSessionAndGetSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("key", "value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("key", "value")); Single sessionSingle = vertexAiSessionService.createSession("123", "user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -341,8 +338,8 @@ public void listEmptySession_success() { @Test public void appendEvent_withStateRemoved_updatesSessionState() { String userId = "userB"; - ConcurrentMap initialState = - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + Map initialState = + new HashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); Session session = vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet();