@@ -845,4 +845,220 @@ async def test_missing_task_id_error(self, streaming_model):
845845 handoffs = [],
846846 tracing = None ,
847847 task_id = None # Missing task_id
848- )
848+ )
849+
850+
851+ class TestStreamingModelUsageAndCacheKey :
852+ """Tests for real-Usage capture, span output_data["usage"], and prompt_cache_key routing."""
853+
854+ @staticmethod
855+ def _async_iter (events ):
856+ async def _gen ():
857+ for event in events :
858+ yield event
859+ return _gen ()
860+
861+ @staticmethod
862+ def _make_response_completed_event (
863+ * ,
864+ input_tokens : int = 0 ,
865+ output_tokens : int = 0 ,
866+ total_tokens : int = 0 ,
867+ cached_tokens : int = 0 ,
868+ reasoning_tokens : int = 0 ,
869+ with_usage : bool = True ,
870+ ):
871+ from openai .types .responses import ResponseCompletedEvent
872+
873+ usage = MagicMock ()
874+ usage .input_tokens = input_tokens
875+ usage .output_tokens = output_tokens
876+ usage .total_tokens = total_tokens
877+ usage .input_tokens_details = MagicMock (cached_tokens = cached_tokens )
878+ usage .output_tokens_details = MagicMock (reasoning_tokens = reasoning_tokens )
879+
880+ response = MagicMock ()
881+ response .output = []
882+ response .usage = usage if with_usage else None
883+
884+ event = MagicMock (spec = ResponseCompletedEvent )
885+ event .response = response
886+ return event
887+
888+ @pytest .fixture
889+ def context_set (self ):
890+ """Set/reset the streaming contextvars used by get_response."""
891+ from agentex .lib .core .temporal .plugins .openai_agents .interceptors .context_interceptor import (
892+ streaming_task_id ,
893+ streaming_trace_id ,
894+ streaming_parent_span_id ,
895+ )
896+ task_token = streaming_task_id .set ("test-task-abc" )
897+ trace_token = streaming_trace_id .set ("test-trace-123" )
898+ parent_token = streaming_parent_span_id .set ("test-parent-span" )
899+ try :
900+ yield streaming_task_id , streaming_trace_id , streaming_parent_span_id
901+ finally :
902+ streaming_task_id .reset (task_token )
903+ streaming_trace_id .reset (trace_token )
904+ streaming_parent_span_id .reset (parent_token )
905+
906+ @pytest .fixture
907+ def mock_span (self ):
908+ return MagicMock ()
909+
910+ @pytest .fixture
911+ def streaming_model_with_mock_tracer (self , streaming_model , mock_span ):
912+ """A streaming_model whose tracer.trace().span(...) yields a captured mock span."""
913+ async_cm = MagicMock ()
914+ async_cm .__aenter__ = AsyncMock (return_value = mock_span )
915+ async_cm .__aexit__ = AsyncMock (return_value = False )
916+ trace_obj = MagicMock ()
917+ trace_obj .span = MagicMock (return_value = async_cm )
918+ streaming_model .tracer = MagicMock ()
919+ streaming_model .tracer .trace = MagicMock (return_value = trace_obj )
920+ return streaming_model
921+
922+ @pytest .mark .asyncio
923+ async def test_get_response_captures_usage_from_completed_event (
924+ self ,
925+ streaming_model_with_mock_tracer ,
926+ context_set , # noqa: ARG002
927+ ):
928+ model = streaming_model_with_mock_tracer
929+ completed = self ._make_response_completed_event (
930+ input_tokens = 1234 ,
931+ output_tokens = 56 ,
932+ total_tokens = 1290 ,
933+ cached_tokens = 987 ,
934+ reasoning_tokens = 42 ,
935+ )
936+ model .client .responses .create = AsyncMock (return_value = self ._async_iter ([completed ]))
937+
938+ response = await model .get_response (
939+ system_instructions = None ,
940+ input = "hi" ,
941+ model_settings = ModelSettings (),
942+ tools = [],
943+ output_schema = None ,
944+ handoffs = [],
945+ tracing = None ,
946+ )
947+
948+ assert response .usage .input_tokens == 1234
949+ assert response .usage .output_tokens == 56
950+ assert response .usage .total_tokens == 1290
951+ assert response .usage .input_tokens_details .cached_tokens == 987
952+ assert response .usage .output_tokens_details .reasoning_tokens == 42
953+
954+ @pytest .mark .asyncio
955+ async def test_get_response_usage_falls_back_when_no_completed_event (
956+ self ,
957+ streaming_model_with_mock_tracer ,
958+ context_set , # noqa: ARG002
959+ ):
960+ model = streaming_model_with_mock_tracer
961+ # Stream ends with no ResponseCompletedEvent
962+ model .client .responses .create = AsyncMock (return_value = self ._async_iter ([]))
963+
964+ response = await model .get_response (
965+ system_instructions = None ,
966+ input = "hi" ,
967+ model_settings = ModelSettings (),
968+ tools = [],
969+ output_schema = None ,
970+ handoffs = [],
971+ tracing = None ,
972+ )
973+
974+ assert response .usage .input_tokens == 0
975+ assert response .usage .output_tokens == 0
976+ assert response .usage .total_tokens == 0
977+ assert response .usage .input_tokens_details .cached_tokens == 0
978+ assert response .usage .output_tokens_details .reasoning_tokens == 0
979+
980+ @pytest .mark .asyncio
981+ async def test_get_response_emits_usage_in_span_output (
982+ self ,
983+ streaming_model_with_mock_tracer ,
984+ context_set , # noqa: ARG002
985+ mock_span ,
986+ ):
987+ model = streaming_model_with_mock_tracer
988+ completed = self ._make_response_completed_event (
989+ input_tokens = 100 ,
990+ output_tokens = 10 ,
991+ total_tokens = 110 ,
992+ cached_tokens = 80 ,
993+ reasoning_tokens = 5 ,
994+ )
995+ model .client .responses .create = AsyncMock (return_value = self ._async_iter ([completed ]))
996+
997+ await model .get_response (
998+ system_instructions = None ,
999+ input = "hi" ,
1000+ model_settings = ModelSettings (),
1001+ tools = [],
1002+ output_schema = None ,
1003+ handoffs = [],
1004+ tracing = None ,
1005+ )
1006+
1007+ assert isinstance (mock_span .output , dict )
1008+ assert "usage" in mock_span .output
1009+ usage_block = mock_span .output ["usage" ]
1010+ assert usage_block ["input_tokens" ] == 100
1011+ assert usage_block ["output_tokens" ] == 10
1012+ assert usage_block ["total_tokens" ] == 110
1013+ assert usage_block ["cached_input_tokens" ] == 80
1014+ assert usage_block ["reasoning_tokens" ] == 5
1015+
1016+ @pytest .mark .asyncio
1017+ async def test_get_response_passes_prompt_cache_key_from_contextvar (
1018+ self ,
1019+ streaming_model_with_mock_tracer ,
1020+ context_set , # noqa: ARG002
1021+ ):
1022+ model = streaming_model_with_mock_tracer
1023+ completed = self ._make_response_completed_event ()
1024+ model .client .responses .create = AsyncMock (return_value = self ._async_iter ([completed ]))
1025+
1026+ await model .get_response (
1027+ system_instructions = None ,
1028+ input = "hi" ,
1029+ model_settings = ModelSettings (),
1030+ tools = [],
1031+ output_schema = None ,
1032+ handoffs = [],
1033+ tracing = None ,
1034+ )
1035+
1036+ kwargs = model .client .responses .create .call_args .kwargs
1037+ assert kwargs ["prompt_cache_key" ] == "test-task-abc"
1038+
1039+ @pytest .mark .asyncio
1040+ async def test_get_response_caller_override_for_prompt_cache_key (
1041+ self ,
1042+ streaming_model_with_mock_tracer ,
1043+ context_set , # noqa: ARG002
1044+ ):
1045+ model = streaming_model_with_mock_tracer
1046+ completed = self ._make_response_completed_event ()
1047+ model .client .responses .create = AsyncMock (return_value = self ._async_iter ([completed ]))
1048+
1049+ await model .get_response (
1050+ system_instructions = None ,
1051+ input = "hi" ,
1052+ model_settings = ModelSettings (extra_args = {"prompt_cache_key" : "my-key" }),
1053+ tools = [],
1054+ output_schema = None ,
1055+ handoffs = [],
1056+ tracing = None ,
1057+ )
1058+
1059+ kwargs = model .client .responses .create .call_args .kwargs
1060+ assert kwargs ["prompt_cache_key" ] == "my-key"
1061+ # Make sure the override key was popped from extra_args and not double-passed.
1062+ assert "prompt_cache_key" not in {
1063+ k for k in kwargs if k != "prompt_cache_key"
1064+ } or list (kwargs ).count ("prompt_cache_key" ) == 1
0 commit comments