diff --git a/examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/tests/test_agent.py b/examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/tests/test_agent.py index 8bf59cafc..7f3455236 100644 --- a/examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/tests/test_agent.py +++ b/examples/tutorials/10_async/10_temporal/070_open_ai_agents_sdk_tools/tests/test_agent.py @@ -118,8 +118,11 @@ async def test_send_event_and_poll(self, client: AsyncAgentex, agent_id: str): content_length = len(str(agent_text)) final_message = message - # Stop when we get DONE status + # Stop when we get DONE status (after tool_response if a tool was used; + # tool rows can appear on a later poll than final text). if message.streaming_status == "DONE" and content_length > 0: + if seen_tool_request and not seen_tool_response: + continue break # Verify we got all the expected pieces diff --git a/examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/tests/test_agent.py b/examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/tests/test_agent.py index 3377c1ea8..520e0fc0e 100644 --- a/examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/tests/test_agent.py +++ b/examples/tutorials/10_async/10_temporal/080_open_ai_agents_sdk_human_in_the_loop/tests/test_agent.py @@ -148,9 +148,12 @@ async def test_send_event_and_poll_with_human_approval(self, client: AsyncAgente if message.content and message.content.type == "text" and message.content.author == "agent": content_length = len(message.content.content) if message.content.content else 0 - # Stop when we get DONE status with actual content + # Stop when we get DONE with content (after tool_response if a tool ran; + # tool rows can appear on a later poll than final text). if message.streaming_status == "DONE" and content_length > 0: found_final_response = True + if seen_tool_request and not seen_tool_response: + continue break # Verify that we saw the complete flow: tool_request -> human approval -> tool_response -> final answer diff --git a/examples/tutorials/test_utils/async_utils.py b/examples/tutorials/test_utils/async_utils.py index 2187e98d8..503b6df83 100644 --- a/examples/tutorials/test_utils/async_utils.py +++ b/examples/tutorials/test_utils/async_utils.py @@ -17,6 +17,27 @@ from agentex.types.agent_rpc_params import ParamsSendEventRequest from agentex.types.agent_rpc_result import StreamTaskMessageDone, StreamTaskMessageFull from agentex.types.text_content_param import TextContentParam +from agentex.types.tool_request_content import ToolRequestContent +from agentex.types.tool_response_content import ToolResponseContent + + +def _is_tool_lifecycle_message(message: TaskMessage) -> bool: + """True for tool request/response rows (by model type or ``content.type`` string).""" + c = message.content + if c is None: + return False + if isinstance(c, (ToolRequestContent, ToolResponseContent)): + return True + ctype = getattr(c, "type", None) + return ctype in ("tool_request", "tool_response") + + +def _pending_poll_sort_key(item: tuple[TaskMessage, int | None]) -> tuple[int, datetime]: + """Yield tool lifecycle messages before other content in the same poll batch.""" + m, _ = item + ts = m.created_at if m.created_at else datetime.min.replace(tzinfo=timezone.utc) + phase = 0 if _is_tool_lifecycle_message(m) else 1 + return (phase, ts) async def send_event_and_poll_yielding( @@ -90,10 +111,16 @@ async def poll_messages( If False, only yield each message ID once (default: False) Yields: - TaskMessage objects as they are discovered or updated + TaskMessage objects as they are discovered or updated. + + Within each poll, messages to emit are collected first, then re-ordered so + ``tool_request`` / ``tool_response`` rows are yielded before other content in that + batch. ``tool_response`` can also appear on a later poll than final assistant text; + callers that ``break`` on DONE text should keep polling until they have seen + ``tool_response`` after a ``tool_request`` (see tutorial tests). """ # Keep track of messages we've already yielded - seen_message_ids = set() + seen_message_ids: set[str] = set() # Track message content hashes to detect updates (for streaming) message_content_hashes: dict[str, int] = {} start_time = datetime.now() @@ -102,14 +129,15 @@ async def poll_messages( while (datetime.now() - start_time).seconds < timeout: messages = await client.messages.list(task_id=task_id) - # Sort messages by created_at to ensure chronological order - # Use datetime.min for messages without created_at timestamp sorted_messages = sorted( messages, - key=lambda m: m.created_at if m.created_at else datetime.min.replace(tzinfo=timezone.utc) + key=lambda m: m.created_at if m.created_at else datetime.min.replace(tzinfo=timezone.utc), ) - new_messages_found = 0 + # Collect (message, hash) for this poll without mutating dedupe state yet, then + # yield tool lifecycle rows before streaming text / other updates in the same batch. + pending: list[tuple[TaskMessage, int | None]] = [] + for message in sorted_messages: # Check if message passes timestamp filter if messages_created_after and message.created_at: @@ -141,16 +169,22 @@ async def poll_messages( is_updated = message.id in message_content_hashes and message_content_hashes[message.id] != content_hash if is_new_message or is_updated: - message_content_hashes[message.id] = content_hash - seen_message_ids.add(message.id) - new_messages_found += 1 - yield message + pending.append((message, content_hash)) else: # Original behavior: only yield each message ID once if is_new_message: - seen_message_ids.add(message.id) - new_messages_found += 1 - yield message + pending.append((message, None)) + + pending.sort(key=_pending_poll_sort_key) + + for message, content_hash in pending: + mid = message.id + if not mid: + continue + if yield_updates and content_hash is not None: + message_content_hashes[mid] = content_hash + seen_message_ids.add(mid) + yield message # Sleep before next poll await asyncio.sleep(sleep_interval)