|
20 | 20 | import io.a2a.server.agentexecution.RequestContext; |
21 | 21 | import io.a2a.server.events.EventQueue; |
22 | 22 | import io.a2a.server.tasks.TaskUpdater; |
| 23 | +import io.a2a.spec.Artifact; |
23 | 24 | import io.a2a.spec.InvalidAgentResponseError; |
24 | 25 | import io.a2a.spec.Message; |
25 | 26 | import io.a2a.spec.Part; |
| 27 | +import io.a2a.spec.TaskArtifactUpdateEvent; |
| 28 | +import io.a2a.spec.TaskState; |
| 29 | +import io.a2a.spec.TaskStatus; |
| 30 | +import io.a2a.spec.TaskStatusUpdateEvent; |
26 | 31 | import io.a2a.spec.TextPart; |
| 32 | +import io.reactivex.rxjava3.core.Completable; |
| 33 | +import io.reactivex.rxjava3.core.Flowable; |
27 | 34 | import io.reactivex.rxjava3.core.Maybe; |
| 35 | +import io.reactivex.rxjava3.core.Single; |
28 | 36 | import io.reactivex.rxjava3.disposables.CompositeDisposable; |
29 | 37 | import io.reactivex.rxjava3.disposables.Disposable; |
30 | 38 | import java.util.HashMap; |
|
43 | 51 | * use in production code. |
44 | 52 | */ |
45 | 53 | public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { |
46 | | - |
47 | 54 | private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); |
48 | 55 | private static final String USER_ID_PREFIX = "A2A_USER_"; |
49 | | - |
50 | 56 | private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>(); |
51 | 57 | private final Runner.Builder runnerBuilder; |
52 | 58 | private final AgentExecutorConfig agentExecutorConfig; |
@@ -137,7 +143,6 @@ public Builder plugins(List<? extends Plugin> plugins) { |
137 | 143 | return this; |
138 | 144 | } |
139 | 145 |
|
140 | | - @CanIgnoreReturnValue |
141 | 146 | public AgentExecutor build() { |
142 | 147 | return new AgentExecutor( |
143 | 148 | app, |
@@ -165,46 +170,88 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { |
165 | 170 | if (message == null) { |
166 | 171 | throw new IllegalArgumentException("Message cannot be null"); |
167 | 172 | } |
168 | | - |
169 | 173 | // Submits a new task if there is no active task. |
170 | 174 | if (ctx.getTask() == null) { |
171 | 175 | updater.submit(); |
172 | 176 | } |
173 | | - |
174 | 177 | // Group all reactive work for this task into one container |
175 | 178 | CompositeDisposable taskDisposables = new CompositeDisposable(); |
176 | 179 | // Check if the task with the task id is already running, put if absent. |
177 | 180 | if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { |
178 | 181 | throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); |
179 | 182 | } |
180 | | - |
181 | 183 | EventProcessor p = new EventProcessor(agentExecutorConfig.outputMode()); |
182 | 184 | Content content = PartConverter.messageToContent(message); |
183 | | - Runner runner = runnerBuilder.build(); |
| 185 | + Single<Boolean> skipExecution = |
| 186 | + agentExecutorConfig.beforeExecuteCallback() != null |
| 187 | + ? agentExecutorConfig.beforeExecuteCallback().call(ctx) |
| 188 | + : Single.just(false); |
184 | 189 |
|
| 190 | + Runner runner = runnerBuilder.build(); |
185 | 191 | taskDisposables.add( |
186 | | - prepareSession(ctx, runner.appName(), runner.sessionService()) |
| 192 | + skipExecution |
187 | 193 | .flatMapPublisher( |
188 | | - session -> { |
189 | | - updater.startWork(); |
190 | | - return runner.runAsync( |
191 | | - getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); |
| 194 | + skip -> { |
| 195 | + if (skip) { |
| 196 | + cancel(ctx, eventQueue); |
| 197 | + return Flowable.empty(); |
| 198 | + } |
| 199 | + return Maybe.defer( |
| 200 | + () -> { |
| 201 | + return prepareSession(ctx, runner.appName(), runner.sessionService()); |
| 202 | + }) |
| 203 | + .flatMapPublisher( |
| 204 | + session -> { |
| 205 | + updater.startWork(); |
| 206 | + return runner.runAsync( |
| 207 | + getUserId(ctx), |
| 208 | + session.id(), |
| 209 | + content, |
| 210 | + agentExecutorConfig.runConfig()); |
| 211 | + }); |
192 | 212 | }) |
193 | | - .subscribe( |
| 213 | + .concatMap( |
194 | 214 | event -> { |
195 | | - p.process(event, updater); |
196 | | - }, |
| 215 | + return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue) |
| 216 | + .toFlowable(); |
| 217 | + }) |
| 218 | + // Ignore all events from the runner, since they are already processed. |
| 219 | + .ignoreElements() |
| 220 | + .materialize() |
| 221 | + .flatMapCompletable( |
| 222 | + notification -> { |
| 223 | + Throwable error = notification.getError(); |
| 224 | + if (error != null) { |
| 225 | + logger.error("Runner failed to execute", error); |
| 226 | + } |
| 227 | + return handleExecutionEnd(ctx, error, eventQueue); |
| 228 | + }) |
| 229 | + .doFinally(() -> cleanupTask(ctx.getTaskId())) |
| 230 | + .subscribe( |
| 231 | + () -> {}, |
197 | 232 | error -> { |
198 | | - logger.error("Runner failed with {}", error); |
199 | | - updater.fail(failedMessage(ctx, error)); |
200 | | - cleanupTask(ctx.getTaskId()); |
201 | | - }, |
202 | | - () -> { |
203 | | - updater.complete(); |
204 | | - cleanupTask(ctx.getTaskId()); |
| 233 | + logger.error("Failed to handle execution end", error); |
205 | 234 | })); |
206 | 235 | } |
207 | 236 |
|
| 237 | + private Completable handleExecutionEnd( |
| 238 | + RequestContext ctx, Throwable error, EventQueue eventQueue) { |
| 239 | + TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED; |
| 240 | + Message message = error != null ? failedMessage(ctx, error) : null; |
| 241 | + TaskStatusUpdateEvent initialEvent = |
| 242 | + new TaskStatusUpdateEvent.Builder() |
| 243 | + .taskId(ctx.getTaskId()) |
| 244 | + .contextId(ctx.getContextId()) |
| 245 | + .isFinal(true) |
| 246 | + .status(new TaskStatus(state, message, null)) |
| 247 | + .build(); |
| 248 | + Maybe<TaskStatusUpdateEvent> afterExecute = |
| 249 | + agentExecutorConfig.afterExecuteCallback() != null |
| 250 | + ? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent) |
| 251 | + : Maybe.just(initialEvent); |
| 252 | + return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement(); |
| 253 | + } |
| 254 | + |
208 | 255 | private void cleanupTask(String taskId) { |
209 | 256 | Disposable d = activeTasks.remove(taskId); |
210 | 257 | if (d != null) { |
@@ -249,16 +296,19 @@ private EventProcessor(AgentExecutorConfig.OutputMode outputMode) { |
249 | 296 | this.outputMode = outputMode; |
250 | 297 | } |
251 | 298 |
|
252 | | - private void process(Event event, TaskUpdater updater) { |
| 299 | + private Maybe<TaskArtifactUpdateEvent> process( |
| 300 | + Event event, |
| 301 | + RequestContext ctx, |
| 302 | + Callbacks.AfterEventCallback callback, |
| 303 | + EventQueue eventQueue) { |
253 | 304 | if (event.errorCode().isPresent()) { |
254 | | - throw new InvalidAgentResponseError( |
255 | | - null, // Uses default code -32006 |
256 | | - "Agent returned an error: " + event.errorCode().get(), |
257 | | - null); |
| 305 | + return Maybe.error( |
| 306 | + new InvalidAgentResponseError( |
| 307 | + null, // Uses default code -32006 |
| 308 | + "Agent returned an error: " + event.errorCode().get(), |
| 309 | + null)); |
258 | 310 | } |
259 | | - |
260 | 311 | ImmutableList<Part<?>> parts = EventConverter.contentToParts(event.content()); |
261 | | - |
262 | 312 | // Mark all parts as partial if the event is partial. |
263 | 313 | if (event.partial().orElse(false)) { |
264 | 314 | parts.forEach( |
@@ -302,7 +352,26 @@ private void process(Event event, TaskUpdater updater) { |
302 | 352 | } |
303 | 353 | } |
304 | 354 |
|
305 | | - updater.addArtifact(parts, artifactId, null, metadata, append, lastChunk); |
| 355 | + TaskArtifactUpdateEvent initialEvent = |
| 356 | + new TaskArtifactUpdateEvent.Builder() |
| 357 | + .taskId(ctx.getTaskId()) |
| 358 | + .contextId(ctx.getContextId()) |
| 359 | + .lastChunk(lastChunk) |
| 360 | + .append(append) |
| 361 | + .artifact( |
| 362 | + new Artifact.Builder() |
| 363 | + .artifactId(artifactId) |
| 364 | + .parts(parts) |
| 365 | + .metadata(metadata) |
| 366 | + .build()) |
| 367 | + .build(); |
| 368 | + |
| 369 | + Maybe<TaskArtifactUpdateEvent> afterEvent = |
| 370 | + callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent); |
| 371 | + return afterEvent.doOnSuccess( |
| 372 | + finalEvent -> { |
| 373 | + eventQueue.enqueueEvent(finalEvent); |
| 374 | + }); |
306 | 375 | } |
307 | 376 | } |
308 | 377 | } |
0 commit comments