diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index ca809dd9c4..cf745919a2 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -14,6 +14,7 @@ FunctionToolLookupKey, get_function_tool_lookup_key_for_tool, get_function_tool_namespace, + get_tool_trace_name_for_tool, ) from ..agent import Agent from ..exceptions import UserError @@ -24,6 +25,9 @@ from ..run_context import RunContextWrapper, TContext from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool from ..tool_context import ToolContext +from ..tracing import Span, agent_span +from ..tracing.span_data import AgentSpanData +from ..tracing.spans import NoOpSpan from ..util._approvals import evaluate_needs_approval_setting from .agent import RealtimeAgent from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput @@ -193,6 +197,7 @@ def __init__( self._guardrail_tasks: set[asyncio.Task[Any]] = set() self._tool_call_tasks: set[asyncio.Task[Any]] = set() self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True)) + self._current_agent_span: Span[AgentSpanData] | None = None @property def model(self) -> RealtimeModel: @@ -203,27 +208,59 @@ async def __aenter__(self) -> RealtimeSession: """Start the session by connecting to the model. After this, you will be able to stream events from the model and send messages and audio to the model. """ - # Add ourselves as a listener - self._model.add_listener(self) + # Create the agent span. Do not install it as the current ContextVar span: + # asyncio tasks inherit a snapshot of their parent's context, so a bg task + # cannot update the main task's context var. Installing the span would leave a + # stale (finished) span as "current" after any handoff that runs in a bg task. + # Agent spans are emitted as children of the enclosing trace without being set + # as current, which is correct and avoids all cross-task ContextVar management. + self._current_agent_span = self._make_agent_span(self._current_agent) + self._current_agent_span.start(mark_as_current=False) - model_config = self._model_config.copy() - model_config["initial_model_settings"] = await self._get_updated_model_settings_from_agent( - starting_settings=self._model_config.get("initial_model_settings", None), - agent=self._current_agent, - ) - - # Connect to the model - await self._model.connect(model_config) - - # Emit initial history update - await self._put_event( - RealtimeHistoryUpdated( - history=self._history, - info=self._event_info, + try: + # Add ourselves as a listener + self._model.add_listener(self) + + model_config = self._model_config.copy() + ( + initial_settings, + resolved_tools, + enabled_handoffs, + ) = await self._get_updated_model_settings_from_agent( + starting_settings=self._model_config.get("initial_model_settings", None), + agent=self._current_agent, + ) + model_config["initial_model_settings"] = initial_settings + + # Reuse the resolved tools/handoffs returned above — avoids a second call and + # ensures span metadata matches what was actually sent to the model, including + # any overrides applied by starting_settings. + if not isinstance(self._current_agent_span, NoOpSpan): + self._current_agent_span.span_data.tools = [ + n for t in resolved_tools if (n := get_tool_trace_name_for_tool(t)) is not None + ] or None + self._current_agent_span.span_data.handoffs = [ + h.agent_name for h in enabled_handoffs + ] or None + + # Connect to the model + await self._model.connect(model_config) + + # Emit initial history update + await self._put_event( + RealtimeHistoryUpdated( + history=self._history, + info=self._event_info, + ) ) - ) - return self + return self + except BaseException: + # __aexit__ is not called when __aenter__ raises, so clean up the span here. + if self._current_agent_span is not None: + self._current_agent_span.finish(reset_current=False) + self._current_agent_span = None + raise async def enter(self) -> RealtimeSession: """Enter the async context manager. We strongly recommend using the async context manager @@ -278,13 +315,31 @@ async def interrupt(self) -> None: async def update_agent(self, agent: RealtimeAgent) -> None: """Update the active agent for this session and apply its settings to the model.""" - self._current_agent = agent + # Finish the outgoing agent span before switching agents, mirroring the handoff path. + if self._current_agent_span is not None: + self._current_agent_span.finish(reset_current=False) - updated_settings = await self._get_updated_model_settings_from_agent( + self._current_agent = agent + self._current_agent_span = self._make_agent_span(self._current_agent) + self._current_agent_span.start(mark_as_current=False) + + ( + updated_settings, + resolved_tools, + enabled_handoffs, + ) = await self._get_updated_model_settings_from_agent( starting_settings=None, agent=self._current_agent, ) + if not isinstance(self._current_agent_span, NoOpSpan): + self._current_agent_span.span_data.tools = [ + n for t in resolved_tools if (n := get_tool_trace_name_for_tool(t)) is not None + ] or None + self._current_agent_span.span_data.handoffs = [ + h.agent_name for h in enabled_handoffs + ] or None + await self._model.send_event( RealtimeModelSendSessionUpdate(session_settings=updated_settings) ) @@ -815,15 +870,43 @@ async def _handle_tool_call( # Store previous agent for event previous_agent = agent + # Finish the span for the outgoing agent. Use reset_current=False because this + # runs inside an asyncio background task; resetting a token from a different + # context raises ValueError. + if self._current_agent_span is not None: + self._current_agent_span.finish(reset_current=False) + # Update current agent self._current_agent = result - # Get updated model settings from new agent - updated_settings = await self._get_updated_model_settings_from_agent( + # Create the incoming agent span. Because we never install agent spans as + # current (see __aenter__), this background task's context already holds the + # trace root as the current span — provider.create_span() will parent the new + # span to the trace root, making it a sibling of the outgoing agent span. + self._current_agent_span = self._make_agent_span(self._current_agent) + self._current_agent_span.start(mark_as_current=False) + + # Get updated model settings from new agent; reuse resolved tools and + # handoffs for span metadata to avoid a redundant second call. + ( + updated_settings, + resolved_tools, + enabled_handoffs, + ) = await self._get_updated_model_settings_from_agent( starting_settings=None, agent=self._current_agent, ) + if not isinstance(self._current_agent_span, NoOpSpan): + self._current_agent_span.span_data.tools = [ + n + for t in resolved_tools + if (n := get_tool_trace_name_for_tool(t)) is not None + ] or None + self._current_agent_span.span_data.handoffs = [ + h.agent_name for h in enabled_handoffs + ] or None + # Send handoff event await self._put_event( RealtimeHandoffEvent( @@ -1235,6 +1318,11 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return + # Finish the active agent span. + if self._current_agent_span is not None: + self._current_agent_span.finish(reset_current=False) + self._current_agent_span = None + # Cancel and cleanup guardrail tasks self._cleanup_guardrail_tasks() self._cleanup_tool_call_tasks() @@ -1253,11 +1341,28 @@ async def _cleanup(self) -> None: self._closed = True self._wake_event_iterators() + def _make_agent_span(self, agent: RealtimeAgent) -> Span[AgentSpanData]: + """Create a new agent span for the given agent, respecting tracing_disabled. + + Tool and handoff names are intentionally omitted here. Callers must populate + span_data.tools and span_data.handoffs from the tuple returned by + _get_updated_model_settings_from_agent() so that metadata reflects what was + actually sent to the model (after is_enabled filtering and any model_config overrides). + """ + disabled: bool = bool(self._run_config.get("tracing_disabled", False)) + return agent_span(name=agent.name, disabled=disabled) + async def _get_updated_model_settings_from_agent( self, starting_settings: RealtimeSessionModelSettings | None, agent: RealtimeAgent, - ) -> RealtimeSessionModelSettings: + ) -> tuple[RealtimeSessionModelSettings, list[Any], list[Any]]: + """Return (settings, final_tools, final_handoffs). + + final_tools and final_handoffs reflect the values in the returned settings after + starting_settings overrides are applied. Callers must use these for span metadata + to ensure the span reports exactly what was sent to the model. + """ # Start with the merged base settings from run and model configuration. updated_settings = self._base_model_settings.copy() @@ -1273,7 +1378,7 @@ async def _get_updated_model_settings_from_agent( updated_settings["tools"] = tools or [] updated_settings["handoffs"] = handoffs or [] - # Apply starting settings (from model config) next + # Apply starting_settings (from model config) — may override tools and handoffs. if starting_settings: updated_settings.update(starting_settings) @@ -1281,7 +1386,10 @@ async def _get_updated_model_settings_from_agent( if disable_tracing: updated_settings["tracing"] = None - return updated_settings + # Return the final tools/handoffs AFTER overrides so span metadata matches the model. + final_tools = list(updated_settings.get("tools") or []) + final_handoffs = list(updated_settings.get("handoffs") or []) + return updated_settings, final_tools, final_handoffs @classmethod async def _get_handoffs( diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 03148c739a..d4816f3a32 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -325,9 +325,11 @@ async def close(self): @pytest.fixture def mock_agent(): agent = Mock(spec=RealtimeAgent) + agent.name = "mock_agent" agent.get_all_tools = AsyncMock(return_value=[]) type(agent).handoffs = PropertyMock(return_value=[]) + type(agent).tools = PropertyMock(return_value=[]) type(agent).output_guardrails = PropertyMock(return_value=[]) return agent @@ -2463,9 +2465,11 @@ async def test_session_gets_model_settings_from_agent_during_connection(self): # Create agent with specific settings agent = Mock(spec=RealtimeAgent) + agent.name = "test_agent" agent.get_system_prompt = AsyncMock(return_value="Test agent instructions") agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "test_tool"}]) agent.handoffs = [] + agent.tools = [] session = RealtimeSession(mock_model, agent, None) @@ -2492,9 +2496,11 @@ async def test_model_config_overrides_model_settings_not_agent(self): mock_model.add_listener = Mock() agent = Mock(spec=RealtimeAgent) + agent.name = "test_agent" agent.get_system_prompt = AsyncMock(return_value="Agent instructions") agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "agent_tool"}]) agent.handoffs = [] + agent.tools = [] # Provide model config with settings model_config: RealtimeModelConfig = { @@ -2530,8 +2536,10 @@ async def test_handoffs_are_included_in_model_settings(self): # Create agent with handoffs agent = Mock(spec=RealtimeAgent) + agent.name = "test_agent" agent.get_system_prompt = AsyncMock(return_value="Agent with handoffs") agent.get_all_tools = AsyncMock(return_value=[]) + agent.tools = [] # Create a mock handoff handoff_agent = Mock(spec=RealtimeAgent) @@ -2619,7 +2627,7 @@ async def mock_get_handoffs(cls, agent, context_wrapper): m.setattr("agents.realtime.session.RealtimeSession._get_handoffs", mock_get_handoffs) # Test the method directly - model_settings = await session._get_updated_model_settings_from_agent( + model_settings, _, _ = await session._get_updated_model_settings_from_agent( starting_settings=model_config_initial_settings, agent=agent ) @@ -2669,7 +2677,7 @@ async def mock_get_handoffs(cls, agent, context_wrapper): with pytest.MonkeyPatch().context() as m: m.setattr("agents.realtime.session.RealtimeSession._get_handoffs", mock_get_handoffs) - model_settings = await session._get_updated_model_settings_from_agent( + model_settings, _, _ = await session._get_updated_model_settings_from_agent( starting_settings=None, # No initial settings agent=agent, ) @@ -2715,7 +2723,7 @@ async def mock_get_handoffs(cls, agent, context_wrapper): with pytest.MonkeyPatch().context() as m: m.setattr("agents.realtime.session.RealtimeSession._get_handoffs", mock_get_handoffs) - model_settings = await session._get_updated_model_settings_from_agent( + model_settings, _, _ = await session._get_updated_model_settings_from_agent( starting_settings=model_config_settings, agent=agent ) @@ -2762,7 +2770,7 @@ async def mock_get_handoffs(cls, agent, context_wrapper): mock_get_handoffs, ) - model_settings = await session._get_updated_model_settings_from_agent( + model_settings, _, _ = await session._get_updated_model_settings_from_agent( starting_settings=None, agent=agent, ) diff --git a/tests/realtime/test_session_exceptions.py b/tests/realtime/test_session_exceptions.py index da93902368..f306761154 100644 --- a/tests/realtime/test_session_exceptions.py +++ b/tests/realtime/test_session_exceptions.py @@ -89,9 +89,11 @@ async def interrupt(self) -> None: def fake_agent(): """Create a fake agent for testing.""" agent = Mock() + agent.name = "fake_agent" agent.get_all_tools = AsyncMock(return_value=[]) agent.get_system_prompt = AsyncMock(return_value="test instructions") agent.handoffs = [] + agent.tools = [] return agent diff --git a/tests/realtime/test_session_spans.py b/tests/realtime/test_session_spans.py new file mode 100644 index 0000000000..9d038e9753 --- /dev/null +++ b/tests/realtime/test_session_spans.py @@ -0,0 +1,402 @@ +"""Tests that RealtimeSession creates agent spans for SDK-level tracing.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from agents.realtime.agent import RealtimeAgent +from agents.realtime.model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener +from agents.realtime.model_events import RealtimeModelEvent, RealtimeModelToolCallEvent +from agents.realtime.session import RealtimeSession +from agents.tracing import trace +from agents.tracing.scope import Scope +from agents.tracing.span_data import AgentSpanData +from tests.testing_processor import SPAN_PROCESSOR_TESTING + + +class _FakeRealtimeModel(RealtimeModel): + """Minimal fake that never sends events and succeeds immediately.""" + + def __init__(self) -> None: + self._listeners: list[RealtimeModelListener] = [] + + def add_listener(self, listener: RealtimeModelListener) -> None: + self._listeners.append(listener) + + def remove_listener(self, listener: RealtimeModelListener) -> None: + if listener in self._listeners: + self._listeners.remove(listener) + + async def connect(self, options: RealtimeModelConfig) -> None: + pass + + async def close(self) -> None: + pass + + async def send_event(self, event: Any) -> None: + pass + + async def send_message( + self, message: Any, other_event_data: dict[str, Any] | None = None + ) -> None: + pass + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + pass + + async def send_tool_output(self, tool_call: Any, output: str, start_response: bool) -> None: + pass + + async def interrupt(self) -> None: + pass + + async def dispatch(self, event: RealtimeModelEvent) -> None: + """Send an event to all listeners (test helper).""" + for listener in self._listeners: + await listener.on_event(event) + + +def _make_session( + agent: RealtimeAgent, + model: _FakeRealtimeModel | None = None, + *, + tracing_disabled: bool = False, +) -> RealtimeSession: + return RealtimeSession( + model=model or _FakeRealtimeModel(), + agent=agent, + context=None, + run_config={"tracing_disabled": tracing_disabled} if tracing_disabled else {}, + ) + + +@pytest.mark.asyncio +async def test_session_creates_agent_span_on_enter(): + """Entering a RealtimeSession context must create an agent span.""" + agent = RealtimeAgent(name="greeter") + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert len(agent_spans) == 1, f"Expected 1 agent span, got {len(agent_spans)}" + + +@pytest.mark.asyncio +async def test_session_agent_span_has_correct_name(): + """The agent span name must match the RealtimeAgent name.""" + agent = RealtimeAgent(name="support_bot") + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].span_data.name == "support_bot" + + +@pytest.mark.asyncio +async def test_session_agent_span_finished_after_close(): + """The agent span must be finished (exported) once the session closes.""" + agent = RealtimeAgent(name="closer") + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].ended_at is not None + + +@pytest.mark.asyncio +async def test_session_span_includes_tool_names(): + """The agent span records the names of tools available to the agent.""" + from agents.tool import function_tool + + @function_tool + def my_tool() -> str: + """A test tool.""" + return "ok" + + agent = RealtimeAgent(name="tool_agent", tools=[my_tool]) + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].span_data.tools == ["my_tool"] + + +@pytest.mark.asyncio +async def test_session_span_includes_handoff_names(): + """The agent span records the names of handoff targets.""" + child = RealtimeAgent(name="specialist") + agent = RealtimeAgent(name="router", handoffs=[child]) + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].span_data.handoffs == ["specialist"] + + +@pytest.mark.asyncio +async def test_tracing_disabled_creates_no_agent_spans(): + """When tracing_disabled=True, no agent spans should be emitted.""" + agent = RealtimeAgent(name="silent") + session = _make_session(agent, tracing_disabled=True) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert len(agent_spans) == 0, f"Expected 0 agent spans, got {len(agent_spans)}" + + +@pytest.mark.asyncio +async def test_no_active_trace_does_not_poison_span_context(): + """Without an outer trace(), the session must not alter the ambient span context. + + Convention: RealtimeSession never installs agent spans as the ContextVar current span, + so the context is always unchanged before and after the session regardless of whether + a real trace exists. + """ + span_before = Scope.get_current_span() + agent = RealtimeAgent(name="agent") + session = _make_session(agent) + + # Enter/exit WITHOUT any enclosing trace. + async with session: + pass + + span_after = Scope.get_current_span() + assert span_before is span_after, "Session must not permanently alter the current span context." + + +@pytest.mark.asyncio +async def test_disabled_handoff_excluded_from_span_metadata(): + """Handoffs with is_enabled=False must not appear in span handoff metadata. + + Convention: span metadata must reflect what was actually sent to the model. + _get_handoffs() filters by is_enabled; raw agent.handoffs must not be used. + """ + from agents.realtime.handoffs import realtime_handoff + + specialist = RealtimeAgent(name="specialist") + disabled_handoff = realtime_handoff(specialist, is_enabled=False) + agent = RealtimeAgent(name="router", handoffs=[disabled_handoff]) + session = _make_session(agent) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].span_data.handoffs is None, ( + f"Disabled handoff should not appear in span metadata, " + f"got: {agent_spans[0].span_data.handoffs}" + ) + + +@pytest.mark.asyncio +async def test_cleanup_from_different_task_does_not_raise(): + """_cleanup called from a different asyncio task must not raise and must close the session. + + close() is public and __aiter__ also calls _cleanup when _stored_exception is set. + Both can run in a different asyncio task than __aenter__. + """ + agent = RealtimeAgent(name="agent") + session = _make_session(agent) + + with trace("test"): + await session.enter() + + async def close_from_other_task() -> None: + await session._cleanup() + + await asyncio.create_task(close_from_other_task()) + + assert session._closed is True + + +@pytest.mark.asyncio +async def test_span_context_unchanged_after_close_called_directly(): + """Ambient span context must be unchanged whether exited via async with or close(). + + Convention: RealtimeSession never installs agent spans as the ContextVar current span, + so close() has no context cleanup to perform; state before and after is identical. + """ + span_before = Scope.get_current_span() + agent = RealtimeAgent(name="agent") + session = _make_session(agent) + + with trace("test"): + await session.enter() + await session.close() + + span_after = Scope.get_current_span() + assert span_before is span_after, "Calling close() directly must not alter the span context." + + +@pytest.mark.asyncio +async def test_handoff_span_is_sibling_not_child_of_initial_span(): + """After a handoff the new agent span must be a sibling of the first, not its child. + + Convention: the incoming agent span's parent_id must not equal the outgoing agent + span's span_id. Both should be direct children of the trace root (parent_id=None). + """ + specialist = RealtimeAgent(name="specialist") + router = RealtimeAgent(name="router", handoffs=[specialist]) + model = _FakeRealtimeModel() + session = _make_session(router, model) + + with trace("test"): + async with session: + # Fire the handoff tool call that the model would send. + await model.dispatch( + RealtimeModelToolCallEvent( + name="transfer_to_specialist", + call_id="call_001", + arguments="{}", + ) + ) + # Let the background task spawned by async_tool_calls complete. + await asyncio.sleep(0.05) + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert len(agent_spans) == 2, ( + f"Expected 2 agent spans (router + specialist), got {len(agent_spans)}" + ) + + router_span = next(s for s in agent_spans if s.span_data.name == "router") + specialist_span = next(s for s in agent_spans if s.span_data.name == "specialist") + + assert specialist_span.parent_id != router_span.span_id, ( + "Specialist span must not be a child of the router span. " + f"specialist.parent_id={specialist_span.parent_id}, router.span_id={router_span.span_id}" + ) + + +@pytest.mark.asyncio +async def test_aenter_failure_finishes_span(): + """If __aenter__ raises after the span is started, the span must still be finished. + + Python does not call __aexit__ when __aenter__ raises, so the except BaseException + block in __aenter__ is the only cleanup path. Verify no unfinished span is leaked. + """ + + class _FailingConnectModel(_FakeRealtimeModel): + async def connect(self, options: Any) -> None: + raise RuntimeError("simulated connection failure") + + agent = RealtimeAgent(name="agent") + session = RealtimeSession( + model=_FailingConnectModel(), + agent=agent, + context=None, + run_config={}, + ) + + with trace("test"): + with pytest.raises(RuntimeError): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert len(agent_spans) == 1, f"Expected 1 agent span, got {len(agent_spans)}" + assert agent_spans[0].ended_at is not None, ( + "Agent span must be finished (not leaked) when __aenter__ raises." + ) + + +@pytest.mark.asyncio +async def test_span_tool_metadata_reflects_model_config_override(): + """model_config.initial_model_settings tool override must be reflected in span metadata. + + Convention: span metadata must match what was actually sent to the model. When + initial_model_settings overrides tools (e.g. to empty), the span must show the + override — not the agent's default tool list. + """ + from agents.tool import function_tool + + @function_tool + def my_tool() -> str: + """A test tool.""" + return "ok" + + agent = RealtimeAgent(name="tool_agent", tools=[my_tool]) + # model_config overrides tools with an empty list, wiping the agent's tool. + session = RealtimeSession( + model=_FakeRealtimeModel(), + agent=agent, + context=None, + model_config={"initial_model_settings": {"tools": []}}, + run_config={}, + ) + + with trace("test"): + async with session: + pass + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert agent_spans[0].span_data.tools is None, ( + f"model_config tool override must clear tools from span, " + f"got: {agent_spans[0].span_data.tools}" + ) + + +@pytest.mark.asyncio +async def test_update_agent_finishes_old_span_and_starts_new_one(): + """update_agent() must finish the outgoing span and emit a new span for the incoming agent. + + Convention: update_agent() is the public API equivalent of a handoff. It must mirror + the handoff path: finish the current agent span, then create and start a new one for + the incoming agent. Without this, activity after the switch is attributed to the wrong + agent and no span is emitted for the new agent. + """ + original = RealtimeAgent(name="original_agent") + replacement = RealtimeAgent(name="replacement_agent") + model = _FakeRealtimeModel() + session = _make_session(original, model) + + with trace("test"): + async with session: + await session.update_agent(replacement) + + spans = SPAN_PROCESSOR_TESTING.get_ordered_spans() + agent_spans = [s for s in spans if isinstance(s.span_data, AgentSpanData)] + assert len(agent_spans) == 2, ( + f"Expected 2 agent spans (original + replacement), got {len(agent_spans)}" + ) + + names = {s.span_data.name for s in agent_spans} + assert names == {"original_agent", "replacement_agent"}, ( + f"Expected spans for both agents, got: {names}" + ) + + original_span = next(s for s in agent_spans if s.span_data.name == "original_agent") + assert original_span.ended_at is not None, ( + "Original agent span must be finished after update_agent()" + ) diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index bacde6703c..ff6d297ec4 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -257,7 +257,7 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): ) # Test the _get_updated_model_settings_from_agent method directly - model_settings = await session._get_updated_model_settings_from_agent( + model_settings, _, _ = await session._get_updated_model_settings_from_agent( starting_settings=None, agent=agent )