diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 654a9e2f9f..e49b1cce9f 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -204,7 +204,7 @@ class RunAgentRequest(common.BaseModel): app_name: str user_id: str session_id: str - new_message: types.Content + new_message: Optional[types.Content] = None streaming: bool = False state_delta: Optional[dict[str, Any]] = None # for resume long-running functions @@ -371,7 +371,7 @@ def _otel_env_vars_enabled() -> bool: def _setup_gcp_telemetry( - internal_exporters: list[SpanProcessor] = None, + internal_exporters: list[SpanProcessor] | None = None, ): if typing.TYPE_CHECKING: from ..telemetry.setup import OTelHooks @@ -413,7 +413,7 @@ def _setup_gcp_telemetry( def _setup_telemetry_from_env( - internal_exporters: list[SpanProcessor] = None, + internal_exporters: list[SpanProcessor] | None = None, ): from ..telemetry.setup import maybe_set_otel_providers @@ -510,7 +510,7 @@ def __init__( # Internal properties we want to allow being modified from callbacks. self.runners_to_clean: set[str] = set() self.current_app_name_ref: SharedValue[str] = SharedValue(value="") - self.runner_dict = {} + self.runner_dict: dict[str, Runner] = {} self.url_prefix = url_prefix self.auto_create_session = auto_create_session @@ -712,8 +712,8 @@ def get_fast_api_app( A FastAPI app instance. """ # Properties we don't need to modify from callbacks - trace_dict = {} - session_trace_dict = {} + trace_dict: dict[str, Any] = {} + session_trace_dict: dict[str, list[int]] = {} # Set up a file system watcher to detect changes in the agents directory. observer = Observer() setup_observer(observer, self) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 921afee693..db6bec4367 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -23,6 +23,7 @@ from typing import Any from typing import AsyncGenerator from typing import Callable +from typing import cast from typing import Generator from typing import List from typing import Optional @@ -413,7 +414,7 @@ def run( The events generated by the agent. """ run_config = run_config or RunConfig() - event_queue = queue.Queue() + event_queue: queue.Queue[Optional[Event]] = queue.Queue() async def _invoke_run_async(): try: @@ -480,8 +481,8 @@ async def run_async( The events generated by the agent. Raises: - ValueError: If the session is not found; If both invocation_id and - new_message are None. + ValueError: If the session is not found and `auto_create_session` is False, + or if both `invocation_id` and `new_message` are `None`. """ run_config = run_config or RunConfig() @@ -496,6 +497,7 @@ async def _run_with_trace( session = await self._get_or_create_session( user_id=user_id, session_id=session_id ) + if not invocation_id and not new_message: raise ValueError( 'Running an agent requires either a new_message or an ' @@ -1001,7 +1003,7 @@ async def run_live( ) if not session: session = await self._get_or_create_session( - user_id=user_id, session_id=session_id + user_id=cast(str, user_id), session_id=cast(str, session_id) ) invocation_context = self._new_invocation_context_for_live( session, @@ -1320,7 +1322,7 @@ async def _setup_context_for_resumed_invocation( # Step 1: Maybe retrieve a previous user message for the invocation. user_message = new_message or self._find_user_message_for_invocation( - session.events, invocation_id + session.events, cast(str, invocation_id) ) if not user_message: raise ValueError( @@ -1536,12 +1538,7 @@ async def close(self): logger.info('Runner closed.') - if sys.version_info < (3, 11): - Self = 'Runner' # pylint: disable=invalid-name - else: - from typing import Self # pylint: disable=g-import-not-at-top - - async def __aenter__(self) -> Self: + async def __aenter__(self) -> 'Runner': """Async context manager entry.""" return self diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 1ab1d41f47..881cb6bff5 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -48,6 +48,7 @@ from google.adk.sessions.state import State from google.genai import types from pydantic import BaseModel +from pydantic import Field import pytest # Configure logging to help diagnose server startup issues @@ -132,6 +133,7 @@ async def dummy_run_async( run_config: Optional[RunConfig] = None, invocation_id: Optional[str] = None, ): + run_config = run_config or RunConfig() yield _event_1() await asyncio.sleep(0) @@ -154,9 +156,9 @@ class _MockEvalCaseResult(BaseModel): user_id: str session_id: str eval_set_file: str - eval_metric_results: list = {} - overall_eval_metric_results: list = ({},) - eval_metric_result_per_invocation: list = {} + eval_metric_results: list = Field(default_factory=list) + overall_eval_metric_results: list = Field(default_factory=list) + eval_metric_result_per_invocation: list = Field(default_factory=list) ################################################# @@ -1530,6 +1532,31 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() + +@pytest.mark.parametrize( + "extra_payload", + [ + {}, + {"state_delta": {"some_key": "some_value"}}, + ], + ids=["no_state_delta", "with_state_delta"], +) +def test_agent_run_resume_without_message_success( + test_app, create_test_session, extra_payload +): + """Test that /run allows resuming a session with only an invocation_id.""" + info = create_test_session + url = "/run" + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "invocation_id": "test_invocation_id", + "streaming": False, + **extra_payload, + } + response = test_app.post(url, json=payload) + assert response.status_code == 200 def test_health_endpoint(test_app): """Test the health endpoint.""" response = test_app.get("/health")