Skip to content

Commit dc7de0f

Browse files
committed
test(openai_agents): cover real usage capture, span emission, and prompt_cache_key
Adds five tests in TestStreamingModelUsageAndCacheKey covering the three changes in the prior commit: - Usage is captured from ResponseCompletedEvent.response.usage. - Usage falls back to zeros when the stream ends without a completed event. - The streaming span's output_data carries the usage block. - prompt_cache_key defaults to streaming_task_id contextvar. - Callers can override prompt_cache_key via model_settings.extra_args. The existing TestStreamingModel* tests in this file are unrelated and remain broken for reasons predating this change (incorrect _mock_adk_streaming fixture lookup and unset contextvars).
1 parent 9e86dbd commit dc7de0f

1 file changed

Lines changed: 217 additions & 1 deletion

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,4 +845,220 @@ async def test_missing_task_id_error(self, streaming_model):
845845
handoffs=[],
846846
tracing=None,
847847
task_id=None # Missing task_id
848-
)
848+
)
849+
850+
851+
class TestStreamingModelUsageAndCacheKey:
852+
"""Tests for real-Usage capture, span output_data["usage"], and prompt_cache_key routing."""
853+
854+
@staticmethod
855+
def _async_iter(events):
856+
async def _gen():
857+
for event in events:
858+
yield event
859+
return _gen()
860+
861+
@staticmethod
862+
def _make_response_completed_event(
863+
*,
864+
input_tokens: int = 0,
865+
output_tokens: int = 0,
866+
total_tokens: int = 0,
867+
cached_tokens: int = 0,
868+
reasoning_tokens: int = 0,
869+
with_usage: bool = True,
870+
):
871+
from openai.types.responses import ResponseCompletedEvent
872+
873+
usage = MagicMock()
874+
usage.input_tokens = input_tokens
875+
usage.output_tokens = output_tokens
876+
usage.total_tokens = total_tokens
877+
usage.input_tokens_details = MagicMock(cached_tokens=cached_tokens)
878+
usage.output_tokens_details = MagicMock(reasoning_tokens=reasoning_tokens)
879+
880+
response = MagicMock()
881+
response.output = []
882+
response.usage = usage if with_usage else None
883+
884+
event = MagicMock(spec=ResponseCompletedEvent)
885+
event.response = response
886+
return event
887+
888+
@pytest.fixture
889+
def context_set(self):
890+
"""Set/reset the streaming contextvars used by get_response."""
891+
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
892+
streaming_task_id,
893+
streaming_trace_id,
894+
streaming_parent_span_id,
895+
)
896+
task_token = streaming_task_id.set("test-task-abc")
897+
trace_token = streaming_trace_id.set("test-trace-123")
898+
parent_token = streaming_parent_span_id.set("test-parent-span")
899+
try:
900+
yield streaming_task_id, streaming_trace_id, streaming_parent_span_id
901+
finally:
902+
streaming_task_id.reset(task_token)
903+
streaming_trace_id.reset(trace_token)
904+
streaming_parent_span_id.reset(parent_token)
905+
906+
@pytest.fixture
907+
def mock_span(self):
908+
return MagicMock()
909+
910+
@pytest.fixture
911+
def streaming_model_with_mock_tracer(self, streaming_model, mock_span):
912+
"""A streaming_model whose tracer.trace().span(...) yields a captured mock span."""
913+
async_cm = MagicMock()
914+
async_cm.__aenter__ = AsyncMock(return_value=mock_span)
915+
async_cm.__aexit__ = AsyncMock(return_value=False)
916+
trace_obj = MagicMock()
917+
trace_obj.span = MagicMock(return_value=async_cm)
918+
streaming_model.tracer = MagicMock()
919+
streaming_model.tracer.trace = MagicMock(return_value=trace_obj)
920+
return streaming_model
921+
922+
@pytest.mark.asyncio
923+
async def test_get_response_captures_usage_from_completed_event(
924+
self,
925+
streaming_model_with_mock_tracer,
926+
context_set, # noqa: ARG002
927+
):
928+
model = streaming_model_with_mock_tracer
929+
completed = self._make_response_completed_event(
930+
input_tokens=1234,
931+
output_tokens=56,
932+
total_tokens=1290,
933+
cached_tokens=987,
934+
reasoning_tokens=42,
935+
)
936+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
937+
938+
response = await model.get_response(
939+
system_instructions=None,
940+
input="hi",
941+
model_settings=ModelSettings(),
942+
tools=[],
943+
output_schema=None,
944+
handoffs=[],
945+
tracing=None,
946+
)
947+
948+
assert response.usage.input_tokens == 1234
949+
assert response.usage.output_tokens == 56
950+
assert response.usage.total_tokens == 1290
951+
assert response.usage.input_tokens_details.cached_tokens == 987
952+
assert response.usage.output_tokens_details.reasoning_tokens == 42
953+
954+
@pytest.mark.asyncio
955+
async def test_get_response_usage_falls_back_when_no_completed_event(
956+
self,
957+
streaming_model_with_mock_tracer,
958+
context_set, # noqa: ARG002
959+
):
960+
model = streaming_model_with_mock_tracer
961+
# Stream ends with no ResponseCompletedEvent
962+
model.client.responses.create = AsyncMock(return_value=self._async_iter([]))
963+
964+
response = await model.get_response(
965+
system_instructions=None,
966+
input="hi",
967+
model_settings=ModelSettings(),
968+
tools=[],
969+
output_schema=None,
970+
handoffs=[],
971+
tracing=None,
972+
)
973+
974+
assert response.usage.input_tokens == 0
975+
assert response.usage.output_tokens == 0
976+
assert response.usage.total_tokens == 0
977+
assert response.usage.input_tokens_details.cached_tokens == 0
978+
assert response.usage.output_tokens_details.reasoning_tokens == 0
979+
980+
@pytest.mark.asyncio
981+
async def test_get_response_emits_usage_in_span_output(
982+
self,
983+
streaming_model_with_mock_tracer,
984+
context_set, # noqa: ARG002
985+
mock_span,
986+
):
987+
model = streaming_model_with_mock_tracer
988+
completed = self._make_response_completed_event(
989+
input_tokens=100,
990+
output_tokens=10,
991+
total_tokens=110,
992+
cached_tokens=80,
993+
reasoning_tokens=5,
994+
)
995+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
996+
997+
await model.get_response(
998+
system_instructions=None,
999+
input="hi",
1000+
model_settings=ModelSettings(),
1001+
tools=[],
1002+
output_schema=None,
1003+
handoffs=[],
1004+
tracing=None,
1005+
)
1006+
1007+
assert isinstance(mock_span.output, dict)
1008+
assert "usage" in mock_span.output
1009+
usage_block = mock_span.output["usage"]
1010+
assert usage_block["input_tokens"] == 100
1011+
assert usage_block["output_tokens"] == 10
1012+
assert usage_block["total_tokens"] == 110
1013+
assert usage_block["cached_input_tokens"] == 80
1014+
assert usage_block["reasoning_tokens"] == 5
1015+
1016+
@pytest.mark.asyncio
1017+
async def test_get_response_passes_prompt_cache_key_from_contextvar(
1018+
self,
1019+
streaming_model_with_mock_tracer,
1020+
context_set, # noqa: ARG002
1021+
):
1022+
model = streaming_model_with_mock_tracer
1023+
completed = self._make_response_completed_event()
1024+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1025+
1026+
await model.get_response(
1027+
system_instructions=None,
1028+
input="hi",
1029+
model_settings=ModelSettings(),
1030+
tools=[],
1031+
output_schema=None,
1032+
handoffs=[],
1033+
tracing=None,
1034+
)
1035+
1036+
kwargs = model.client.responses.create.call_args.kwargs
1037+
assert kwargs["prompt_cache_key"] == "test-task-abc"
1038+
1039+
@pytest.mark.asyncio
1040+
async def test_get_response_caller_override_for_prompt_cache_key(
1041+
self,
1042+
streaming_model_with_mock_tracer,
1043+
context_set, # noqa: ARG002
1044+
):
1045+
model = streaming_model_with_mock_tracer
1046+
completed = self._make_response_completed_event()
1047+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1048+
1049+
await model.get_response(
1050+
system_instructions=None,
1051+
input="hi",
1052+
model_settings=ModelSettings(extra_args={"prompt_cache_key": "my-key"}),
1053+
tools=[],
1054+
output_schema=None,
1055+
handoffs=[],
1056+
tracing=None,
1057+
)
1058+
1059+
kwargs = model.client.responses.create.call_args.kwargs
1060+
assert kwargs["prompt_cache_key"] == "my-key"
1061+
# Make sure the override key was popped from extra_args and not double-passed.
1062+
assert "prompt_cache_key" not in {
1063+
k for k in kwargs if k != "prompt_cache_key"
1064+
} or list(kwargs).count("prompt_cache_key") == 1

0 commit comments

Comments
 (0)