Skip to content

Commit 7e8f9dc

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: add callbacks functionality to the agent executor
PiperOrigin-RevId: 878441462
1 parent 05fbcfc commit 7e8f9dc

File tree

2 files changed

+369
-31
lines changed

2 files changed

+369
-31
lines changed

a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,19 @@
2020
import io.a2a.server.agentexecution.RequestContext;
2121
import io.a2a.server.events.EventQueue;
2222
import io.a2a.server.tasks.TaskUpdater;
23+
import io.a2a.spec.Artifact;
2324
import io.a2a.spec.InvalidAgentResponseError;
2425
import io.a2a.spec.Message;
2526
import io.a2a.spec.Part;
27+
import io.a2a.spec.TaskArtifactUpdateEvent;
28+
import io.a2a.spec.TaskState;
29+
import io.a2a.spec.TaskStatus;
30+
import io.a2a.spec.TaskStatusUpdateEvent;
2631
import io.a2a.spec.TextPart;
32+
import io.reactivex.rxjava3.core.Completable;
33+
import io.reactivex.rxjava3.core.Flowable;
2734
import io.reactivex.rxjava3.core.Maybe;
35+
import io.reactivex.rxjava3.core.Single;
2836
import io.reactivex.rxjava3.disposables.CompositeDisposable;
2937
import io.reactivex.rxjava3.disposables.Disposable;
3038
import java.util.HashMap;
@@ -43,10 +51,8 @@
4351
* use in production code.
4452
*/
4553
public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor {
46-
4754
private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class);
4855
private static final String USER_ID_PREFIX = "A2A_USER_";
49-
5056
private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>();
5157
private final Runner.Builder runnerBuilder;
5258
private final AgentExecutorConfig agentExecutorConfig;
@@ -137,7 +143,6 @@ public Builder plugins(List<? extends Plugin> plugins) {
137143
return this;
138144
}
139145

140-
@CanIgnoreReturnValue
141146
public AgentExecutor build() {
142147
return new AgentExecutor(
143148
app,
@@ -165,46 +170,88 @@ public void execute(RequestContext ctx, EventQueue eventQueue) {
165170
if (message == null) {
166171
throw new IllegalArgumentException("Message cannot be null");
167172
}
168-
169173
// Submits a new task if there is no active task.
170174
if (ctx.getTask() == null) {
171175
updater.submit();
172176
}
173-
174177
// Group all reactive work for this task into one container
175178
CompositeDisposable taskDisposables = new CompositeDisposable();
176179
// Check if the task with the task id is already running, put if absent.
177180
if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) {
178181
throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId()));
179182
}
180-
181183
EventProcessor p = new EventProcessor(agentExecutorConfig.outputMode());
182184
Content content = PartConverter.messageToContent(message);
183-
Runner runner = runnerBuilder.build();
185+
Single<Boolean> skipExecution =
186+
agentExecutorConfig.beforeExecuteCallback() != null
187+
? agentExecutorConfig.beforeExecuteCallback().call(ctx)
188+
: Single.just(false);
184189

