Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions astrbot/core/agent/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,44 @@ def convert_schema(schema: dict) -> dict:
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}

if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
def apply_supported_fields(result: dict, source: dict) -> None:
for key in support_fields:
if key in source and key not in result:
result[key] = source[key]

for union_key in ("anyOf", "oneOf"):
union_value = schema.get(union_key)
if isinstance(union_value, list):
converted_branches = [
convert_schema(item) if isinstance(item, dict) else item
for item in union_value
]
non_null_branches = [
item
for item in converted_branches
if not (isinstance(item, dict) and item.get("type") == "null")
]
if len(non_null_branches) == 1 and isinstance(
non_null_branches[0], dict
):
result = non_null_branches[0].copy()
if len(converted_branches) > 1:
result["nullable"] = True
apply_supported_fields(result, schema)
return result
return {union_key: converted_branches}

result = {}

Expand All @@ -268,6 +303,12 @@ def convert_schema(schema: dict) -> dict:

if target_type in supported_types:
result["type"] = target_type
if (
isinstance(origin_type, list)
and "null" in origin_type
and target_type != "null"
):
result["nullable"] = True
if "format" in schema and schema["format"] in supported_formats.get(
result["type"],
set(),
Expand All @@ -276,18 +317,7 @@ def convert_schema(schema: dict) -> dict:
else:
result["type"] = "null"

support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
apply_supported_fields(result, schema)

if "properties" in schema:
properties = {}
Expand Down
128 changes: 128 additions & 0 deletions tests/unit/test_tool_google_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,131 @@ def test_google_schema_fills_missing_array_items_with_string_schema():

assert source_uuids["type"] == "array"
assert source_uuids["items"] == {"type": "string"}


def test_google_schema_collapses_nullable_anyof_property():
tool_module = load_tool_module()
FunctionTool = tool_module.FunctionTool
ToolSet = tool_module.ToolSet

tool = FunctionTool(
name="search_sources",
description="Search sources by recency.",
parameters={
"type": "object",
"properties": {
"time_range": {
"description": "Optional recency filter.",
"anyOf": [
{
"type": "string",
"enum": ["day", "week", "month", "year"],
},
{"type": "null"},
],
"default": None,
}
},
},
)

schema = ToolSet([tool]).google_schema()
time_range = schema["function_declarations"][0]["parameters"]["properties"][
"time_range"
]

assert time_range["type"] == "string"
assert time_range["description"] == "Optional recency filter."
assert time_range["enum"] == ["day", "week", "month", "year"]
assert time_range["nullable"] is True
assert "anyOf" not in time_range
assert "default" not in time_range


def test_google_schema_collapses_single_branch_anyof_property():
tool_module = load_tool_module()
FunctionTool = tool_module.FunctionTool
ToolSet = tool_module.ToolSet

tool = FunctionTool(
name="search_sources",
description="Search sources by query.",
parameters={
"type": "object",
"properties": {
"query": {
"description": "Search query.",
"anyOf": [
{
"type": "string",
}
],
}
},
},
)

schema = ToolSet([tool]).google_schema()
query = schema["function_declarations"][0]["parameters"]["properties"]["query"]

assert query["type"] == "string"
assert query["description"] == "Search query."
assert "nullable" not in query
assert "anyOf" not in query


def test_google_schema_preserves_non_dict_union_branches():
tool_module = load_tool_module()
FunctionTool = tool_module.FunctionTool
ToolSet = tool_module.ToolSet

tool = FunctionTool(
name="search_sources",
description="Search sources by literal value.",
parameters={
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
False,
],
}
},
},
)

schema = ToolSet([tool]).google_schema()
value = schema["function_declarations"][0]["parameters"]["properties"]["value"]

assert value["anyOf"] == [
{"type": "string"},
False,
]


def test_google_schema_marks_type_list_with_null_as_nullable():
tool_module = load_tool_module()
FunctionTool = tool_module.FunctionTool
ToolSet = tool_module.ToolSet

tool = FunctionTool(
name="search_sources",
description="Search sources by recency.",
parameters={
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": "Optional query.",
}
},
},
)

schema = ToolSet([tool]).google_schema()
query = schema["function_declarations"][0]["parameters"]["properties"]["query"]

assert query["type"] == "string"
assert query["description"] == "Optional query."
assert query["nullable"] is True