Skip to content

Commit 3e1104d

Browse files
committed
fix: preserve optional string tool arguments
1 parent 161834d commit 3e1104d

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from collections.abc import Awaitable, Callable, Sequence
55
from itertools import chain
6-
from types import GenericAlias
6+
from types import GenericAlias, NoneType
77
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
88

99
import anyio
@@ -148,7 +148,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
148148
continue
149149

150150
field_info = key_to_field_info[data_key]
151-
if isinstance(data_value, str) and field_info.annotation is not str:
151+
if isinstance(data_value, str) and _should_pre_parse_json(field_info.annotation):
152152
try:
153153
pre_parsed = json.loads(data_value)
154154
except json.JSONDecodeError:
@@ -167,6 +167,22 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
167167
)
168168

169169

170+
def _is_simple_scalar_annotation(annotation: Any) -> bool:
171+
return annotation in {str, int, float, bool, NoneType}
172+
173+
174+
def _should_pre_parse_json(annotation: Any) -> bool:
175+
"""Return whether string input for this annotation should be JSON-decoded."""
176+
if annotation is str:
177+
return False
178+
179+
origin = get_origin(annotation)
180+
if is_union_origin(origin):
181+
return not all(_is_simple_scalar_annotation(arg) for arg in get_args(annotation))
182+
183+
return True
184+
185+
170186
def func_metadata(
171187
func: Callable[..., Any],
172188
skip_names: Sequence[str] = (),

tests/server/mcpserver/test_func_metadata.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,41 @@ def handle_json_payload(payload: str, strict_mode: bool = False) -> str:
551551
assert result == f"Handled payload of length {len(json_array_payload)}"
552552

553553

554+
@pytest.mark.anyio
555+
async def test_optional_str_annotation_preserves_json_string():
556+
def update_task(task_id: str | None = None) -> str:
557+
assert isinstance(task_id, str)
558+
return task_id
559+
560+
meta = func_metadata(update_task)
561+
562+
uuid = "3400e37e-b251-49d9-91b0-f8dd8602ff7e"
563+
json_payload = '{"id": "3400e37e-b251-49d9-91b0-f8dd8602ff7e"}'
564+
565+
assert meta.pre_parse_json({"task_id": uuid})["task_id"] == uuid
566+
assert meta.pre_parse_json({"task_id": json_payload})["task_id"] == json_payload
567+
assert meta.pre_parse_json({"task_id": "[1, 2]"})["task_id"] == "[1, 2]"
568+
569+
result = await meta.call_fn_with_arg_validation(
570+
update_task,
571+
fn_is_async=False,
572+
arguments_to_validate={"task_id": json_payload},
573+
arguments_to_pass_directly=None,
574+
)
575+
576+
assert result == json_payload
577+
578+
579+
def test_str_or_list_still_pre_parses_lists():
580+
def func_with_str_or_list(value: str | list[str]): # pragma: no cover
581+
return value
582+
583+
meta = func_metadata(func_with_str_or_list)
584+
585+
assert meta.pre_parse_json({"value": "hello"})["value"] == "hello"
586+
assert meta.pre_parse_json({"value": '["hello", "world"]'})["value"] == ["hello", "world"]
587+
588+
554589
# Tests for structured output functionality
555590

556591

0 commit comments

Comments
 (0)