Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 101 additions & 30 deletions a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@
import io.a2a.server.agentexecution.RequestContext;
import io.a2a.server.events.EventQueue;
import io.a2a.server.tasks.TaskUpdater;
import io.a2a.spec.Artifact;
import io.a2a.spec.InvalidAgentResponseError;
import io.a2a.spec.Message;
import io.a2a.spec.Part;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TaskStatusUpdateEvent;
import io.a2a.spec.TextPart;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.List;
Expand All @@ -42,10 +50,8 @@
* use in production code.
*/
public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor {

private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class);
private static final String USER_ID_PREFIX = "A2A_USER_";

private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>();
private final Runner.Builder runnerBuilder;
private final AgentExecutorConfig agentExecutorConfig;
Expand Down Expand Up @@ -136,7 +142,6 @@ public Builder plugins(List<? extends Plugin> plugins) {
return this;
}

@CanIgnoreReturnValue
public AgentExecutor build() {
return new AgentExecutor(
app,
Expand Down Expand Up @@ -164,46 +169,95 @@ public void execute(RequestContext ctx, EventQueue eventQueue) {
if (message == null) {
throw new IllegalArgumentException("Message cannot be null");
}

// Submits a new task if there is no active task.
if (ctx.getTask() == null) {
updater.submit();
}

// Group all reactive work for this task into one container
CompositeDisposable taskDisposables = new CompositeDisposable();
// Check if the task with the task id is already running, put if absent.
if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) {
throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId()));
}

EventProcessor p = new EventProcessor();
Content content = PartConverter.messageToContent(message);
Runner runner = runnerBuilder.build();
Single<Boolean> skipExecution =
agentExecutorConfig.beforeExecuteCallback() != null
? agentExecutorConfig.beforeExecuteCallback().call(ctx)
: Single.just(false);

Runner runner = runnerBuilder.build();
taskDisposables.add(
prepareSession(ctx, runner.appName(), runner.sessionService())
skipExecution
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(
getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig());
skip -> {
if (skip) {
cancel(ctx, eventQueue);
return Flowable.empty();
}
return Maybe.defer(
() -> {
return prepareSession(ctx, runner.appName(), runner.sessionService());
})
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(
getUserId(ctx),
session.id(),
content,
agentExecutorConfig.runConfig());
});
})
.subscribe(
.concatMap(
event -> {
p.process(event, updater);
},
return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue)
.toFlowable();
})
.subscribe(
// Events are already processed and enqueued in concatMap
event -> {},
error -> {
logger.error("Runner failed with {}", error);
updater.fail(failedMessage(ctx, error));
cleanupTask(ctx.getTaskId());
var unused =
handleExecutionEnd(ctx, error, eventQueue)
.subscribe(
() -> cleanupTask(ctx.getTaskId()),
e -> {
logger.error("Failed to handle execution end", e);
cleanupTask(ctx.getTaskId());
});
},
() -> {
updater.complete();
cleanupTask(ctx.getTaskId());
var unused =
handleExecutionEnd(ctx, null, eventQueue)
.subscribe(
() -> cleanupTask(ctx.getTaskId()),
e -> {
logger.error("Failed to handle execution end", e);
cleanupTask(ctx.getTaskId());
});
}));
}

private Completable handleExecutionEnd(
RequestContext ctx, Throwable error, EventQueue eventQueue) {
TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED;
Message message = error != null ? failedMessage(ctx, error) : null;
TaskStatusUpdateEvent initialEvent =
new TaskStatusUpdateEvent.Builder()
.taskId(ctx.getTaskId())
.contextId(ctx.getContextId())
.isFinal(true)
.status(new TaskStatus(state, message, null))
.build();
Maybe<TaskStatusUpdateEvent> afterExecute =
agentExecutorConfig.afterExecuteCallback() != null
? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent)
: Maybe.just(initialEvent);
return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement();
}

private void cleanupTask(String taskId) {
Disposable d = activeTasks.remove(taskId);
if (d != null) {
Expand Down Expand Up @@ -238,24 +292,26 @@ private static Message failedMessage(RequestContext context, Throwable e) {

// Processor that will process all events related to the one runner invocation.
private static class EventProcessor {
private final String artifactId;

// All artifacts related to the invocation should have the same artifact id.
private EventProcessor() {
artifactId = UUID.randomUUID().toString();
}

private final String artifactId;

private void process(Event event, TaskUpdater updater) {
private Maybe<TaskArtifactUpdateEvent> process(
Event event,
RequestContext ctx,
Callbacks.AfterEventCallback callback,
EventQueue eventQueue) {
if (event.errorCode().isPresent()) {
throw new InvalidAgentResponseError(
null, // Uses default code -32006
"Agent returned an error: " + event.errorCode().get(),
null);
return Maybe.error(
new InvalidAgentResponseError(
null, // Uses default code -32006
"Agent returned an error: " + event.errorCode().get(),
null));
}

ImmutableList<Part<?>> parts = EventConverter.contentToParts(event.content());

// Mark all parts as partial if the event is partial.
if (event.partial().orElse(false)) {
parts.forEach(
Expand All @@ -264,8 +320,23 @@ private void process(Event event, TaskUpdater updater) {
metadata.put("adk_partial", true);
});
}

updater.addArtifact(parts, artifactId, null, ImmutableMap.of());
TaskArtifactUpdateEvent initialEvent =
new TaskArtifactUpdateEvent.Builder()
.taskId(ctx.getTaskId())
.contextId(ctx.getContextId())
.artifact(
new Artifact.Builder()
.artifactId(artifactId)
.parts(parts)
.metadata(ImmutableMap.of())
.build())
.build();
Maybe<TaskArtifactUpdateEvent> afterEvent =
callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent);
return afterEvent.doOnSuccess(
evt -> {
eventQueue.enqueueEvent(evt);
});
}
}
}
Loading
Loading