From 3e1104dc3f87e51cb657ad856eb8ac7ca3708f25 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 15 May 2026 00:08:46 +0800 Subject: [PATCH] fix: preserve optional string tool arguments --- .../mcpserver/utilities/func_metadata.py | 20 +++++++++-- tests/server/mcpserver/test_func_metadata.py | 35 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 4a76106371..9452af9192 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -3,7 +3,7 @@ import json from collections.abc import Awaitable, Callable, Sequence from itertools import chain -from types import GenericAlias +from types import GenericAlias, NoneType from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints import anyio @@ -148,7 +148,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: continue field_info = key_to_field_info[data_key] - if isinstance(data_value, str) and field_info.annotation is not str: + if isinstance(data_value, str) and _should_pre_parse_json(field_info.annotation): try: pre_parsed = json.loads(data_value) except json.JSONDecodeError: @@ -167,6 +167,22 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: ) +def _is_simple_scalar_annotation(annotation: Any) -> bool: + return annotation in {str, int, float, bool, NoneType} + + +def _should_pre_parse_json(annotation: Any) -> bool: + """Return whether string input for this annotation should be JSON-decoded.""" + if annotation is str: + return False + + origin = get_origin(annotation) + if is_union_origin(origin): + return not all(_is_simple_scalar_annotation(arg) for arg in get_args(annotation)) + + return True + + def func_metadata( func: Callable[..., Any], skip_names: Sequence[str] = (), diff --git a/tests/server/mcpserver/test_func_metadata.py b/tests/server/mcpserver/test_func_metadata.py index c57d1ee9f0..fed58e97a7 100644 --- a/tests/server/mcpserver/test_func_metadata.py +++ b/tests/server/mcpserver/test_func_metadata.py @@ -551,6 +551,41 @@ def handle_json_payload(payload: str, strict_mode: bool = False) -> str: assert result == f"Handled payload of length {len(json_array_payload)}" +@pytest.mark.anyio +async def test_optional_str_annotation_preserves_json_string(): + def update_task(task_id: str | None = None) -> str: + assert isinstance(task_id, str) + return task_id + + meta = func_metadata(update_task) + + uuid = "3400e37e-b251-49d9-91b0-f8dd8602ff7e" + json_payload = '{"id": "3400e37e-b251-49d9-91b0-f8dd8602ff7e"}' + + assert meta.pre_parse_json({"task_id": uuid})["task_id"] == uuid + assert meta.pre_parse_json({"task_id": json_payload})["task_id"] == json_payload + assert meta.pre_parse_json({"task_id": "[1, 2]"})["task_id"] == "[1, 2]" + + result = await meta.call_fn_with_arg_validation( + update_task, + fn_is_async=False, + arguments_to_validate={"task_id": json_payload}, + arguments_to_pass_directly=None, + ) + + assert result == json_payload + + +def test_str_or_list_still_pre_parses_lists(): + def func_with_str_or_list(value: str | list[str]): # pragma: no cover + return value + + meta = func_metadata(func_with_str_or_list) + + assert meta.pre_parse_json({"value": "hello"})["value"] == "hello" + assert meta.pre_parse_json({"value": '["hello", "world"]'})["value"] == ["hello", "world"] + + # Tests for structured output functionality