190+
Runner runner = runnerBuilder.build();
185191
taskDisposables.add(
186-
prepareSession(ctx, runner.appName(), runner.sessionService())
192+
skipExecution
187193
.flatMapPublisher(
188-
session -> {
189-
updater.startWork();
190-
return runner.runAsync(
191-
getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig());
194+
skip -> {
195+
if (skip) {
196+
cancel(ctx, eventQueue);
197+
return Flowable.empty();
198+
}
199+
return Maybe.defer(
200+
() -> {
201+
return prepareSession(ctx, runner.appName(), runner.sessionService());
202+
})
203+
.flatMapPublisher(
204+
session -> {
205+
updater.startWork();
206+
return runner.runAsync(
207+
getUserId(ctx),
208+
session.id(),
209+
content,
210+
agentExecutorConfig.runConfig());
211+
});
192212
})
193-
.subscribe(
213+
.concatMap(
194214
event -> {
195-
p.process(event, updater);
196-
},
215+
return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue)
216+
.toFlowable();
217+
})
218+
// Ignore all events from the runner, since they are already processed.
219+
.ignoreElements()
220+
.materialize()
221+
.flatMapCompletable(
222+
notification -> {
223+
Throwable error = notification.getError();
224+
if (error != null) {
225+
logger.error("Runner failed to execute", error);
226+
}
227+
return handleExecutionEnd(ctx, error, eventQueue);
228+
})
229+
.doFinally(() -> cleanupTask(ctx.getTaskId()))
230+
.subscribe(
231+
() -> {},
197232
error -> {
198-
logger.error("Runner failed with {}", error);
199-
updater.fail(failedMessage(ctx, error));
200-
cleanupTask(ctx.getTaskId());
201-
},
202-
() -> {
203-
updater.complete();
204-
cleanupTask(ctx.getTaskId());
233+
logger.error("Failed to handle execution end", error);
205234
}));
206235
}
207236

237+
private Completable handleExecutionEnd(
238+
RequestContext ctx, Throwable error, EventQueue eventQueue) {
239+
TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED;
240+
Message message = error != null ? failedMessage(ctx, error) : null;
241+
TaskStatusUpdateEvent initialEvent =
242+
new TaskStatusUpdateEvent.Builder()
243+
.taskId(ctx.getTaskId())
244+
.contextId(ctx.getContextId())
245+
.isFinal(true)
246+
.status(new TaskStatus(state, message, null))
247+
.build();
248+
Maybe<TaskStatusUpdateEvent> afterExecute =
249+
agentExecutorConfig.afterExecuteCallback() != null
250+
? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent)
251+
: Maybe.just(initialEvent);
252+
return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement();
253+
}
254+
208255
private void cleanupTask(String taskId) {
209256
Disposable d = activeTasks.remove(taskId);
210257
if (d != null) {
@@ -249,16 +296,19 @@ private EventProcessor(AgentExecutorConfig.OutputMode outputMode) {
249296
this.outputMode = outputMode;
250297
}
251298

252-
private void process(Event event, TaskUpdater updater) {
299+
private Maybe<TaskArtifactUpdateEvent> process(
300+
Event event,
301+
RequestContext ctx,
302+
Callbacks.AfterEventCallback callback,
303+
EventQueue eventQueue) {
253304
if (event.errorCode().isPresent()) {
254-
throw new InvalidAgentResponseError(
255-
null, // Uses default code -32006
256-
"Agent returned an error: " + event.errorCode().get(),
257-
null);
305+
return Maybe.error(
306+
new InvalidAgentResponseError(
307+
null, // Uses default code -32006
308+
"Agent returned an error: " + event.errorCode().get(),
309+
null));
258310
}
259-
260311
ImmutableList<Part<?>> parts = EventConverter.contentToParts(event.content());
261-
262312
// Mark all parts as partial if the event is partial.
263313
if (event.partial().orElse(false)) {
264314
parts.forEach(
@@ -302,7 +352,26 @@ private void process(Event event, TaskUpdater updater) {
302352
}
303353
}
304354

305-
updater.addArtifact(parts, artifactId, null, metadata, append, lastChunk);
355+
TaskArtifactUpdateEvent initialEvent =
356+
new TaskArtifactUpdateEvent.Builder()
357+
.taskId(ctx.getTaskId())
358+
.contextId(ctx.getContextId())
359+
.lastChunk(lastChunk)
360+
.append(append)
361+
.artifact(
362+
new Artifact.Builder()
363+
.artifactId(artifactId)
364+
.parts(parts)
365+
.metadata(metadata)
366+
.build())
367+
.build();
368+
369+
Maybe<TaskArtifactUpdateEvent> afterEvent =
370+
callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent);
371+
return afterEvent.doOnSuccess(
372+
finalEvent -> {
373+
eventQueue.enqueueEvent(finalEvent);
374+
});
306375
}
307376
}
308377
}

0 commit comments

Comments
 (0)