diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 6fd98c1ce..4ceb255dd 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -858,7 +858,7 @@ def _use_google_auth_sync(self) -> bool: def _use_google_auth_async(self) -> bool: try: - from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.credentials import Credentials as AsyncCredentials from google.auth.aio.transport.sessions import AsyncAuthorizedSession except ImportError: return False @@ -877,11 +877,36 @@ async def _get_aiohttp_session( if self._aiohttp_session is None and self._use_google_auth_async(): try: - from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.credentials import Credentials as AsyncCredentials from google.auth.aio.transport.sessions import AsyncAuthorizedSession - async_creds = StaticCredentials(token=self._access_token()) # type: ignore[no-untyped-call] - self._aiohttp_session = AsyncAuthorizedSession(async_creds) # type: ignore[no-untyped-call,assignment] + class _RefreshableAsyncCredentials(AsyncCredentials): # type: ignore[misc, valid-type] + """Adapter to use the client's sync credentials in an AsyncAuthorizedSession.""" + + def __init__(self, client: 'BaseApiClient'): + super().__init__() # type: ignore[no-untyped-call] + self._client = client + + async def before_request( + self, request: Any, method: str, url: str, headers: dict[str, str] + ) -> None: + token = await self._client._async_access_token() + headers['Authorization'] = f'Bearer {token}' + if ( + self._client._credentials + and self._client._credentials.quota_project_id + ): + headers['x-goog-user-project'] = ( + self._client._credentials.quota_project_id + ) + + @property + def valid(self) -> bool: + if not self._client._credentials: + return False + return not self._client._credentials.expired + + self._aiohttp_session = AsyncAuthorizedSession(_RefreshableAsyncCredentials(self)) # type: ignore[no-untyped-call,assignment] return self._aiohttp_session # type: ignore[return-value] except ImportError: pass