Skip to content

Commit eb8ff68

Browse files
committed
test(openai_agents): cover real Usage, real response_id, opt-in cache key
Seven new tests in `TestStreamingModelUsageResponseIdAndCacheKey`: - Usage captured from `ResponseCompletedEvent.response.usage` - Usage falls back to zeros when stream ends without a completed event - Usage emitted in span output_data["usage"] - response_id captured from `ResponseCompletedEvent.response.id` - response_id is None (NOT a fabricated UUID) when stream ends without a completed event — guards against the previous footgun where a client-side UUID would be returned and silently break downstream `previous_response_id` chaining - prompt_cache_key resolves to NOT_GIVEN by default (omitted from request body, safe for non-OpenAI endpoints) - prompt_cache_key forwarded when caller opts in via `model_settings.extra_args["prompt_cache_key"]`, and popped from extra_args so it isn't passed twice Pre-existing tests in `TestStreamingModelBasics` (test_responses_api_streaming, test_task_id_threading, test_redis_context_creation) updated to set `response.id=None` on their `MagicMock(spec=ResponseCompletedEvent)` mocks. Without this, the auto-generated MagicMock attribute for `response.id` flows into `ModelResponse.response_id` and trips pydantic's `str | None` validation.
1 parent 125810a commit eb8ff68

1 file changed

Lines changed: 246 additions & 4 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py

