Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/running_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ settings so the resumed turn continues in the same server-managed conversation.

Use `call_model_input_filter` to edit the model input right before the model call. The hook receives the current agent, context, and the combined input items (including session history when present) and returns a new `ModelInputData`.

The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`.
The return value must be a [`ModelInputData`][agents.run.ModelInputData] object. Its `input` field is required and must be a list of input items. Returning any other shape raises a `UserError`. You may also set `output_schema` on the returned object to replace the response format for that model call — the agent's own `output_type` is used when `output_schema` is `None` or omitted.

```python
from agents import Agent, Runner, RunConfig
Expand Down
6 changes: 5 additions & 1 deletion src/agents/extensions/tool_output_trimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def __call__(self, data: CallModelData[Any]) -> ModelInputData:
f"saved ~{chars_saved} chars"
)

return _ModelInputData(input=new_items, instructions=model_data.instructions)
return _ModelInputData(
input=new_items,
instructions=model_data.instructions,
output_schema=model_data.output_schema,
)

def _find_recent_boundary(self, items: list[Any]) -> int:
"""Find the index separating 'old' items from 'recent' items.
Expand Down
5 changes: 5 additions & 0 deletions src/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

if TYPE_CHECKING:
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .run_context import RunContextWrapper
from .sandbox.manifest import Manifest
from .sandbox.session.base_sandbox_session import BaseSandboxSession
Expand Down Expand Up @@ -50,6 +51,10 @@ class ModelInputData:

input: list[TResponseInputItem]
instructions: str | None
output_schema: AgentOutputSchemaBase | None = None
"""Output schema override. When set by a ``call_model_input_filter``, replaces the schema
derived from ``agent.output_type`` for this model call. When ``None``, the agent's schema
is used unchanged."""


@dataclass
Expand Down
12 changes: 8 additions & 4 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
output_schema=output_schema,
)
output_schema = filtered.output_schema if filtered.output_schema is not None else output_schema
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)
hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items)
Expand Down Expand Up @@ -1760,7 +1762,7 @@ async def run_single_turn(
else:
input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy)

new_response = await get_new_response(
new_response, output_schema = await get_new_response(
bindings,
system_prompt,
input,
Expand Down Expand Up @@ -1811,8 +1813,8 @@ async def get_new_response(
session: Session | None = None,
session_items_to_rewind: list[TResponseInputItem] | None = None,
prompt_cache_key_resolver: PromptCacheKeyResolver | None = None,
) -> ModelResponse:
"""Call the model and return the raw response, handling retries and hooks."""
) -> tuple[ModelResponse, AgentOutputSchemaBase | None]:
"""Call the model and return the raw response and effective output schema after filtering."""
public_agent = bindings.public_agent
execution_agent = bindings.execution_agent
filtered = await maybe_filter_model_input(
Expand All @@ -1821,7 +1823,9 @@ async def get_new_response(
context_wrapper=context_wrapper,
input_items=input,
system_instructions=system_prompt,
output_schema=output_schema,
)
output_schema = filtered.output_schema if filtered.output_schema is not None else output_schema
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate overridden schema to final-output parsing

In non-streamed runs, this reassignment only changes the local output_schema inside get_new_response; after it returns, run_single_turn still passes the original schema to get_single_step_result_from_response. When a call_model_input_filter sets a structured schema for an agent without output_type, the model request uses the override, but the final response is processed with None, so Runner.run returns the raw JSON string instead of validating/parsing it with the override. The streamed path keeps the updated schema in the same function, so this regression is specific to non-streamed runs.

Useful? React with 👍 / 👎.

if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)

Expand Down Expand Up @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None:
hooks.on_llm_end(context_wrapper, public_agent, new_response),
)

return new_response
return new_response, output_schema
8 changes: 7 additions & 1 deletion src/agents/run_internal/turn_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,24 @@ async def maybe_filter_model_input(
context_wrapper: RunContextWrapper[TContext],
input_items: list[TResponseInputItem],
system_instructions: str | None,
output_schema: AgentOutputSchemaBase | None = None,
) -> ModelInputData:
"""Apply optional call_model_input_filter to modify model input."""
effective_instructions = system_instructions
effective_input: list[TResponseInputItem] = input_items

if run_config.call_model_input_filter is None:
return ModelInputData(input=effective_input, instructions=effective_instructions)
return ModelInputData(
input=effective_input,
instructions=effective_instructions,
output_schema=output_schema,
)

try:
model_input = ModelInputData(
input=effective_input.copy(),
instructions=effective_instructions,
output_schema=output_schema,
)
filter_payload: CallModelData[TContext] = CallModelData(
model_data=model_input,
Expand Down
8 changes: 5 additions & 3 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,7 +2449,8 @@ async def test_conversation_lock_rewind_skips_when_no_snapshot() -> None:
session_items_to_rewind=[],
)

assert isinstance(result, ModelResponse)
response, _ = result
assert isinstance(response, ModelResponse)
assert session.pop_calls == 0


Expand Down Expand Up @@ -2494,8 +2495,9 @@ async def test_get_new_response_uses_agent_retry_settings() -> None:
session_items_to_rewind=[],
)

assert isinstance(result, ModelResponse)
assert result.usage.requests == 2
response, _ = result
assert isinstance(response, ModelResponse)
assert response.usage.requests == 2


@pytest.mark.asyncio
Expand Down
112 changes: 112 additions & 0 deletions tests/test_call_model_input_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Any, cast

import pytest
from pydantic import BaseModel

from agents import Agent, RunConfig, Runner, TResponseInputItem, UserError
from agents.agent_output import AgentOutputSchema
from agents.run import CallModelData, ModelInputData

from .fake_model import FakeModel
Expand Down Expand Up @@ -167,3 +169,113 @@ async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
]
assert len(outputs) == 1
assert outputs[0]["output"] == "new-value"


class _Reply(BaseModel):
answer: str


@pytest.mark.asyncio
async def test_filter_can_override_output_schema_non_streamed() -> None:
"""Regression test for #3563: filter can replace output_schema on non-streamed run.

