diff --git a/src/agents/run.py b/src/agents/run.py index 014271a5ea..55177412d0 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1243,6 +1243,17 @@ def _finalize_result(result: RunResult) -> RunResult: ) ) raise + except Exception: + # A parallel input guardrail (or the model turn) raised a + # non-tripwire error. asyncio.gather does not cancel the + # sibling awaitables when one fails, so cancel the in-flight + # model task to avoid leaking it (mirrors the tripwire + # cleanup above). + if should_cancel_parallel_model_task_on_input_guardrail_trip(): + if not model_task.done(): + model_task.cancel() + await asyncio.gather(model_task, return_exceptions=True) + raise else: turn_result = await model_task diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 8f05c38129..395696ee20 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -615,6 +615,56 @@ async def slow_get_response(*args, **kwargs): assert model_cancelled.is_set() is False +@pytest.mark.asyncio +async def test_parallel_guardrail_non_tripwire_error_cancels_model_task(): + # A parallel input guardrail that raises a non-tripwire exception must not leave the + # in-flight model task running. asyncio.gather does not cancel sibling awaitables when + # one fails, so the runner is responsible for cancelling the model task. + model_started = asyncio.Event() + model_cancelled = asyncio.Event() + model_finished = asyncio.Event() + + class GuardrailBoom(RuntimeError): + pass + + @input_guardrail(run_in_parallel=True) + async def raises_after_model_starts( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + await asyncio.wait_for(model_started.wait(), timeout=1) + raise GuardrailBoom("guardrail failed") + + model = FakeModel() + original_get_response = model.get_response + + async def slow_get_response(*args, **kwargs): + model_started.set() + try: + await asyncio.sleep(0.02) + return await original_get_response(*args, **kwargs) + except asyncio.CancelledError: + model_cancelled.set() + raise + finally: + model_finished.set() + + agent = Agent( + name="parallel_guardrail_error_agent", + instructions="Reply with 'hello'", + input_guardrails=[raises_after_model_starts], + model=model, + ) + model.set_next_output([get_text_message("should_not_finish")]) + + with patch.object(model, "get_response", side_effect=slow_get_response): + with pytest.raises(GuardrailBoom): + await Runner.run(agent, "trigger guardrail") + + await asyncio.wait_for(model_finished.wait(), timeout=1) + assert model_started.is_set() is True + assert model_cancelled.is_set() is True + + @pytest.mark.asyncio async def test_parallel_guardrail_may_not_prevent_tool_execution_streaming(): tool_was_executed = False