diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 25874ad345..bafc420d9c 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -424,58 +424,24 @@ async def rewind_session_items( logger.debug("Rewind target %d (first 300 chars): %s", i, target[:300]) snapshot_serializations = target_serializations.copy() + rewound = await _rewind_session_tail_suffix( + session=session, + pop_item=pop_item, + expected_serializations=target_serializations, + ignore_ids_for_matching=ignore_ids_for_matching, + mismatch_warning=( + "Skipping session rewind because the current tail does not match the retry-owned suffix" + ), + pop_failure_warning="Failed to rewind session item: %s", + ) + if not rewound: + return - remaining = target_serializations.copy() - - while remaining: - try: - result = pop_item() - if inspect.isawaitable(result): - result = await result - except Exception as exc: - logger.warning("Failed to rewind session item: %s", exc) - break - else: - if result is None: - break - - popped_serialized = fingerprint_input_item( - result, ignore_ids_for_matching=ignore_ids_for_matching - ) - - logger.debug("Popped item type during rewind: %s", type(result).__name__) - if popped_serialized: - logger.debug("Popped serialized (first 300 chars): %s", popped_serialized[:300]) - else: - logger.debug("Popped serialized: None") - - logger.debug("Number of remaining targets: %d", len(remaining)) - if remaining and popped_serialized: - logger.debug("First target (first 300 chars): %s", remaining[0][:300]) - logger.debug("Match found: %s", popped_serialized in remaining) - if len(remaining) > 0: - first_target = remaining[0] - if abs(len(first_target) - len(popped_serialized)) < 50: - logger.debug( - "Length comparison - popped: %d, target: %d", - len(popped_serialized), - len(first_target), - ) - - if popped_serialized and popped_serialized in remaining: - remaining.remove(popped_serialized) - - if remaining: - logger.warning( - "Unable to fully rewind session; %d items still unmatched after retry", - len(remaining), - ) - else: - await wait_for_session_cleanup( - session, - snapshot_serializations, - ignore_ids_for_matching=ignore_ids_for_matching, - ) + await wait_for_session_cleanup( + session, + snapshot_serializations, + ignore_ids_for_matching=ignore_ids_for_matching, + ) if session is None or server_tracker is None: return @@ -493,22 +459,36 @@ async def rewind_session_items( if isinstance(latest_id, str) and latest_id in server_tracker.server_item_ids: return - logger.debug("Stripping stray conversation items until we reach a known server item") - while True: - try: - result = pop_item() - if inspect.isawaitable(result): - result = await result - except Exception as exc: - logger.warning("Failed to strip stray session item: %s", exc) - break + try: + session_items = await session.get_items() + except Exception as exc: + logger.debug("Failed to inspect session tail while stripping stray items: %s", exc) + return - if result is None: - break + stray_serializations = _collect_retry_owned_tail_serializations( + session_items, + server_tracker=server_tracker, + ignore_ids_for_matching=ignore_ids_for_matching, + ) + if not stray_serializations: + return - stripped_id = result.get("id") if isinstance(result, dict) else getattr(result, "id", None) - if isinstance(stripped_id, str) and stripped_id in server_tracker.server_item_ids: - break + logger.debug( + "Stripping %d retry-owned conversation items until the session tail reaches " + "a known server item", + len(stray_serializations), + ) + await _rewind_session_tail_suffix( + session=session, + pop_item=pop_item, + expected_serializations=stray_serializations, + ignore_ids_for_matching=ignore_ids_for_matching, + mismatch_warning=( + "Skipping stray session cleanup because the current tail no longer matches " + "retry-owned conversation items" + ), + pop_failure_warning="Failed to strip stray session item: %s", + ) async def wait_for_session_cleanup( @@ -582,6 +562,121 @@ def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: b ) +async def _rewind_session_tail_suffix( + *, + session: Session, + pop_item: Any, + expected_serializations: Sequence[str], + ignore_ids_for_matching: bool, + mismatch_warning: str, + pop_failure_warning: str, +) -> bool: + """Remove an exact serialized suffix from the session tail, aborting when the tail diverges.""" + if not expected_serializations: + return True + + try: + tail_items = await session.get_items(limit=len(expected_serializations)) + except Exception as exc: + logger.warning(pop_failure_warning, exc) + return False + + if len(tail_items) != len(expected_serializations): + logger.warning(mismatch_warning) + return False + + tail_serializations: list[str] = [] + for item in tail_items: + serialized = fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching) + if not serialized: + logger.warning(mismatch_warning) + return False + tail_serializations.append(serialized) + + if tail_serializations != list(expected_serializations): + logger.warning(mismatch_warning) + return False + + popped_items: list[TResponseInputItem] = [] + for expected in reversed(expected_serializations): + try: + result = pop_item() + if inspect.isawaitable(result): + result = await result + except Exception as exc: + await _restore_popped_session_items(session, popped_items) + logger.warning(pop_failure_warning, exc) + return False + + if result is None: + await _restore_popped_session_items(session, popped_items) + logger.warning(mismatch_warning) + return False + + popped_items.append(result) + popped_serialized = fingerprint_input_item( + result, ignore_ids_for_matching=ignore_ids_for_matching + ) + if popped_serialized != expected: + await _restore_popped_session_items(session, popped_items) + logger.warning(mismatch_warning) + return False + + return True + + +async def _restore_popped_session_items( + session: Session, popped_items: Sequence[TResponseInputItem] +) -> None: + """Best-effort restoration for items popped during a failed rewind attempt.""" + if not popped_items: + return + + add_items = getattr(session, "add_items", None) + if not callable(add_items): + return + + try: + result = add_items(list(reversed(popped_items))) + if inspect.isawaitable(result): + await result + except Exception as exc: + logger.warning("Failed to restore session items after a rewind mismatch: %s", exc) + + +def _collect_retry_owned_tail_serializations( + session_items: Sequence[TResponseInputItem], + *, + server_tracker: OpenAIServerConversationTracker, + ignore_ids_for_matching: bool, +) -> list[str]: + """Return the contiguous retry-owned tail suffix that can be safely stripped.""" + stray_tail: list[str] = [] + + for item in reversed(session_items): + item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None) + if isinstance(item_id, str) and item_id in server_tracker.server_item_ids: + return list(reversed(stray_tail)) + + serialized = fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching) + if serialized and serialized in server_tracker.sent_item_fingerprints: + stray_tail.append(serialized) + continue + + logger.warning( + "Skipping stray session cleanup because the current tail contains items unrelated " + "to this retry" + ) + return [] + + if stray_tail: + logger.warning( + "Skipping stray session cleanup because no known server item was found before the " + "session boundary" + ) + return [] + + def _session_item_key(item: Any) -> str: """Return a stable representation of a session item for comparison.""" try: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 45cdab7711..ce3e547d69 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -70,6 +70,7 @@ from agents.run_internal.run_loop import get_new_response from agents.run_internal.run_steps import NextStepFinalOutput, SingleStepResult from agents.run_internal.session_persistence import ( + _collect_retry_owned_tail_serializations, persist_session_items_for_guardrail_trip, prepare_input_with_session, rewind_session_items, @@ -2364,6 +2365,79 @@ async def test_rewind_handles_id_stripped_sessions() -> None: assert session.saved_items == [] +@pytest.mark.asyncio +async def test_rewind_skips_mismatched_tail_suffix() -> None: + target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"}) + unrelated = cast( + TResponseInputItem, + {"type": "message", "role": "user", "content": "unrelated tail item"}, + ) + session = CountingSession(history=[target, unrelated]) + + await rewind_session_items(session, [target]) + + assert session.pop_calls == 0 + assert session.saved_items == [target, unrelated] + + +@pytest.mark.asyncio +async def test_rewind_preserves_unrelated_tail_items_when_server_tracker_cleanup_runs() -> None: + known_server_item = cast( + TResponseInputItem, + {"id": "msg_server_1", "type": "message", "role": "assistant", "content": "server item"}, + ) + unrelated = cast( + TResponseInputItem, + {"type": "message", "role": "user", "content": "unrelated tail item"}, + ) + target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"}) + session = CountingSession(history=[known_server_item, unrelated, target]) + tracker = OpenAIServerConversationTracker() + tracker.server_item_ids.add("msg_server_1") + + await rewind_session_items(session, [target], tracker) + + assert session.pop_calls == 1 + assert session.saved_items == [known_server_item, unrelated] + + +@pytest.mark.asyncio +async def test_rewind_strips_only_retry_owned_tail_items_before_known_server_item() -> None: + known_server_item = cast( + TResponseInputItem, + {"id": "msg_server_1", "type": "message", "role": "assistant", "content": "server item"}, + ) + retry_owned_tail = cast( + TResponseInputItem, + {"type": "message", "role": "user", "content": "retry-owned local item"}, + ) + target = cast(TResponseInputItem, {"type": "message", "role": "user", "content": "target"}) + session = CountingSession(history=[known_server_item, retry_owned_tail, target]) + tracker = OpenAIServerConversationTracker() + tracker.server_item_ids.add("msg_server_1") + retry_owned_fingerprint = fingerprint_input_item(retry_owned_tail) + assert retry_owned_fingerprint is not None + tracker.sent_item_fingerprints.add(retry_owned_fingerprint) + + await rewind_session_items(session, [target], tracker) + + assert session.pop_calls == 2 + assert session.saved_items == [known_server_item] + + +def test_collect_retry_owned_tail_serializations_returns_empty_for_empty_session() -> None: + tracker = OpenAIServerConversationTracker() + + assert ( + _collect_retry_owned_tail_serializations( + [], + server_tracker=tracker, + ignore_ids_for_matching=False, + ) + == [] + ) + + @pytest.mark.asyncio async def test_save_result_to_session_does_not_increment_counter_when_nothing_saved() -> None: session = SimpleListSession()