Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -88,7 +89,20 @@ private CollectionReference getSessionsCollection(String userId) {
/** Creates a new session in Firestore. */
@Override
public Single<Session> createSession(
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
String appName,
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId) {
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
}

/** Creates a new session in Firestore. */
@Override
public Single<Session> createSession(
String appName,
String userId,
@Nullable Map<String, Object> state,
@Nullable String sessionId) {
return Single.fromCallable(
() -> {
Objects.requireNonNull(appName, "appName cannot be null");
Expand Down
45 changes: 43 additions & 2 deletions core/src/main/java/com/google/adk/sessions/BaseSessionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.annotation.Nullable;

Expand All @@ -47,13 +49,35 @@ public interface BaseSessionService {
* service should generate a unique ID.
* @return The newly created {@link Session} instance.
* @throws SessionException if creation fails.
* @deprecated Use {@link #createSession(String, String, Map, String)} instead.
*/
@Deprecated
Single<Session> createSession(
String appName,
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId);

/**
* Creates a new session with the specified parameters.
*
* @param appName The name of the application associated with the session.
* @param userId The identifier for the user associated with the session.
* @param state An optional map representing the initial state of the session. Can be null or
* empty.
* @param sessionId An optional client-provided identifier for the session. If empty or null, the
* service should generate a unique ID.
* @return The newly created {@link Session} instance.
* @throws SessionException if creation fails.
*/
default Single<Session> createSession(
String appName,
String userId,
@Nullable Map<String, Object> state,
@Nullable String sessionId) {
return createSession(appName, userId, ensureConcurrentMap(state), sessionId);
}

/**
* Creates a new session with the specified application name and user ID, using a default state
* (null) and allowing the service to generate a unique session ID.
Expand Down Expand Up @@ -165,9 +189,9 @@ default Single<Event> appendEvent(Session session, Event event) {

EventActions actions = event.actions();
if (actions != null) {
ConcurrentMap<String, Object> stateDelta = actions.stateDelta();
Map<String, Object> stateDelta = actions.stateDelta();
if (stateDelta != null && !stateDelta.isEmpty()) {
ConcurrentMap<String, Object> sessionState = session.state();
Map<String, Object> sessionState = session.state();
if (sessionState != null) {
stateDelta.forEach(
(key, value) -> {
Expand All @@ -190,4 +214,21 @@ default Single<Event> appendEvent(Session session, Event event) {

return Single.just(event);
}

/**
* Ensures the given {@link Map} is a {@link ConcurrentMap}. If the input is null, returns null.
* If the input is already a {@link ConcurrentMap}, it is cast and returned. Otherwise, a new
* {@link ConcurrentHashMap} is created from the input map.
*/
@Nullable
private static ConcurrentMap<String, Object> ensureConcurrentMap(
@Nullable Map<String, Object> state) {
if (state == null) {
return null;
}
if (state instanceof ConcurrentMap<String, Object> concurrentMap) {
return concurrentMap;
}
return new ConcurrentHashMap<>(state);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ public Single<Session> createSession(
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId) {
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
}

@Override
public Single<Session> createSession(
String appName,
String userId,
@Nullable Map<String, Object> state,
@Nullable String sessionId) {
Objects.requireNonNull(appName, "appName cannot be null");
Objects.requireNonNull(userId, "userId cannot be null");

Expand All @@ -83,15 +92,13 @@ public Single<Session> createSession(
// Ensure state map and events list are mutable for the new session
ConcurrentMap<String, Object> initialState =
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
List<Event> initialEvents = new ArrayList<>();

// Assuming Session constructor or setters allow setting these mutable collections
Session newSession =
Session.builder(resolvedSessionId)
.appName(appName)
.userId(userId)
.state(initialState)
.events(initialEvents)
.lastUpdateTime(Instant.now())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import io.reactivex.rxjava3.core.Single;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
import okhttp3.ResponseBody;
Expand Down Expand Up @@ -51,8 +51,8 @@ final class VertexAiClient {
}

Maybe<JsonNode> createSession(
String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) {
ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>();
String reasoningEngineId, String userId, Map<String, Object> state) {
Map<String, Object> sessionJsonMap = new HashMap<>();
sessionJsonMap.put("userId", userId);
if (state != null) {
sessionJsonMap.put("sessionState", state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ public Single<Session> createSession(
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId) {
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
}

@Override
public Single<Session> createSession(
String appName,
String userId,
@Nullable Map<String, Object> state,
@Nullable String sessionId) {

String reasoningEngineId = parseReasoningEngineId(appName);
return client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import io.reactivex.rxjava3.core.Single;
import java.util.HashMap;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -84,7 +85,7 @@ public void lifecycle_listSessions() {

Session session =
sessionService
.createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1")
.createSession("app-name", "user-id", new HashMap<>(), "session-1")
.blockingGet();

ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -130,9 +131,7 @@ public void lifecycle_deleteSession() {
public void appendEvent_updatesSessionState() {
InMemorySessionService sessionService = new InMemorySessionService();
Session session =
sessionService
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
.blockingGet();
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();

ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
stateDelta.put("sessionKey", "sessionValue");
Expand Down Expand Up @@ -167,9 +166,7 @@ public void appendEvent_updatesSessionState() {
public void appendEvent_removesState() {
InMemorySessionService sessionService = new InMemorySessionService();
Session session =
sessionService
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
.blockingGet();
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();

ConcurrentMap<String, Object> stateDeltaAdd = new ConcurrentHashMap<>();
stateDeltaAdd.put("sessionKey", "sessionValue");
Expand Down Expand Up @@ -221,9 +218,7 @@ public void appendEvent_removesState() {
public void sequentialAgents_shareTempState() {
InMemorySessionService sessionService = new InMemorySessionService();
Session session =
sessionService
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
.blockingGet();
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();

// Agent 1 writes to temp state
ConcurrentMap<String, Object> stateDelta1 = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ public void setUp() throws Exception {

@Test
public void createSession_success() throws Exception {
ConcurrentMap<String, Object> sessionStateMap =
new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value"));
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value"));
Single<Session> sessionSingle =
vertexAiSessionService.createSession("123", "test_user", sessionStateMap, null);
Session createdSession = sessionSingle.blockingGet();
Expand All @@ -190,8 +189,7 @@ public void createSession_success() throws Exception {

@Test
public void createSession_getSession_success() throws Exception {
ConcurrentMap<String, Object> sessionStateMap =
new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value"));
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value"));
Single<Session> sessionSingle =
vertexAiSessionService.createSession("789", "test_user", sessionStateMap, null);
Session createdSession = sessionSingle.blockingGet();
Expand Down Expand Up @@ -252,8 +250,7 @@ public void getAndDeleteSession_success() throws Exception {

@Test
public void createSessionAndGetSession_success() throws Exception {
ConcurrentMap<String, Object> sessionStateMap =
new ConcurrentHashMap<>(ImmutableMap.of("key", "value"));
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("key", "value"));
Single<Session> sessionSingle =
vertexAiSessionService.createSession("123", "user", sessionStateMap, null);
Session createdSession = sessionSingle.blockingGet();
Expand Down Expand Up @@ -341,8 +338,8 @@ public void listEmptySession_success() {
@Test
public void appendEvent_withStateRemoved_updatesSessionState() {
String userId = "userB";
ConcurrentMap<String, Object> initialState =
new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2"));
Map<String, Object> initialState =
new HashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2"));
Session session =
vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet();

Expand Down