Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -85,10 +84,19 @@ private CollectionReference getSessionsCollection(String userId) {
.collection(SESSION_COLLECTION_NAME);
}

@Override
public Single<Session> createSession(
String appName,
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId) {
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
}

/** Creates a new session in Firestore. */
@Override
public Single<Session> createSession(
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
String appName, String userId, Map<String, Object> state, String sessionId) {
return Single.fromCallable(
() -> {
Objects.requireNonNull(appName, "appName cannot be null");
Expand All @@ -100,21 +108,17 @@ public Single<Session> createSession(
.filter(s -> !s.isEmpty())
.orElseGet(() -> UUID.randomUUID().toString());

ConcurrentMap<String, Object> initialState =
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
logger.info(
"Creating session for userId: {} with sessionId: {} and initial state: {}",
userId,
resolvedSessionId,
initialState);
List<Event> initialEvents = new ArrayList<>();
state);
Instant now = Instant.now();
Session newSession =
Session.builder(resolvedSessionId)
.appName(appName)
.userId(userId)
.state(initialState)
.events(initialEvents)
.state(state)
.lastUpdateTime(now)
.build();

Expand Down Expand Up @@ -200,8 +204,7 @@ public Maybe<Session> getSession(
})
.map(
events -> {
ConcurrentMap<String, Object> state =
new ConcurrentHashMap<>((Map<String, Object>) data.get(STATE_KEY));
Map<String, Object> state = (Map<String, Object>) data.get(STATE_KEY);
return Session.builder((String) data.get(ID_KEY))
.appName((String) data.get(APP_NAME_KEY))
.userId((String) data.get(USER_ID_KEY))
Expand Down Expand Up @@ -451,8 +454,6 @@ public Single<ListSessionsResponse> listSessions(String appName, String userId)
.appName((String) data.get(APP_NAME_KEY))
.userId((String) data.get(USER_ID_KEY))
.lastUpdateTime(Instant.parse((String) data.get(UPDATE_TIME_KEY)))
.state(new ConcurrentHashMap<>()) // Empty state
.events(new ArrayList<>()) // Empty events
.build();
sessions.add(session);
}
Expand Down
118 changes: 49 additions & 69 deletions core/src/main/java/com/google/adk/events/EventActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.adk.JsonBaseModel;
import com.google.adk.sessions.State;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.HashSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.annotation.Nullable;

/** Represents the actions attached to an event. */
Expand All @@ -35,39 +34,37 @@
public class EventActions extends JsonBaseModel {

private Optional<Boolean> skipSummarization;
private ConcurrentMap<String, Object> stateDelta;
private ConcurrentMap<String, Integer> artifactDelta;
private Set<String> deletedArtifactIds;
private Map<String, Object> stateDelta;
private Map<String, Integer> artifactDelta;
private Optional<String> transferToAgent;
private Optional<Boolean> escalate;
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
private Map<String, Map<String, Object>> requestedAuthConfigs;
private Map<String, ToolConfirmation> requestedToolConfirmations;
private boolean endOfAgent;
private Optional<EventCompaction> compaction;

/** Default constructor for Jackson. */
public EventActions() {
this.skipSummarization = Optional.empty();
this.stateDelta = new ConcurrentHashMap<>();
this.artifactDelta = new ConcurrentHashMap<>();
this.deletedArtifactIds = new HashSet<>();
this.stateDelta = Collections.synchronizedMap(new HashMap<>());
this.artifactDelta = Collections.synchronizedMap(new HashMap<>());
this.transferToAgent = Optional.empty();
this.escalate = Optional.empty();
this.requestedAuthConfigs = new ConcurrentHashMap<>();
this.requestedToolConfirmations = new ConcurrentHashMap<>();
this.requestedAuthConfigs = Collections.synchronizedMap(new HashMap<>());
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>());
this.endOfAgent = false;
this.compaction = Optional.empty();
}

private EventActions(Builder builder) {
this.skipSummarization = builder.skipSummarization;
this.stateDelta = builder.stateDelta;
this.artifactDelta = builder.artifactDelta;
this.deletedArtifactIds = builder.deletedArtifactIds;
this.stateDelta = Collections.synchronizedMap(builder.stateDelta);
this.artifactDelta = Collections.synchronizedMap(builder.artifactDelta);
this.transferToAgent = builder.transferToAgent;
this.escalate = builder.escalate;
this.requestedAuthConfigs = builder.requestedAuthConfigs;
this.requestedToolConfirmations = builder.requestedToolConfirmations;
this.requestedAuthConfigs = Collections.synchronizedMap(builder.requestedAuthConfigs);
this.requestedToolConfirmations =
Collections.synchronizedMap(builder.requestedToolConfirmations);
this.endOfAgent = builder.endOfAgent;
this.compaction = builder.compaction;
}
Expand All @@ -90,41 +87,32 @@ public void setSkipSummarization(boolean skipSummarization) {
}

@JsonProperty("stateDelta")
public ConcurrentMap<String, Object> stateDelta() {
public Map<String, Object> stateDelta() {
return stateDelta;
}

@Deprecated // Use stateDelta(), addState() and removeStateByKey() instead.
public void setStateDelta(ConcurrentMap<String, Object> stateDelta) {
this.stateDelta = stateDelta;
public void setStateDelta(Map<String, Object> stateDelta) {
this.stateDelta = Collections.synchronizedMap(new HashMap<>(stateDelta));
}

/**
* Removes a key from the state delta.
*
* @param key The key to remove.
* @deprecated Use {@link #stateDelta()}.put(key, null) instead.
*/
@Deprecated
public void removeStateByKey(String key) {
stateDelta.put(key, State.REMOVED);
stateDelta().put(key, null);
}

@JsonProperty("artifactDelta")
public ConcurrentMap<String, Integer> artifactDelta() {
public Map<String, Integer> artifactDelta() {
return artifactDelta;
}

public void setArtifactDelta(ConcurrentMap<String, Integer> artifactDelta) {
this.artifactDelta = artifactDelta;
}

@JsonProperty("deletedArtifactIds")
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public Set<String> deletedArtifactIds() {
return deletedArtifactIds;
}

public void setDeletedArtifactIds(Set<String> deletedArtifactIds) {
this.deletedArtifactIds = deletedArtifactIds;
public void setArtifactDelta(Map<String, Integer> artifactDelta) {
this.artifactDelta = Collections.synchronizedMap(new HashMap<>(artifactDelta));
}

@JsonProperty("transferToAgent")
Expand Down Expand Up @@ -154,23 +142,23 @@ public void setEscalate(boolean escalate) {
}

@JsonProperty("requestedAuthConfigs")
public ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs() {
public Map<String, Map<String, Object>> requestedAuthConfigs() {
return requestedAuthConfigs;
}

public void setRequestedAuthConfigs(
ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs) {
public void setRequestedAuthConfigs(Map<String, Map<String, Object>> requestedAuthConfigs) {
this.requestedAuthConfigs = requestedAuthConfigs;
}

@JsonProperty("requestedToolConfirmations")
public ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations() {
public Map<String, ToolConfirmation> requestedToolConfirmations() {
return requestedToolConfirmations;
}

public void setRequestedToolConfirmations(
ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations) {
this.requestedToolConfirmations = requestedToolConfirmations;
Map<String, ToolConfirmation> requestedToolConfirmations) {
this.requestedToolConfirmations =
Collections.synchronizedMap(new HashMap<>(requestedToolConfirmations));
}

@JsonProperty("endOfAgent")
Expand Down Expand Up @@ -235,7 +223,6 @@ public boolean equals(Object o) {
return Objects.equals(skipSummarization, that.skipSummarization)
&& Objects.equals(stateDelta, that.stateDelta)
&& Objects.equals(artifactDelta, that.artifactDelta)
&& Objects.equals(deletedArtifactIds, that.deletedArtifactIds)
&& Objects.equals(transferToAgent, that.transferToAgent)
&& Objects.equals(escalate, that.escalate)
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
Expand All @@ -250,7 +237,6 @@ public int hashCode() {
skipSummarization,
stateDelta,
artifactDelta,
deletedArtifactIds,
transferToAgent,
escalate,
requestedAuthConfigs,
Expand All @@ -262,38 +248,34 @@ public int hashCode() {
/** Builder for {@link EventActions}. */
public static class Builder {
private Optional<Boolean> skipSummarization;
private ConcurrentMap<String, Object> stateDelta;
private ConcurrentMap<String, Integer> artifactDelta;
private Set<String> deletedArtifactIds;
private Map<String, Object> stateDelta;
private Map<String, Integer> artifactDelta;
private Optional<String> transferToAgent;
private Optional<Boolean> escalate;
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations;
private Map<String, Map<String, Object>> requestedAuthConfigs;
private Map<String, ToolConfirmation> requestedToolConfirmations;
private boolean endOfAgent = false;
private Optional<EventCompaction> compaction;

public Builder() {
this.skipSummarization = Optional.empty();
this.stateDelta = new ConcurrentHashMap<>();
this.artifactDelta = new ConcurrentHashMap<>();
this.deletedArtifactIds = new HashSet<>();
this.stateDelta = new HashMap<>();
this.artifactDelta = new HashMap<>();
this.transferToAgent = Optional.empty();
this.escalate = Optional.empty();
this.requestedAuthConfigs = new ConcurrentHashMap<>();
this.requestedToolConfirmations = new ConcurrentHashMap<>();
this.requestedAuthConfigs = new HashMap<>();
this.requestedToolConfirmations = new HashMap<>();
this.compaction = Optional.empty();
}

private Builder(EventActions eventActions) {
this.skipSummarization = eventActions.skipSummarization();
this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta());
this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta());
this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds());
this.stateDelta = new HashMap<>(eventActions.stateDelta());
this.artifactDelta = new HashMap<>(eventActions.artifactDelta());
this.transferToAgent = eventActions.transferToAgent();
this.escalate = eventActions.escalate();
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
this.requestedToolConfirmations =
new ConcurrentHashMap<>(eventActions.requestedToolConfirmations());
this.requestedAuthConfigs = new HashMap<>(eventActions.requestedAuthConfigs());
this.requestedToolConfirmations = new HashMap<>(eventActions.requestedToolConfirmations());
this.endOfAgent = eventActions.endOfAgent();
this.compaction = eventActions.compaction();
}
Expand All @@ -307,22 +289,22 @@ public Builder skipSummarization(boolean skipSummarization) {

@CanIgnoreReturnValue
@JsonProperty("stateDelta")
public Builder stateDelta(ConcurrentMap<String, Object> value) {
public Builder stateDelta(Map<String, Object> value) {
this.stateDelta = value;
return this;
}

@CanIgnoreReturnValue
@JsonProperty("artifactDelta")
public Builder artifactDelta(ConcurrentMap<String, Integer> value) {
public Builder artifactDelta(Map<String, Integer> value) {
this.artifactDelta = value;
return this;
}

@CanIgnoreReturnValue
@JsonProperty("deletedArtifactIds")
public Builder deletedArtifactIds(Set<String> value) {
this.deletedArtifactIds = value;
value.forEach(v -> artifactDelta.put(v, null));
return this;
}

Expand All @@ -342,16 +324,15 @@ public Builder escalate(boolean escalate) {

@CanIgnoreReturnValue
@JsonProperty("requestedAuthConfigs")
public Builder requestedAuthConfigs(
ConcurrentMap<String, ConcurrentMap<String, Object>> value) {
public Builder requestedAuthConfigs(Map<String, Map<String, Object>> value) {
this.requestedAuthConfigs = value;
return this;
}

@CanIgnoreReturnValue
@JsonProperty("requestedToolConfirmations")
public Builder requestedToolConfirmations(ConcurrentMap<String, ToolConfirmation> value) {
this.requestedToolConfirmations = value;
public Builder requestedToolConfirmations(Map<String, ToolConfirmation> value) {
this.requestedToolConfirmations = Collections.synchronizedMap(new HashMap<>(value));
return this;
}

Expand Down Expand Up @@ -385,7 +366,6 @@ public Builder merge(EventActions other) {
other.skipSummarization().ifPresent(this::skipSummarization);
this.stateDelta.putAll(other.stateDelta());
this.artifactDelta.putAll(other.artifactDelta());
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
other.transferToAgent().ifPresent(this::transferToAgent);
other.escalate().ifPresent(this::escalate);
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -337,7 +338,9 @@ private Single<Event> appendNewMessageToSession(
// Add state delta if provided
if (stateDelta != null && !stateDelta.isEmpty()) {
eventBuilder.actions(
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
EventActions.builder()
.stateDelta(stateDelta == null ? new HashMap<>() : new HashMap<>(stateDelta))
.build());
}

return this.sessionService.appendEvent(session, eventBuilder.build());
Expand Down
Loading
Loading