diff --git a/server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java b/server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java index c82df7846..beee9bf7e 100644 --- a/server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java +++ b/server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java @@ -74,50 +74,52 @@ */ public class RequestContext { - private @Nullable MessageSendParams params; - private @Nullable String taskId; - private @Nullable String contextId; - private @Nullable Task task; - private List relatedTasks; + private final @Nullable MessageSendParams params; + private final String taskId; + private final String contextId; + private final @Nullable Task task; + private final List relatedTasks; private final @Nullable ServerCallContext callContext; - public RequestContext( + /** + * Constructor with all fields already validated and initialized. + *

+ * Note: Use {@link Builder} instead of calling this constructor directly. + * The builder handles ID generation and validation. + *

+ * + * @param params the message send parameters (can be null for cancel operations) + * @param taskId the task identifier (must not be null) + * @param contextId the context identifier (must not be null) + * @param task the existing task state (null for new conversations) + * @param relatedTasks other tasks in the same context (must not be null, can be empty) + * @param callContext the server call context (can be null) + */ + private RequestContext( @Nullable MessageSendParams params, - @Nullable String taskId, - @Nullable String contextId, + String taskId, + String contextId, @Nullable Task task, - @Nullable List relatedTasks, - @Nullable ServerCallContext callContext) throws InvalidParamsError { + List relatedTasks, + @Nullable ServerCallContext callContext) { this.params = params; this.taskId = taskId; this.contextId = contextId; this.task = task; - this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks; + this.relatedTasks = relatedTasks; this.callContext = callContext; - - // If the taskId and contextId were specified, they must match the params - if (params != null) { - if (taskId != null && !taskId.equals(params.message().taskId())) { - throw new InvalidParamsError("bad task id"); - } - this.taskId = checkOrGenerateTaskId(); - if (contextId != null && !contextId.equals(params.message().contextId())) { - throw new InvalidParamsError("bad context id"); - } - this.contextId = checkOrGenerateContextId(); - } } /** * Returns the task identifier. *

- * This is auto-generated (UUID) if not provided by the client in the message parameters. - * It can be null if the context was not created from message parameters. + * This is auto-generated (UUID) by the builder if not provided by the client + * in the message parameters. This value is never null. *

* - * @return the task ID + * @return the task ID (never null) */ - public @Nullable String getTaskId() { + public String getTaskId() { return taskId; } @@ -125,13 +127,13 @@ public RequestContext( * Returns the conversation context identifier. *

* Conversation contexts group related tasks together (e.g., multiple tasks - * in the same user session). This is auto-generated (UUID) if not provided by the client - * in the message parameters. It can be null if the context was not created from message parameters. + * in the same user session). This is auto-generated (UUID) by the builder if + * not provided by the client in the message parameters. This value is never null. *

* - * @return the context ID + * @return the context ID (never null) */ - public @Nullable String getContextId() { + public String getContextId() { return contextId; } @@ -209,6 +211,19 @@ public List getRelatedTasks() { return callContext; } + /** + * Returns the tenant identifier from the request parameters. + *

+ * The tenant is used in multi-tenant environments to identify which + * customer or organization the request belongs to. + *

+ * + * @return the tenant identifier, or null if no params or tenant not set + */ + public @Nullable String getTenant() { + return params != null ? params.tenant() : null; + } + /** * Extracts all text content from the message and joins with the specified delimiter. *

@@ -240,48 +255,20 @@ public String getUserInput(String delimiter) { return getMessageText(params.message(), delimiter); } + /** + * Attaches a related task to this context. + *

+ * This is primarily used by the framework to populate related tasks after + * construction. Agent implementations should use {@link #getRelatedTasks()} + * to access related tasks. + *

+ * + * @param task the task to attach + */ public void attachRelatedTask(Task task) { relatedTasks.add(task); } - private @Nullable String checkOrGenerateTaskId() { - if (params == null) { - return taskId; - } - if (taskId == null && params.message().taskId() == null) { - // Message is immutable, create new one with generated taskId - String generatedTaskId = UUID.randomUUID().toString(); - Message updatedMessage = Message.builder(params.message()) - .taskId(generatedTaskId) - .build(); - params = new MessageSendParams(updatedMessage, params.configuration(), params.metadata()); - return generatedTaskId; - } - if (params.message().taskId() != null) { - return params.message().taskId(); - } - return taskId; - } - - private @Nullable String checkOrGenerateContextId() { - if (params == null) { - return contextId; - } - if (contextId == null && params.message().contextId() == null) { - // Message is immutable, create new one with generated contextId - String generatedContextId = UUID.randomUUID().toString(); - Message updatedMessage = Message.builder(params.message()) - .contextId(generatedContextId) - .build(); - params = new MessageSendParams(updatedMessage, params.configuration(), params.metadata()); - return generatedContextId; - } - if (params.message().contextId() != null) { - return params.message().contextId(); - } - return contextId; - } - private String getMessageText(Message message, String delimiter) { List textParts = getTextParts(message.parts()); return String.join(delimiter, textParts); @@ -295,6 +282,18 @@ private List getTextParts(List> parts) { .collect(Collectors.toList()); } + /** + * Builder for creating {@link RequestContext} instances. + *

+ * The builder handles ID generation and validation automatically: + *

+ *
    + *
  • TaskId and ContextId are auto-generated (UUID) if not provided
  • + *
  • IDs are validated against message parameters if both are present
  • + *
  • Message parameters are updated with generated IDs
  • + *
  • Related tasks list is initialized to empty list if null
  • + *
+ */ public static class Builder { private @Nullable MessageSendParams params; private @Nullable String taskId; @@ -357,8 +356,56 @@ public Builder setServerCallContext(@Nullable ServerCallContext serverCallContex return serverCallContext; } - public RequestContext build() { - return new RequestContext(params, taskId, contextId, task, relatedTasks, serverCallContext); + /** + * Builds the RequestContext with ID generation and validation. + * + * @return the constructed RequestContext + * @throws InvalidParamsError if taskId or contextId don't match message parameters + */ + public RequestContext build() throws InvalidParamsError { + // 1. Initialize relatedTasks to empty list if null + List finalRelatedTasks = relatedTasks != null ? relatedTasks : new ArrayList<>(); + + // 2. Extract message IDs upfront (or null if no params) + String messageTaskId = params != null ? params.message().taskId() : null; + String messageContextId = params != null ? params.message().contextId() : null; + + // 3. Validate: if both builder and message provide an ID, they must match + if (taskId != null && messageTaskId != null && !taskId.equals(messageTaskId)) { + throw new InvalidParamsError("bad task id"); + } + if (contextId != null && messageContextId != null && !contextId.equals(messageContextId)) { + throw new InvalidParamsError("bad context id"); + } + + // 4. Determine final IDs using coalesce pattern: builder → message → generate + String finalTaskId = taskId != null ? taskId : + messageTaskId != null ? messageTaskId : + UUID.randomUUID().toString(); + + String finalContextId = contextId != null ? contextId : + messageContextId != null ? messageContextId : + UUID.randomUUID().toString(); + + // 5. Update params if message needs to be updated with final IDs + MessageSendParams finalParams = params; + if (params != null && (!finalTaskId.equals(messageTaskId) || !finalContextId.equals(messageContextId))) { + Message updatedMessage = Message.builder(params.message()) + .taskId(finalTaskId) + .contextId(finalContextId) + .build(); + // Preserve all original fields including tenant + finalParams = MessageSendParams.builder() + .message(updatedMessage) + .configuration(params.configuration()) + .metadata(params.metadata()) + .tenant(params.tenant()) + .build(); + } + + // 6. Call constructor with finalized values (IDs guaranteed non-null) + return new RequestContext(finalParams, finalTaskId, finalContextId, + task, finalRelatedTasks, serverCallContext); } } diff --git a/server-common/src/main/java/io/a2a/server/tasks/AgentEmitter.java b/server-common/src/main/java/io/a2a/server/tasks/AgentEmitter.java index 7c6c97f20..6088022b6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/AgentEmitter.java +++ b/server-common/src/main/java/io/a2a/server/tasks/AgentEmitter.java @@ -18,7 +18,6 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; -import io.a2a.util.Assert; import org.jspecify.annotations.Nullable; /** @@ -108,8 +107,8 @@ public class AgentEmitter { */ public AgentEmitter(RequestContext context, EventQueue eventQueue) { this.eventQueue = eventQueue; - this.taskId = Assert.checkNotNullParam("taskId",context.getTaskId()); - this.contextId = Assert.checkNotNullParam("contextId",context.getContextId()); + this.taskId = context.getTaskId(); + this.contextId = context.getContextId(); } private void updateStatus(TaskState taskState) { diff --git a/server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java b/server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java index ba692d97c..0fc9c16ca 100644 --- a/server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java +++ b/server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java @@ -35,10 +35,11 @@ private static MessageSendConfiguration defaultConfiguration() { @Test public void testInitWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null, null); + RequestContext context = new RequestContext.Builder().build(); + assertNull(context.getMessage()); - assertNull(context.getTaskId()); - assertNull(context.getContextId()); + assertNotNull(context.getTaskId()); // Generated UUID + assertNotNull(context.getContextId()); // Generated UUID assertNull(context.getTask()); assertTrue(context.getRelatedTasks().isEmpty()); } @@ -56,7 +57,9 @@ public void testInitWithParamsNoIds() { .thenReturn(taskId) .thenReturn(contextId); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); // getMessage() returns a new Message with generated IDs, not the original assertNotNull(context.getMessage()); @@ -73,7 +76,10 @@ public void testInitWithTaskId() { var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, taskId, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(taskId) + .build(); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().taskId()); @@ -84,7 +90,11 @@ public void testInitWithContextId() { String contextId = "context-456"; var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(contextId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, contextId, null, null, null); + + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setContextId(contextId) + .build(); assertEquals(contextId, context.getContextId()); assertEquals(contextId, mockParams.message().contextId()); @@ -96,7 +106,12 @@ public void testInitWithBothIds() { String contextId = "context-456"; var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).contextId(contextId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null, null); + + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(taskId) + .setContextId(contextId) + .build(); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().taskId()); @@ -110,14 +125,17 @@ public void testInitWithTask() { var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, mockTask, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTask(mockTask) + .build(); assertEquals(mockTask, context.getTask()); } @Test public void testGetUserInputNoParams() { - RequestContext context = new RequestContext(null, null, null, null, null, null); + RequestContext context = new RequestContext.Builder().build(); assertEquals("", context.getUserInput(null)); } @@ -125,7 +143,7 @@ public void testGetUserInputNoParams() { public void testAttachRelatedTask() { var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(null, null, null, null, null, null); + RequestContext context = new RequestContext.Builder().build(); assertEquals(0, context.getRelatedTasks().size()); context.attachRelatedTask(mockTask); @@ -144,7 +162,9 @@ public void testCheckOrGenerateTaskIdWithExistingTaskId() { var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(existingId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); assertEquals(existingId, context.getTaskId()); assertEquals(existingId, mockParams.message().taskId()); @@ -157,7 +177,9 @@ public void testCheckOrGenerateContextIdWithExistingContextId() { var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(existingId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); assertEquals(existingId, context.getContextId()); assertEquals(existingId, mockParams.message().contextId()); @@ -170,7 +192,11 @@ public void testInitRaisesErrorOnTaskIdMismatch() { var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, "wrong-task-id", null, mockTask, null, null)); + new RequestContext.Builder() + .setParams(mockParams) + .setTaskId("wrong-task-id") + .setTask(mockTask) + .build()); assertTrue(error.getMessage().contains("bad task id")); } @@ -182,7 +208,12 @@ public void testInitRaisesErrorOnContextIdMismatch() { var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, mockTask.id(), "wrong-context-id", mockTask, null, null)); + new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(mockTask.id()) + .setContextId("wrong-context-id") + .setTask(mockTask) + .build()); assertTrue(error.getMessage().contains("bad context id")); } @@ -195,7 +226,9 @@ public void testWithRelatedTasksProvided() { relatedTasks.add(mockTask); relatedTasks.add(mock(Task.class)); - RequestContext context = new RequestContext(null, null, null, null, relatedTasks, null); + RequestContext context = new RequestContext.Builder() + .setRelatedTasks(relatedTasks) + .build(); assertEquals(relatedTasks, context.getRelatedTasks()); assertEquals(2, context.getRelatedTasks().size()); @@ -203,7 +236,7 @@ public void testWithRelatedTasksProvided() { @Test public void testMessagePropertyWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null, null); + RequestContext context = new RequestContext.Builder().build(); assertNull(context.getMessage()); } @@ -212,7 +245,10 @@ public void testMessagePropertyWithParams() { var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); + // getMessage() returns a new Message with generated IDs, not the original assertNotNull(context.getMessage()); assertEquals(mockMessage.role(), context.getMessage().role()); @@ -228,7 +264,9 @@ public void testInitWithExistingIdsInMessage() { .taskId(existingTaskId).contextId(existingContextId).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); assertEquals(existingTaskId, context.getTaskId()); assertEquals(existingContextId, context.getContextId()); @@ -240,8 +278,11 @@ public void testInitWithTaskIdAndExistingTaskIdMatch() { var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - - RequestContext context = new RequestContext(mockParams, mockTask.id(), null, mockTask, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(mockTask.id()) + .setTask(mockTask) + .build(); assertEquals(mockTask.id(), context.getTaskId()); assertEquals(mockTask, context.getTask()); @@ -253,8 +294,12 @@ public void testInitWithContextIdAndExistingContextIdMatch() { var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); var mockTask = Task.builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - - RequestContext context = new RequestContext(mockParams, mockTask.id(), mockTask.contextId(), mockTask, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(mockTask.id()) + .setContextId(mockTask.contextId()) + .setTask(mockTask) + .build(); assertEquals(mockTask.contextId(), context.getContextId()); assertEquals(mockTask, context.getTask()); @@ -265,7 +310,10 @@ void testMessageBuilderGeneratesId() { var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); + assertNotNull(mockMessage.messageId()); assertFalse(mockMessage.messageId().isEmpty()); } @@ -275,7 +323,110 @@ void testMessageBuilderUsesProvidedId() { var mockMessage = Message.builder().messageId("123").role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null, null); + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .build(); + assertEquals("123", mockMessage.messageId()); } + + @Test + public void testBuilderGeneratesIdsWhenNoParams() { + RequestContext context = new RequestContext.Builder() + .build(); + + assertNotNull(context.getTaskId()); + assertNotNull(context.getContextId()); + assertFalse(context.getTaskId().isEmpty()); + assertFalse(context.getContextId().isEmpty()); + } + + @Test + public void testBuilderPreservesProvidedIdsWhenNoParams() { + String providedTaskId = "my-task-id"; + String providedContextId = "my-context-id"; + + RequestContext context = new RequestContext.Builder() + .setTaskId(providedTaskId) + .setContextId(providedContextId) + .build(); + + assertEquals(providedTaskId, context.getTaskId()); + assertEquals(providedContextId, context.getContextId()); + } + + @Test + public void testBuilderUpdatesMessageWithBuilderIds() { + // Regression test for Gemini review: ensure message gets updated when builder provides IDs + String builderTaskId = "builder-task-id"; + String builderContextId = "builder-context-id"; + + // Message has no IDs, but builder provides them + var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); + var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); + + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(builderTaskId) + .setContextId(builderContextId) + .build(); + + // Both context and message should have the builder IDs + assertEquals(builderTaskId, context.getTaskId()); + assertEquals(builderContextId, context.getContextId()); + assertEquals(builderTaskId, context.getMessage().taskId()); // KEY: message must be updated + assertEquals(builderContextId, context.getMessage().contextId()); // KEY: message must be updated + } + + @Test + public void testMessageIdsTakePrecedenceWhenBothPresent() { + // When both builder and message provide IDs, they must match (or throw) + String sharedTaskId = "shared-task-id"; + String sharedContextId = "shared-context-id"; + + var mockMessage = Message.builder() + .role(Message.Role.USER) + .parts(List.of(new TextPart(""))) + .taskId(sharedTaskId) + .contextId(sharedContextId) + .build(); + var mockParams = MessageSendParams.builder().message(mockMessage).configuration(defaultConfiguration()).build(); + + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(sharedTaskId) // Same as message + .setContextId(sharedContextId) // Same as message + .build(); + + assertEquals(sharedTaskId, context.getTaskId()); + assertEquals(sharedContextId, context.getContextId()); + assertEquals(sharedTaskId, context.getMessage().taskId()); + assertEquals(sharedContextId, context.getMessage().contextId()); + } + + @Test + public void testBuilderPreservesTenantWhenUpdatingMessage() { + // Regression test for Gemini review: ensure tenant is preserved when message is updated + String tenantId = "customer-123"; + String builderTaskId = "builder-task-id"; + + var mockMessage = Message.builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); + var mockParams = MessageSendParams.builder() + .message(mockMessage) + .configuration(defaultConfiguration()) + .tenant(tenantId) + .build(); + + RequestContext context = new RequestContext.Builder() + .setParams(mockParams) + .setTaskId(builderTaskId) // Forces message update + .build(); + + // Verify the message was updated with builder's task ID + assertNotNull(context.getMessage()); + assertEquals(builderTaskId, context.getMessage().taskId()); + + // KEY: Verify tenant wasn't lost during message update + assertEquals(tenantId, context.getTenant()); + } }