Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def __init__(

self._guardrail_tasks: set[asyncio.Task[Any]] = set()
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
# Background tasks that emit events from done-callbacks. asyncio only
# holds a weak reference to a task, so we keep a strong reference here
# until it completes to avoid the event being dropped if the task is
# garbage-collected mid-run.
self._event_tasks: set[asyncio.Task[Any]] = set()
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))

@property
Expand Down Expand Up @@ -1140,12 +1145,10 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
exception = task.exception()
if exception:
# Create an exception event instead of raising
asyncio.create_task(
self._put_event(
RealtimeError(
info=self._event_info,
error={"message": f"Guardrail task failed: {str(exception)}"},
)
self._emit_event_soon(
RealtimeError(
info=self._event_info,
error={"message": f"Guardrail task failed: {str(exception)}"},
)
)

Expand Down Expand Up @@ -1190,17 +1193,14 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
exception.call_id,
exc_info=exception,
)
asyncio.create_task(
self._put_event(
RealtimeError(
info=self._event_info,
error={
"message": (
"Tool output send failed; cached output will be retried: "
f"{exception}"
)
},
)
self._emit_event_soon(
RealtimeError(
info=self._event_info,
error={
"message": (
f"Tool output send failed; cached output will be retried: {exception}"
)
},
)
)
return
Expand All @@ -1210,12 +1210,10 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
if self._stored_exception is None:
self._stored_exception = exception

asyncio.create_task(
self._put_event(
RealtimeError(
info=self._event_info,
error={"message": f"Tool call task failed: {exception}"},
)
self._emit_event_soon(
RealtimeError(
info=self._event_info,
error={"message": f"Tool call task failed: {exception}"},
)
)

Expand All @@ -1225,6 +1223,22 @@ def _cleanup_tool_call_tasks(self) -> None:
task.cancel()
self._tool_call_tasks.clear()

def _emit_event_soon(self, event: RealtimeSessionEvent) -> None:
"""Schedule an event to be put on the queue without blocking the caller.

A strong reference to the task is retained until it completes so the
event is not dropped if the task would otherwise be garbage-collected.
"""
task = asyncio.create_task(self._put_event(event))
self._event_tasks.add(task)
task.add_done_callback(self._event_tasks.discard)

def _cleanup_event_tasks(self) -> None:
for task in self._event_tasks:
if not task.done():
task.cancel()
self._event_tasks.clear()

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
self._event_queue.put_nowait(_REALTIME_SESSION_CLOSED_SENTINEL)
Expand All @@ -1238,6 +1252,7 @@ async def _cleanup(self) -> None:
# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
self._cleanup_event_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down
67 changes: 67 additions & 0 deletions tests/realtime/test_session_event_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, Mock

import pytest

from agents.realtime.events import RealtimeError, RealtimeEventInfo
from agents.realtime.session import RealtimeSession


@pytest.fixture
def fake_agent():
agent = Mock()
agent.get_all_tools = AsyncMock(return_value=[])
agent.get_system_prompt = AsyncMock(return_value="test instructions")
agent.handoffs = []
return agent


@pytest.fixture
def fake_model():
return Mock()


class TestEmitEventSoon:
"""Background event tasks must be referenced so they are not dropped."""

@pytest.mark.asyncio
async def test_emit_event_soon_keeps_task_referenced_until_done(self, fake_model, fake_agent):
"""The scheduled task is tracked while pending and released when done.

asyncio only keeps a weak reference to a task, so a fire-and-forget
``create_task`` can be garbage-collected before it runs. ``_emit_event_soon``
retains a strong reference until completion and then delivers the event.
"""
session = RealtimeSession(fake_model, fake_agent, None)
event = RealtimeError(
info=RealtimeEventInfo(context=session._context_wrapper),
error={"message": "boom"},
)

session._emit_event_soon(event)

# While pending, the task is held in the tracking set (strong reference).
assert len(session._event_tasks) == 1

# Once it runs, the event reaches the queue and the reference is released.
delivered = await asyncio.wait_for(session._event_queue.get(), timeout=1)
assert delivered is event
await asyncio.sleep(0)
assert len(session._event_tasks) == 0

@pytest.mark.asyncio
async def test_cleanup_event_tasks_cancels_pending(self, fake_model, fake_agent):
"""Cleanup cancels any still-pending background event tasks."""
session = RealtimeSession(fake_model, fake_agent, None)
event = RealtimeError(
info=RealtimeEventInfo(context=session._context_wrapper),
error={"message": "boom"},
)

session._emit_event_soon(event)
assert len(session._event_tasks) == 1

session._cleanup_event_tasks()
assert len(session._event_tasks) == 0