diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 93b10c1214..c3dc17c6a5 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -506,6 +506,43 @@ async def test_streaming_agent_run_with_events( events.append(event) assert len(events) == 1 + @pytest.mark.asyncio + async def test_streaming_agent_run_with_events_extracts_user_id_from_headers( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + app._tmpl_attrs["in_memory_runner"] = _MockRunner() + + request_json = json.dumps( + { + "message": { + "parts": [{"text": "Hello"}], + "role": "user", + }, + } + ) + headers = { + "X-Goog-Authenticated-User-Email": "test_user_from_header@google.com" + } + + with mock.patch.object(app, "_init_session") as mock_init_session: + mock_session = mock.Mock() + mock_session.id = "mock_session_id" + mock_init_session.return_value = mock_session + + async for _ in app.streaming_agent_run_with_events( + request_json=request_json, headers=headers + ): + pass + + mock_init_session.assert_called_once() + # Assert that the extracted request object correctly pulled the user_id from headers + request_obj = mock_init_session.call_args.kwargs["request"] + assert request_obj.user_id == "test_user_from_header@google.com" + @pytest.mark.asyncio @mock.patch.dict( os.environ, diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 6bab24f0a2..57cf8aaf88 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -195,9 +195,14 @@ def __init__(self, **kwargs): ) # The authorizations of the user, keyed by authorization ID. - self.user_id: Optional[str] = kwargs.get("user_id") or kwargs.get( - "userId", _DEFAULT_USER_ID - ) + extracted_user_id = kwargs.get("user_id") or kwargs.get("userId") + if not extracted_user_id: + headers = kwargs.get("headers", {}) + extracted_user_id = headers.get( + "X-Goog-Authenticated-User-Email" + ) or headers.get("X-Endpoint-API-UserInfo") + + self.user_id: Optional[str] = extracted_user_id or _DEFAULT_USER_ID # The user ID. self.session_id: Optional[str] = kwargs.get("session_id") or kwargs.get( @@ -1195,7 +1200,9 @@ def stream_query( ): yield _utils.dump_event_for_json(event) - async def streaming_agent_run_with_events(self, request_json: str): + async def streaming_agent_run_with_events( + self, request_json: str, headers: Optional[Dict[str, str]] = None + ): """Streams responses asynchronously from the ADK application. In general, you should use `async_stream_query` instead, as it has a @@ -1206,13 +1213,18 @@ async def streaming_agent_run_with_events(self, request_json: str): Args: request_json (str): Required. The request to stream responses for. + headers (Dict[str, str]): + Optional. The HTTP request headers containing IAM metadata. """ import json from google.genai import types from google.genai.errors import ClientError - request = _StreamRunRequest(**json.loads(request_json)) + payload = json.loads(request_json) + if headers: + payload["headers"] = headers + request = _StreamRunRequest(**payload) if not any( self._tmpl_attrs.get(service) for service in (