diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 0c12727aa..90b645159 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -20,11 +20,19 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.Artifact; import io.a2a.spec.InvalidAgentResponseError; import io.a2a.spec.Message; import io.a2a.spec.Part; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; import java.util.List; @@ -42,10 +50,8 @@ * use in production code. */ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { - private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; - private final Map activeTasks = new ConcurrentHashMap<>(); private final Runner.Builder runnerBuilder; private final AgentExecutorConfig agentExecutorConfig; @@ -136,7 +142,6 @@ public Builder plugins(List plugins) { return this; } - @CanIgnoreReturnValue public AgentExecutor build() { return new AgentExecutor( app, @@ -164,46 +169,95 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { if (message == null) { throw new IllegalArgumentException("Message cannot be null"); } - // Submits a new task if there is no active task. if (ctx.getTask() == null) { updater.submit(); } - // Group all reactive work for this task into one container CompositeDisposable taskDisposables = new CompositeDisposable(); // Check if the task with the task id is already running, put if absent. if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); } - EventProcessor p = new EventProcessor(); Content content = PartConverter.messageToContent(message); - Runner runner = runnerBuilder.build(); + Single skipExecution = + agentExecutorConfig.beforeExecuteCallback() != null + ? agentExecutorConfig.beforeExecuteCallback().call(ctx) + : Single.just(false); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.appName(), runner.sessionService()) + skipExecution .flatMapPublisher( - session -> { - updater.startWork(); - return runner.runAsync( - getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); + skip -> { + if (skip) { + cancel(ctx, eventQueue); + return Flowable.empty(); + } + return Maybe.defer( + () -> { + return prepareSession(ctx, runner.appName(), runner.sessionService()); + }) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync( + getUserId(ctx), + session.id(), + content, + agentExecutorConfig.runConfig()); + }); }) - .subscribe( + .concatMap( event -> { - p.process(event, updater); - }, + return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue) + .toFlowable(); + }) + .subscribe( + // Events are already processed and enqueued in concatMap + event -> {}, error -> { logger.error("Runner failed with {}", error); - updater.fail(failedMessage(ctx, error)); - cleanupTask(ctx.getTaskId()); + var unused = + handleExecutionEnd(ctx, error, eventQueue) + .subscribe( + () -> cleanupTask(ctx.getTaskId()), + e -> { + logger.error("Failed to handle execution end", e); + cleanupTask(ctx.getTaskId()); + }); }, () -> { - updater.complete(); - cleanupTask(ctx.getTaskId()); + var unused = + handleExecutionEnd(ctx, null, eventQueue) + .subscribe( + () -> cleanupTask(ctx.getTaskId()), + e -> { + logger.error("Failed to handle execution end", e); + cleanupTask(ctx.getTaskId()); + }); })); } + private Completable handleExecutionEnd( + RequestContext ctx, Throwable error, EventQueue eventQueue) { + TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED; + Message message = error != null ? failedMessage(ctx, error) : null; + TaskStatusUpdateEvent initialEvent = + new TaskStatusUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .isFinal(true) + .status(new TaskStatus(state, message, null)) + .build(); + Maybe afterExecute = + agentExecutorConfig.afterExecuteCallback() != null + ? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent) + : Maybe.just(initialEvent); + return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement(); + } + private void cleanupTask(String taskId) { Disposable d = activeTasks.remove(taskId); if (d != null) { @@ -238,24 +292,26 @@ private static Message failedMessage(RequestContext context, Throwable e) { // Processor that will process all events related to the one runner invocation. private static class EventProcessor { + private final String artifactId; // All artifacts related to the invocation should have the same artifact id. private EventProcessor() { artifactId = UUID.randomUUID().toString(); } - private final String artifactId; - - private void process(Event event, TaskUpdater updater) { + private Maybe process( + Event event, + RequestContext ctx, + Callbacks.AfterEventCallback callback, + EventQueue eventQueue) { if (event.errorCode().isPresent()) { - throw new InvalidAgentResponseError( - null, // Uses default code -32006 - "Agent returned an error: " + event.errorCode().get(), - null); + return Maybe.error( + new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null)); } - ImmutableList> parts = EventConverter.contentToParts(event.content()); - // Mark all parts as partial if the event is partial. if (event.partial().orElse(false)) { parts.forEach( @@ -264,8 +320,23 @@ private void process(Event event, TaskUpdater updater) { metadata.put("adk_partial", true); }); } - - updater.addArtifact(parts, artifactId, null, ImmutableMap.of()); + TaskArtifactUpdateEvent initialEvent = + new TaskArtifactUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .artifact( + new Artifact.Builder() + .artifactId(artifactId) + .parts(parts) + .metadata(ImmutableMap.of()) + .build()) + .build(); + Maybe afterEvent = + callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent); + return afterEvent.doOnSuccess( + evt -> { + eventQueue.enqueueEvent(evt); + }); } } } diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index 350bd6f16..550a97acf 100644 --- a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -1,6 +1,11 @@ package com.google.adk.a2a.executor; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -9,7 +14,23 @@ import com.google.adk.events.Event; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.spec.Message; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -18,13 +39,25 @@ @RunWith(JUnit4.class) public final class AgentExecutorTest { + private EventQueue eventQueue; + private List enqueuedEvents; private TestAgent testAgent; @Before public void setUp() { + enqueuedEvents = new ArrayList<>(); + eventQueue = mock(EventQueue.class); + doAnswer( + invocation -> { + enqueuedEvents.add(invocation.getArgument(0)); + return null; + }) + .when(eventQueue) + .enqueueEvent(any()); testAgent = new TestAgent(); } + @Test public void createAgentExecutor_noAgent_succeeds() { var unused = @@ -77,14 +110,198 @@ public void createAgentExecutor_noAgentExecutorConfig_throwsException() { }); } + @Test + public void execute_withBeforeExecuteCallback_cancelsExecutionOnError() { + // If callback returns error, execution should stop/fail. + Callbacks.BeforeExecuteCallback callback = + ctx -> Single.error(new RuntimeException("Cancelled")); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify error handling triggered cleanup and fail event + // The executor catches the error and emits failed event. + assertThat(enqueuedEvents).isNotEmpty(); + Object lastEvent = enqueuedEvents.get(enqueuedEvents.size() - 1); + assertThat(lastEvent).isInstanceOf(TaskStatusUpdateEvent.class); + TaskStatusUpdateEvent statusEvent = (TaskStatusUpdateEvent) lastEvent; + assertThat(statusEvent.getStatus().state().toString()).isEqualTo("FAILED"); + assertThat(statusEvent.getStatus().message().getParts().get(0)).isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.getStatus().message().getParts().get(0); + assertThat(textPart.getText()).contains("Cancelled"); + } + + @Test + public void execute_withBeforeExecuteCallback_skipsExecutionIfTrue() { + Callbacks.BeforeExecuteCallback callback = ctx -> Single.just(true); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + for (Object event : enqueuedEvents) { + if (event instanceof TaskStatusUpdateEvent) { + TaskStatusUpdateEvent statusEvent = (TaskStatusUpdateEvent) event; + System.out.println("statusEvent: " + statusEvent.getStatus().state().toString()); + // assertThat(statusEvent.getStatus().state().toString()).isEqualTo("CANCELLED"); + } + } + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent.isPresent()).isFalse(); + } + + @Test + public void execute_withAfterEventCallback_modifiesEvent() { + // Agent emits an event. Callback intercepts and modifies it. + Part textPart = Part.builder().text("Hello world").build(); + Event agentEvent = + Event.builder() + .id("event-1") + .author("agent") + .content(Content.builder().role("model").parts(ImmutableList.of(textPart)).build()) + .build(); + testAgent.setEventsToEmit(Flowable.just(agentEvent)); + + Callbacks.AfterEventCallback callback = + (ctx, event, sourceEvent) -> { + // Modify event by adding metadata + return Maybe.just( + new TaskArtifactUpdateEvent.Builder(event) + .metadata(ImmutableMap.of("modified", true)) + .build()); + }; + + AgentExecutorConfig config = AgentExecutorConfig.builder().afterEventCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent).isPresent(); + assertThat(artifactEvent.get().getMetadata()).containsEntry("modified", true); + } + + @Test + public void execute_withAfterExecuteCallback_modifiesStatus() { + testAgent.setEventsToEmit(Flowable.empty()); // Just complete + + Callbacks.AfterExecuteCallback callback = + (ctx, event) -> { + // Modify status to have different message + Message newMessage = + new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Modified completion"))) + .build(); + + return Maybe.just( + new TaskStatusUpdateEvent.Builder(event) + .status(new io.a2a.spec.TaskStatus(event.getStatus().state(), newMessage, null)) + .build()); + }; + + AgentExecutorConfig config = + AgentExecutorConfig.builder().afterExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify status event + Optional statusEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .filter(TaskStatusUpdateEvent::isFinal) + .findFirst(); + + assertThat(statusEvent).isPresent(); + assertThat(statusEvent.get().getStatus().message().getParts().get(0)) + .isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.get().getStatus().message().getParts().get(0); + assertThat(textPart.getText()).isEqualTo("Modified completion"); + } + + private RequestContext createRequestContext() { + Message message = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .parts(ImmutableList.of(new TextPart("trigger"))) + .build(); + + // Mock RequestContext properly or use a builder if available. + // RequestContext probably has builder. + // Assuming simple mock for now or constructor. + // Let's check AgentExecutor usage: ctx.getMessage(), ctx.getTaskId(), ctx.getContextId(). + RequestContext ctx = mock(RequestContext.class); + when(ctx.getMessage()).thenReturn(message); + when(ctx.getTaskId()).thenReturn("task-" + UUID.randomUUID()); + when(ctx.getContextId()).thenReturn("ctx-" + UUID.randomUUID()); + return ctx; + } + private static final class TestAgent extends BaseAgent { - private final Flowable eventsToEmit = Flowable.empty(); + private Flowable eventsToEmit = Flowable.empty(); TestAgent() { // BaseAgent constructor: name, description, examples, tools, model super("test_agent", "test", ImmutableList.of(), null, null); } + void setEventsToEmit(Flowable events) { + this.eventsToEmit = events; + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { return eventsToEmit;