Lines changed: 246 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming
757757
text_delta_2 = MagicMock(spec=ResponseTextDeltaEvent)
758758
text_delta_2.delta = "world!"
759759
completed = MagicMock(spec=ResponseCompletedEvent)
760-
completed.response = MagicMock(output=[], usage=MagicMock())
760+
completed.response = MagicMock(output=[], usage=MagicMock(), id=None)
761761
mock_stream = AsyncMock()
762762
mock_stream.__aiter__.return_value = iter([item_added, text_delta_1, text_delta_2, completed])
763763
streaming_model.client.responses.create.return_value = mock_stream
@@ -796,7 +796,7 @@ async def test_task_id_threading(self, streaming_model, mock_adk_streaming, _str
796796
item_added.item = MagicMock(type="message")
797797
item_added.output_index = 0
798798
completed = MagicMock(spec=ResponseCompletedEvent)
799-
completed.response = MagicMock(output=[], usage=MagicMock())
799+
completed.response = MagicMock(output=[], usage=MagicMock(), id=None)
800800
mock_stream = AsyncMock()
801801
mock_stream.__aiter__.return_value = iter([item_added, completed])
802802
streaming_model.client.responses.create.return_value = mock_stream
@@ -832,7 +832,7 @@ async def test_redis_context_creation(self, streaming_model, mock_adk_streaming,
832832
reasoning_delta.delta = "Thinking..."
833833
reasoning_delta.summary_index = 0
834834
completed = MagicMock(spec=ResponseCompletedEvent)
835-
completed.response = MagicMock(output=[], usage=MagicMock())
835+
completed.response = MagicMock(output=[], usage=MagicMock(), id=None)
836836
mock_stream = AsyncMock()
837837
mock_stream.__aiter__.return_value = iter([item_added, reasoning_delta, completed])
838838
streaming_model.client.responses.create.return_value = mock_stream
@@ -871,4 +871,246 @@ async def test_missing_task_id_error(self, streaming_model):
871871
output_schema=None,
872872
handoffs=[],
873873
tracing=None,
874-
)
874+
)
875+
876+
877+
class TestStreamingModelUsageResponseIdAndCacheKey:
878+
"""Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key."""
879+
880+
@staticmethod
881+
def _async_iter(events):
882+
async def _gen():
883+
for event in events:
884+
yield event
885+
return _gen()
886+
887+
@staticmethod
888+
def _make_response_completed_event(
889+
*,
890+
input_tokens: int = 0,
891+
output_tokens: int = 0,
892+
total_tokens: int = 0,
893+
cached_tokens: int = 0,
894+
reasoning_tokens: int = 0,
895+
with_usage: bool = True,
896+
response_id: str | None = "resp_real_server_id",
897+
):
898+
usage = MagicMock()
899+
usage.input_tokens = input_tokens
900+
usage.output_tokens = output_tokens
901+
usage.total_tokens = total_tokens
902+
usage.input_tokens_details = MagicMock(cached_tokens=cached_tokens)
903+
usage.output_tokens_details = MagicMock(reasoning_tokens=reasoning_tokens)
904+
905+
response = MagicMock()
906+
response.output = []
907+
response.usage = usage if with_usage else None
908+
response.id = response_id
909+
910+
event = MagicMock(spec=ResponseCompletedEvent)
911+
event.response = response
912+
return event
913+
914+
@pytest.fixture
915+
def mock_span(self):
916+
return MagicMock()
917+
918+
@pytest.fixture
919+
def streaming_model_with_mock_tracer(self, streaming_model, mock_span):
920+
"""A streaming_model whose tracer.trace().span(...) yields a captured mock span."""
921+
async_cm = MagicMock()
922+
async_cm.__aenter__ = AsyncMock(return_value=mock_span)
923+
async_cm.__aexit__ = AsyncMock(return_value=False)
924+
trace_obj = MagicMock()
925+
trace_obj.span = MagicMock(return_value=async_cm)
926+
streaming_model.tracer = MagicMock()
927+
streaming_model.tracer.trace = MagicMock(return_value=trace_obj)
928+
return streaming_model
929+
930+
@pytest.mark.asyncio
931+
async def test_usage_captured_from_completed_event(
932+
self,
933+
streaming_model_with_mock_tracer,
934+
_streaming_context_vars, # noqa: ARG002
935+
):
936+
model = streaming_model_with_mock_tracer
937+
completed = self._make_response_completed_event(
938+
input_tokens=1234, output_tokens=56, total_tokens=1290,
939+
cached_tokens=987, reasoning_tokens=42,
940+
)
941+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
942+
943+
response = await model.get_response(
944+
system_instructions=None,
945+
input="hi",
946+
model_settings=ModelSettings(),
947+
tools=[],
948+
output_schema=None,
949+
handoffs=[],
950+
tracing=None,
951+
)
952+
953+
assert response.usage.input_tokens == 1234
954+
assert response.usage.output_tokens == 56
955+
assert response.usage.total_tokens == 1290
956+
assert response.usage.input_tokens_details.cached_tokens == 987
957+
assert response.usage.output_tokens_details.reasoning_tokens == 42
958+
959+
@pytest.mark.asyncio
960+
async def test_usage_falls_back_when_no_completed_event(
961+
self,
962+
streaming_model_with_mock_tracer,
963+
_streaming_context_vars, # noqa: ARG002
964+
):
965+
"""Stream ending without a ResponseCompletedEvent (error path) → zero Usage."""
966+
model = streaming_model_with_mock_tracer
967+
model.client.responses.create = AsyncMock(return_value=self._async_iter([]))
968+
969+
response = await model.get_response(
970+
system_instructions=None,
971+
input="hi",
972+
model_settings=ModelSettings(),
973+
tools=[],
974+
output_schema=None,
975+
handoffs=[],
976+
tracing=None,
977+
)
978+
979+
assert response.usage.input_tokens == 0
980+
assert response.usage.output_tokens == 0
981+
assert response.usage.total_tokens == 0
982+
assert response.usage.input_tokens_details.cached_tokens == 0
983+
assert response.usage.output_tokens_details.reasoning_tokens == 0
984+
985+
@pytest.mark.asyncio
986+
async def test_usage_emitted_in_span_output(
987+
self,
988+
streaming_model_with_mock_tracer,
989+
_streaming_context_vars, # noqa: ARG002
990+
mock_span,
991+
):
992+
model = streaming_model_with_mock_tracer
993+
completed = self._make_response_completed_event(
994+
input_tokens=100, output_tokens=10, total_tokens=110,
995+
cached_tokens=80, reasoning_tokens=5,
996+
)
997+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
998+
999+
await model.get_response(
1000+
system_instructions=None,
1001+
input="hi",
1002+
model_settings=ModelSettings(),
1003+
tools=[],
1004+
output_schema=None,
1005+
handoffs=[],
1006+
tracing=None,
1007+
)
1008+
1009+
assert isinstance(mock_span.output, dict)
1010+
usage_block = mock_span.output["usage"]
1011+
assert usage_block == {
1012+
"input_tokens": 100,
1013+
"output_tokens": 10,
1014+
"total_tokens": 110,
1015+
"cached_input_tokens": 80,
1016+
"reasoning_tokens": 5,
1017+
}
1018+
1019+
@pytest.mark.asyncio
1020+
async def test_response_id_captured_from_completed_event(
1021+
self,
1022+
streaming_model_with_mock_tracer,
1023+
_streaming_context_vars, # noqa: ARG002
1024+
):
1025+
"""Real server-issued id flows back on ModelResponse.response_id."""
1026+
model = streaming_model_with_mock_tracer
1027+
completed = self._make_response_completed_event(response_id="resp_abcdef123456")
1028+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1029+
1030+
response = await model.get_response(
1031+
system_instructions=None,
1032+
input="hi",
1033+
model_settings=ModelSettings(),
1034+
tools=[],
1035+
output_schema=None,
1036+
handoffs=[],
1037+
tracing=None,
1038+
)
1039+
1040+
assert response.response_id == "resp_abcdef123456"
1041+
1042+
@pytest.mark.asyncio
1043+
async def test_response_id_is_none_when_no_completed_event(
1044+
self,
1045+
streaming_model_with_mock_tracer,
1046+
_streaming_context_vars, # noqa: ARG002
1047+
):
1048+
"""Stream ending without ResponseCompletedEvent → response_id is None.
1049+
1050+
Critical: must NOT fabricate a UUID. Returning a fake id would cause
1051+
downstream `previous_response_id` chaining to 400 against the server.
1052+
"""
1053+
model = streaming_model_with_mock_tracer
1054+
model.client.responses.create = AsyncMock(return_value=self._async_iter([]))
1055+
1056+
response = await model.get_response(
1057+
system_instructions=None,
1058+
input="hi",
1059+
model_settings=ModelSettings(),
1060+
tools=[],
1061+
output_schema=None,
1062+
handoffs=[],
1063+
tracing=None,
1064+
)
1065+
1066+
assert response.response_id is None
1067+
1068+
@pytest.mark.asyncio
1069+
async def test_prompt_cache_key_not_sent_by_default(
1070+
self,
1071+
streaming_model_with_mock_tracer,
1072+
_streaming_context_vars, # noqa: ARG002
1073+
):
1074+
"""Without an opt-in, prompt_cache_key resolves to NOT_GIVEN (omitted from request)."""
1075+
model = streaming_model_with_mock_tracer
1076+
completed = self._make_response_completed_event()
1077+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1078+
1079+
await model.get_response(
1080+
system_instructions=None,
1081+
input="hi",
1082+
model_settings=ModelSettings(),
1083+
tools=[],
1084+
output_schema=None,
1085+
handoffs=[],
1086+
tracing=None,
1087+
)
1088+
1089+
kwargs = model.client.responses.create.call_args.kwargs
1090+
assert kwargs["prompt_cache_key"] is NOT_GIVEN
1091+
1092+
@pytest.mark.asyncio
1093+
async def test_prompt_cache_key_forwarded_when_opted_in(
1094+
self,
1095+
streaming_model_with_mock_tracer,
1096+
_streaming_context_vars, # noqa: ARG002
1097+
):
1098+
"""Caller opt-in via model_settings.extra_args is forwarded to responses.create."""
1099+
model = streaming_model_with_mock_tracer
1100+
completed = self._make_response_completed_event()
1101+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1102+
1103+
await model.get_response(
1104+
system_instructions=None,
1105+
input="hi",
1106+
model_settings=ModelSettings(extra_args={"prompt_cache_key": "my-key"}),
1107+
tools=[],
1108+
output_schema=None,
1109+
handoffs=[],
1110+
tracing=None,
1111+
)
1112+
1113+
kwargs = model.client.responses.create.call_args.kwargs
1114+
assert kwargs["prompt_cache_key"] == "my-key"
1115+
# Must be popped from extra_args so the SDK doesn't see it twice.
1116+
assert list(kwargs).count("prompt_cache_key") == 1

0 commit comments

Comments
 (0)