diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..ba6f7d32dd 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -28,7 +28,7 @@ from .sandbox.session.sandbox_client import BaseSandboxClient from .sandbox.session.sandbox_session_state import SandboxSessionState from .sandbox.snapshot import SnapshotBase, SnapshotSpec - +from openai.types.chat.completion_create_params import ResponseFormat DEFAULT_MAX_TURNS = 10 DEFAULT_MAX_MANIFEST_ENTRY_CONCURRENCY = 4 @@ -50,6 +50,7 @@ class ModelInputData: input: list[TResponseInputItem] instructions: str | None + response_format: ResponseFormat | None = None @dataclass diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..3e550cb35d 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -194,6 +194,7 @@ resolve_interrupted_turn, run_final_output_hooks, ) +from openai.types.chat.completion_create_params import ResponseFormat __all__ = [ "extract_tool_call_id", @@ -1366,6 +1367,19 @@ def _tool_search_fingerprint(raw_item: Any) -> str: streamed_result._model_input_items, reasoning_item_id_policy, ) + model_settings = get_model_settings(execution_agent, run_config) + model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings) + + current_response_format: ResponseFormat | None = None + if output_schema and hasattr(output_schema, "is_plain_text") and not output_schema.is_plain_text(): + current_response_format = cast(ResponseFormat, { + "type": "json_schema", + "json_schema": { + "name": "final_output", + "strict": output_schema.is_strict_json_schema(), + "schema": output_schema.json_schema(), + }, + }) filtered = await maybe_filter_model_input( agent=public_agent, @@ -1373,6 +1387,7 @@ def _tool_search_fingerprint(raw_item: Any) -> str: context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + response_format=current_response_format, ) if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) @@ -1815,19 +1830,33 @@ async def get_new_response( """Call the model and return the raw response, handling retries and hooks.""" public_agent = bindings.public_agent execution_agent = bindings.execution_agent + + model = get_model(execution_agent, run_config) + model_settings = get_model_settings(execution_agent, run_config) + model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings) + + current_response_format: ResponseFormat | None = None + if output_schema and hasattr(output_schema, "is_plain_text") and not output_schema.is_plain_text(): + current_response_format = cast(ResponseFormat, { + "type": "json_schema", + "json_schema": { + "name": "final_output", + "strict": output_schema.is_strict_json_schema(), + "schema": output_schema.json_schema(), + }, + }) + filtered = await maybe_filter_model_input( agent=public_agent, run_config=run_config, context_wrapper=context_wrapper, input_items=input, system_instructions=system_prompt, + response_format=current_response_format, ) if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) - model = get_model(execution_agent, run_config) - model_settings = get_model_settings(execution_agent, run_config) - model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings) if server_conversation_tracker is not None: server_conversation_tracker.mark_input_as_sent(filtered.input) diff --git a/src/agents/run_internal/turn_preparation.py b/src/agents/run_internal/turn_preparation.py index 0a79ebd813..2a7602a334 100644 --- a/src/agents/run_internal/turn_preparation.py +++ b/src/agents/run_internal/turn_preparation.py @@ -18,6 +18,7 @@ from ..tool import Tool from ..tracing import SpanError from ..util import _error_tracing +from openai.types.chat import ResponseFormat __all__ = [ "validate_run_hooks", @@ -55,18 +56,24 @@ async def maybe_filter_model_input( context_wrapper: RunContextWrapper[TContext], input_items: list[TResponseInputItem], system_instructions: str | None, + response_format: ResponseFormat | None = None, ) -> ModelInputData: """Apply optional call_model_input_filter to modify model input.""" effective_instructions = system_instructions effective_input: list[TResponseInputItem] = input_items if run_config.call_model_input_filter is None: - return ModelInputData(input=effective_input, instructions=effective_instructions) + return ModelInputData( + input=effective_input, + instructions=effective_instructions, + response_format=response_format, + ) try: model_input = ModelInputData( input=effective_input.copy(), instructions=effective_instructions, + response_format=response_format, ) filter_payload: CallModelData[TContext] = CallModelData( model_data=model_input,