Verifies both that the model call receives the override schema and that the
response is parsed against it (not discarded after get_new_response returns).
"""
model = FakeModel()
agent = Agent(name="test", model=model)
model.set_next_output([get_text_message('{"answer": "hi"}')])

override_schema = AgentOutputSchema(_Reply)

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
return ModelInputData(
input=data.model_data.input,
instructions=data.model_data.instructions,
output_schema=override_schema,
)

result = await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

assert model.last_turn_args["output_schema"] is override_schema
assert isinstance(result.final_output, _Reply)
assert result.final_output.answer == "hi"


@pytest.mark.asyncio
async def test_filter_can_override_output_schema_streamed() -> None:
"""Regression test for #3563: filter can replace output_schema on streamed run."""
model = FakeModel()
agent = Agent(name="test", model=model)
model.set_next_output([get_text_message('{"answer": "hi"}')])

override_schema = AgentOutputSchema(_Reply)

async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
return ModelInputData(
input=data.model_data.input,
instructions=data.model_data.instructions,
output_schema=override_schema,
)

result = Runner.run_streamed(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)
async for _ in result.stream_events():
pass

assert model.last_turn_args["output_schema"] is override_schema


@pytest.mark.asyncio
async def test_filter_receives_agent_output_schema() -> None:
"""Filter should see the agent's output_schema in model_data so it can inspect or forward it."""
model = FakeModel()
agent = Agent(name="test", model=model, output_type=_Reply)
model.set_next_output([get_text_message('{"answer": "hi"}')])

observed: list[Any] = []

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
observed.append(data.model_data.output_schema)
return data.model_data

await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

assert len(observed) == 1
assert observed[0] is not None
assert observed[0].name() == "_Reply"


@pytest.mark.asyncio
async def test_filter_not_setting_output_schema_preserves_agent_schema() -> None:
"""A filter omitting output_schema must not clear the agent's schema."""
model = FakeModel()
agent = Agent(name="test", model=model, output_type=_Reply)
model.set_next_output([get_text_message('{"answer": "hi"}')])

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
# Intentionally omit output_schema to confirm the agent schema is preserved.
return ModelInputData(
input=data.model_data.input,
instructions=data.model_data.instructions,
)

await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

assert model.last_turn_args["output_schema"] is not None
assert model.last_turn_args["output_schema"].name() == "_Reply"
Loading