diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index c954c90fc0..84bcf0b292 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -267,6 +267,9 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + When provided, the session's ``session_id`` is used as the A2A + ``context_id`` so that the remote agent can correlate + messages belonging to the same conversation. function_invocation_kwargs: Present for compatibility with the shared agent interface. A2AAgent does not use these values directly. client_kwargs: Present for compatibility with the shared agent interface. @@ -284,13 +287,17 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] When stream=True: A ResponseStream of AgentResponseUpdate items. """ del function_invocation_kwargs, client_kwargs, kwargs + # Derive context_id from session when available so the remote agent + # can correlate messages belonging to the same conversation. + context_id: str | None = session.session_id if session else None + if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( TaskIdParams(id=continuation_token["task_id"]) ) else: normalized_messages = normalize_messages(messages) - a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) + a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], context_id=context_id) a2a_stream = self.client.send_message(a2a_message) response = ResponseStream( @@ -403,7 +410,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp return AgentResponse.from_updates(updates) return AgentResponse(messages=[], response_id=task.id, raw_representation=task) - def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: + def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage: """Prepare a Message for the A2A protocol. Transforms Agent Framework Message objects into A2A protocol Messages by: @@ -412,6 +419,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: - Converting file references (URI/data/hosted_file) to FilePart objects - Preserving metadata and additional properties from the original message - Setting the role to 'user' as framework messages are treated as user input + + Args: + message: The framework Message to convert. + + Keyword Args: + context_id: Optional A2A context ID to associate this message with a + conversation session. When provided, the remote agent can correlate + multiple messages belonging to the same conversation. """ parts: list[A2APart] = [] if not message.contents: @@ -494,6 +509,7 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: parts=parts, message_id=message.message_id or uuid.uuid4().hex, metadata=metadata, + context_id=context_id or uuid.uuid4().hex, ) def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]: diff --git a/python/packages/a2a/tests/test_a2a_agent_context_id.py b/python/packages/a2a/tests/test_a2a_agent_context_id.py new file mode 100644 index 0000000000..1c5c21ea20 --- /dev/null +++ b/python/packages/a2a/tests/test_a2a_agent_context_id.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for A2AAgent session/context_id propagation (issue #4663).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from a2a.types import Role as A2ARole +from agent_framework import AgentSession, Content, Message + +from agent_framework_a2a._agent import A2AAgent + + +def _make_text_message(text: str = "hello") -> Message: + """Create a minimal text Message for testing.""" + return Message( + role="user", + contents=[Content.from_text(text=text)], + ) + + +def _make_agent() -> A2AAgent: + """Create an A2AAgent with a mock client for unit testing.""" + agent = A2AAgent(url="http://localhost:9999") + agent.client = MagicMock() # replace real client with mock + return agent + + +class TestContextIdPropagation: + """Tests verifying that session.session_id is propagated as A2A context_id.""" + + def test_context_id_set_from_session(self) -> None: + """When a session is provided, its session_id should become the A2A context_id.""" + agent = _make_agent() + session = AgentSession(session_id="my-session-123") + message = _make_text_message() + + a2a_msg = agent._prepare_message_for_a2a(message, context_id=session.session_id) + + assert a2a_msg.context_id == "my-session-123" + + def test_context_id_auto_generated_when_no_session(self) -> None: + """When no context_id is provided, a random one is generated.""" + agent = _make_agent() + message = _make_text_message() + + a2a_msg = agent._prepare_message_for_a2a(message) + + # Should have a non-empty context_id + assert a2a_msg.context_id is not None + assert len(a2a_msg.context_id) > 0 + + def test_context_id_none_generates_random(self) -> None: + """Explicitly passing context_id=None should also auto-generate.""" + agent = _make_agent() + message = _make_text_message() + + a2a_msg = agent._prepare_message_for_a2a(message, context_id=None) + + assert a2a_msg.context_id is not None + assert len(a2a_msg.context_id) > 0 + + def test_different_sessions_produce_different_context_ids(self) -> None: + """Different session IDs should produce different context_ids.""" + agent = _make_agent() + message = _make_text_message() + + msg1 = agent._prepare_message_for_a2a(message, context_id="session-A") + msg2 = agent._prepare_message_for_a2a(message, context_id="session-B") + + assert msg1.context_id != msg2.context_id + assert msg1.context_id == "session-A" + assert msg2.context_id == "session-B" + + def test_message_role_is_user(self) -> None: + """Outgoing messages should always have role='user'.""" + agent = _make_agent() + message = _make_text_message() + + a2a_msg = agent._prepare_message_for_a2a(message, context_id="test") + + assert a2a_msg.role == A2ARole.user + + def test_message_parts_preserved(self) -> None: + """Text content should be converted to A2A TextPart.""" + agent = _make_agent() + message = _make_text_message("test content") + + a2a_msg = agent._prepare_message_for_a2a(message, context_id="test") + + assert len(a2a_msg.parts) == 1 + assert a2a_msg.parts[0].root.text == "test content"