diff --git a/src/google/adk/tools/retrieval/__init__.py b/src/google/adk/tools/retrieval/__init__.py index 93fabf08e6..f958ace528 100644 --- a/src/google/adk/tools/retrieval/__init__.py +++ b/src/google/adk/tools/retrieval/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. from .base_retrieval_tool import BaseRetrievalTool +from .callable_retrieval import CallableRetrieval __all__ = [ "BaseRetrievalTool", + "CallableRetrieval", "FilesRetrieval", "LlamaIndexRetrieval", "VertexAiRagRetrieval", diff --git a/src/google/adk/tools/retrieval/callable_retrieval.py b/src/google/adk/tools/retrieval/callable_retrieval.py new file mode 100644 index 0000000000..0cdd0b71fd --- /dev/null +++ b/src/google/adk/tools/retrieval/callable_retrieval.py @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Retrieval tool that wraps a user-provided callable.""" + +from __future__ import annotations + +import inspect +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Union + +from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool +from google.adk.tools.tool_context import ToolContext +from typing_extensions import override + + +class CallableRetrieval(BaseRetrievalTool): + """Retrieval tool backed by a user-provided function. + + Wraps any callable that accepts a query string and returns results, + making it a first-class retrieval tool in ADK. + + Example: + >>> def search_docs(query: str) -> list[str]: + ... return my_db.search(query) + >>> tool = CallableRetrieval( + ... name="search_docs", + ... description="Search the knowledge base.", + ... retriever=search_docs, + ... ) + + Args: + name: Tool name exposed to the LLM. + description: Tool description exposed to the LLM. + retriever: A sync or async callable. Must accept a ``query`` + string as its first argument. May optionally accept a + ``tool_context`` parameter. + """ + + def __init__( + self, + *, + name: str, + description: str, + retriever: Union[ + Callable[[str], Any], + Callable[[str], Awaitable[Any]], + ], + ): + super().__init__(name=name, description=description) + self._retriever = retriever + self._pass_tool_context = ( + "tool_context" in inspect.signature(retriever).parameters + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + query = args["query"] + kwargs = {"tool_context": tool_context} if self._pass_tool_context else {} + result = self._retriever(query, **kwargs) + if inspect.isawaitable(result): + return await result + return result diff --git a/tests/unittests/tools/retrieval/test_callable_retrieval.py b/tests/unittests/tools/retrieval/test_callable_retrieval.py new file mode 100644 index 0000000000..10531379ff --- /dev/null +++ b/tests/unittests/tools/retrieval/test_callable_retrieval.py @@ -0,0 +1,167 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool +from google.adk.tools.retrieval.callable_retrieval import CallableRetrieval +from google.adk.tools.tool_context import ToolContext +import pytest + + +@pytest.fixture +def mock_tool_context(): + return MagicMock(spec=ToolContext) + + +def test_isinstance_base_retrieval_tool(): + tool = CallableRetrieval( + name="test", + description="A test tool.", + retriever=lambda query: [], + ) + assert isinstance(tool, BaseRetrievalTool) + + +def test_get_declaration(): + tool = CallableRetrieval( + name="my_search", + description="Search docs.", + retriever=lambda query: [], + ) + declaration = tool._get_declaration() + assert declaration.name == "my_search" + assert declaration.description == "Search docs." + + +@pytest.mark.asyncio +async def test_sync_callable(mock_tool_context): + def my_retriever(query: str): + return [f"result for {query}"] + + tool = CallableRetrieval( + name="sync_tool", + description="A sync retrieval tool.", + retriever=my_retriever, + ) + result = await tool.run_async( + args={"query": "hello"}, tool_context=mock_tool_context + ) + assert result == ["result for hello"] + + +@pytest.mark.asyncio +async def test_async_callable(mock_tool_context): + async def my_retriever(query: str): + return [f"async result for {query}"] + + tool = CallableRetrieval( + name="async_tool", + description="An async retrieval tool.", + retriever=my_retriever, + ) + result = await tool.run_async( + args={"query": "world"}, tool_context=mock_tool_context + ) + assert result == ["async result for world"] + + +@pytest.mark.asyncio +async def test_tool_context_passthrough(mock_tool_context): + received_context = {} + + def my_retriever(query: str, tool_context: ToolContext): + received_context["ctx"] = tool_context + return ["with context"] + + tool = CallableRetrieval( + name="ctx_tool", + description="Tool with context.", + retriever=my_retriever, + ) + result = await tool.run_async( + args={"query": "test"}, tool_context=mock_tool_context + ) + assert result == ["with context"] + assert received_context["ctx"] is mock_tool_context + + +@pytest.mark.asyncio +async def test_tool_context_omission(mock_tool_context): + def my_retriever(query: str): + return ["no context needed"] + + tool = CallableRetrieval( + name="no_ctx_tool", + description="Tool without context.", + retriever=my_retriever, + ) + result = await tool.run_async( + args={"query": "test"}, tool_context=mock_tool_context + ) + assert result == ["no context needed"] + + +@pytest.mark.asyncio +async def test_async_callable_with_tool_context(mock_tool_context): + async def my_retriever(query: str, tool_context: ToolContext): + return [f"async {query} with context"] + + tool = CallableRetrieval( + name="async_ctx_tool", + description="Async tool with context.", + retriever=my_retriever, + ) + result = await tool.run_async( + args={"query": "test"}, tool_context=mock_tool_context + ) + assert result == ["async test with context"] + + +@pytest.mark.asyncio +async def test_sync_callable_object(mock_tool_context): + + class MyRetriever: + + def __call__(self, query: str): + return [f"object result for {query}"] + + tool = CallableRetrieval( + name="obj_tool", + description="Callable object tool.", + retriever=MyRetriever(), + ) + result = await tool.run_async( + args={"query": "hello"}, tool_context=mock_tool_context + ) + assert result == ["object result for hello"] + + +@pytest.mark.asyncio +async def test_async_callable_object(mock_tool_context): + + class MyAsyncRetriever: + + async def __call__(self, query: str): + return [f"async object result for {query}"] + + tool = CallableRetrieval( + name="async_obj_tool", + description="Async callable object tool.", + retriever=MyAsyncRetriever(), + ) + result = await tool.run_async( + args={"query": "world"}, tool_context=mock_tool_context + ) + assert result == ["async object result for world"]