From 4e17a9c3cbba4a9229f9ca2958f7a26257e8e2e7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 30 Apr 2026 03:44:17 -0700 Subject: [PATCH] fix: avoid caching stale token in async mTLS path PiperOrigin-RevId: 908074886 --- google/genai/_api_client.py | 32 ++++++++++- .../client/test_client_initialization.py | 57 +++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 6fd98c1ce..1812df911 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -865,6 +865,7 @@ def _use_google_auth_async(self) -> bool: return bool( has_aiohttp and self.vertexai + and hasattr(mtls, 'should_use_client_cert') and mtls.should_use_client_cert() # type: ignore[no-untyped-call] and mtls.has_default_client_cert_source() # type: ignore[no-untyped-call] and not self._http_options.httpx_async_client @@ -877,11 +878,36 @@ async def _get_aiohttp_session( if self._aiohttp_session is None and self._use_google_auth_async(): try: - from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.credentials import Credentials as AsyncCredentials from google.auth.aio.transport.sessions import AsyncAuthorizedSession - async_creds = StaticCredentials(token=self._access_token()) # type: ignore[no-untyped-call] - self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call,assignment] + class _RefreshableAsyncCredentials(AsyncCredentials): # type: ignore[misc, valid-type] + """Adapter to use the client's sync credentials in an AsyncAuthorizedSession.""" + + def __init__(self, client: 'BaseApiClient'): + super().__init__() # type: ignore[no-untyped-call] + self._client = client + + async def before_request( + self, request: Any, method: str, url: str, headers: dict[str, str] + ) -> None: + token = await self._client._async_access_token() + headers['Authorization'] = f'Bearer {token}' + if ( + self._client._credentials + and self._client._credentials.quota_project_id + ): + headers['x-goog-user-project'] = ( + self._client._credentials.quota_project_id + ) + + @property + def valid(self) -> bool: + if not self._client._credentials: + return False + return not self._client._credentials.expired + + self._aiohttp_session = AsyncAuthorizedSession(_RefreshableAsyncCredentials(self)) # type: ignore[no-untyped-call,assignment] return self._aiohttp_session # type: ignore[return-value] except ImportError: pass diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index ed0e1a31c..4ad2dba75 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -21,6 +21,7 @@ import logging import os import ssl +import sys from unittest import mock import certifi @@ -1870,3 +1871,59 @@ async def test_get_aiohttp_session(): assert initial_session is not None session = await client._api_client._get_aiohttp_session() assert session is initial_session + + +@requires_aiohttp +@pytest.mark.asyncio +async def test_async_mtls_uses_refreshable_credentials(monkeypatch): + """Tests that _RefreshableAsyncCredentials is used in async mTLS path.""" + from google.genai import _api_client + + # Ensure _use_google_auth_async returns True + monkeypatch.setattr(_api_client, "has_aiohttp", True) + monkeypatch.setattr(_api_client.mtls, "should_use_client_cert", lambda: True, raising=False) + monkeypatch.setattr( + _api_client.mtls, "has_default_client_cert_source", lambda: True + ) + + # Mock AsyncAuthorizedSession and google.auth.aio modules + mock_session = mock.MagicMock() + mock_auth_aio = mock.MagicMock() + monkeypatch.setitem(sys.modules, "google.auth.aio", mock_auth_aio) + monkeypatch.setitem( + sys.modules, "google.auth.aio.credentials", mock_auth_aio.credentials + ) + monkeypatch.setitem( + sys.modules, "google.auth.aio.transport", mock_auth_aio.transport + ) + monkeypatch.setitem( + sys.modules, + "google.auth.aio.transport.sessions", + mock_auth_aio.transport.sessions, + ) + mock_auth_aio.transport.sessions.AsyncAuthorizedSession = mock_session + mock_auth_aio.credentials.Credentials = mock.MagicMock + + # Mock credentials + mock_creds = mock.MagicMock() + mock_creds.expired = False + mock_creds.token = "initial_token" + monkeypatch.setattr( + google.auth, "default", lambda scopes=None: (mock_creds, "fake-project") + ) + + client = Client(vertexai=True, project="fake-project") + client._api_client._credentials = mock_creds + + # Trigger session creation + await client._api_client._get_aiohttp_session() + + # Verify AsyncAuthorizedSession was called with _RefreshableAsyncCredentials + assert mock_session.call_count == 1 + passed_creds = mock_session.call_args[0][0] + assert type(passed_creds).__name__ == "_RefreshableAsyncCredentials" + + # Verify valid property + assert passed_creds.valid == True + mock_creds.expired = True + assert passed_creds.valid == False