Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
When provided, the session's ``session_id`` is used as the A2A
``context_id`` so that the remote agent can correlate
messages belonging to the same conversation.
function_invocation_kwargs: Present for compatibility with the shared agent interface.
A2AAgent does not use these values directly.
client_kwargs: Present for compatibility with the shared agent interface.
Expand All @@ -284,13 +287,17 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
del function_invocation_kwargs, client_kwargs, kwargs
# Derive context_id from session when available so the remote agent
# can correlate messages belonging to the same conversation.
context_id: str | None = session.session_id if session else None

if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
)
else:
normalized_messages = normalize_messages(messages)
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], context_id=context_id)
a2a_stream = self.client.send_message(a2a_message)

response = ResponseStream(
Expand Down Expand Up @@ -403,7 +410,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp
return AgentResponse.from_updates(updates)
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)

def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
"""Prepare a Message for the A2A protocol.

Transforms Agent Framework Message objects into A2A protocol Messages by:
Expand All @@ -412,6 +419,14 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
- Converting file references (URI/data/hosted_file) to FilePart objects
- Preserving metadata and additional properties from the original message
- Setting the role to 'user' as framework messages are treated as user input

Args:
message: The framework Message to convert.

Keyword Args:
context_id: Optional A2A context ID to associate this message with a
conversation session. When provided, the remote agent can correlate
multiple messages belonging to the same conversation.
"""
parts: list[A2APart] = []
if not message.contents:
Expand Down Expand Up @@ -494,6 +509,7 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
parts=parts,
message_id=message.message_id or uuid.uuid4().hex,
metadata=metadata,
context_id=context_id or uuid.uuid4().hex,
)

def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
Expand Down
93 changes: 93 additions & 0 deletions python/packages/a2a/tests/test_a2a_agent_context_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Microsoft. All rights reserved.

"""Tests for A2AAgent session/context_id propagation (issue #4663)."""

from __future__ import annotations

from unittest.mock import MagicMock

from a2a.types import Role as A2ARole
from agent_framework import AgentSession, Content, Message

from agent_framework_a2a._agent import A2AAgent


def _make_text_message(text: str = "hello") -> Message:
"""Create a minimal text Message for testing."""
return Message(
role="user",
contents=[Content.from_text(text=text)],
)


def _make_agent() -> A2AAgent:
"""Create an A2AAgent with a mock client for unit testing."""
agent = A2AAgent(url="http://localhost:9999")
agent.client = MagicMock() # replace real client with mock
return agent


class TestContextIdPropagation:
"""Tests verifying that session.session_id is propagated as A2A context_id."""

def test_context_id_set_from_session(self) -> None:
"""When a session is provided, its session_id should become the A2A context_id."""
agent = _make_agent()
session = AgentSession(session_id="my-session-123")
message = _make_text_message()

a2a_msg = agent._prepare_message_for_a2a(message, context_id=session.session_id)

assert a2a_msg.context_id == "my-session-123"

def test_context_id_auto_generated_when_no_session(self) -> None:
"""When no context_id is provided, a random one is generated."""
agent = _make_agent()
message = _make_text_message()

a2a_msg = agent._prepare_message_for_a2a(message)

# Should have a non-empty context_id
assert a2a_msg.context_id is not None
assert len(a2a_msg.context_id) > 0

def test_context_id_none_generates_random(self) -> None:
"""Explicitly passing context_id=None should also auto-generate."""
agent = _make_agent()
message = _make_text_message()

a2a_msg = agent._prepare_message_for_a2a(message, context_id=None)

assert a2a_msg.context_id is not None
assert len(a2a_msg.context_id) > 0

def test_different_sessions_produce_different_context_ids(self) -> None:
"""Different session IDs should produce different context_ids."""
agent = _make_agent()
message = _make_text_message()

msg1 = agent._prepare_message_for_a2a(message, context_id="session-A")
msg2 = agent._prepare_message_for_a2a(message, context_id="session-B")

assert msg1.context_id != msg2.context_id
assert msg1.context_id == "session-A"
assert msg2.context_id == "session-B"

def test_message_role_is_user(self) -> None:
"""Outgoing messages should always have role='user'."""
agent = _make_agent()
message = _make_text_message()

a2a_msg = agent._prepare_message_for_a2a(message, context_id="test")

assert a2a_msg.role == A2ARole.user

def test_message_parts_preserved(self) -> None:
"""Text content should be converted to A2A TextPart."""
agent = _make_agent()
message = _make_text_message("test content")

a2a_msg = agent._prepare_message_for_a2a(message, context_id="test")

assert len(a2a_msg.parts) == 1
assert a2a_msg.parts[0].root.text == "test content"