diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 93b10c1214..7d7496be26 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -1167,3 +1167,91 @@ def test_update_default_telemetry_enablement( assert _utils.to_dict(deployment_spec)["env"] == [ {"name": key, "value": value} for key, value in expected_env_vars.items() ] + + +class TestAdkAppMtls: + """Test cases for mTLS functionality in AdkApp.""" + + def test_use_client_cert_effective_with_should_use_client_cert(self): + """Verifies that it respects the google-auth mTLS enablement check.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", return_value=True, create=True): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_use_client_cert_effective_with_env_var_true(self): + """Verifies that it falls back to the environment variable if google-auth check fails.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}) + def test_use_client_cert_effective_with_env_var_false(self): + """Verifies that it respects the environment variable being set to false.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_default(self): + """Verifies the default telemetry endpoint is returned when no mTLS is configured.""" + assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_with_cert(self): + """Verifies the mTLS endpoint is used when forced and a certificate is available.""" + assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_no_cert(self): + """Verifies it falls back to regular endpoint even if forced if no certificate is provided.""" + assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}) + def test_get_api_endpoint_never(self): + """Verifies the regular endpoint is used when mTLS is explicitly disabled.""" + assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") + def test_default_instrumentor_builder_with_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS enabled.""" + # Mocking to enable mTLS + with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True): + with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True): + with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"): + adk_template._default_instrumentor_builder(_TEST_PROJECT_ID, enable_tracing=True) + + # Verify the session was configured for mTLS + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + # Verify the exporter was initialized with the mTLS endpoint + mock_exporter.assert_called_once() + assert mock_exporter.call_args.kwargs["endpoint"] == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_with_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS enabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + # Mocking to enable mTLS + with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True): + with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True): + with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was configured for the check request + mock_session.configure_mtls_channel.assert_called_once() + # Verify the check was performed against the mTLS endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, + data=None + ) + diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 6bab24f0a2..5d72d00773 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -26,11 +26,17 @@ import asyncio from collections.abc import Awaitable +import enum +import os import queue import sys import threading import warnings +from google.auth import exceptions as auth_exceptions +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth + if TYPE_CHECKING: try: from google.adk.events.event import Event @@ -106,6 +112,22 @@ "(If you enabled this API recently, you can safely ignore this warning.)" ) +_DEFAULT_TELEMETRY_ENDPOINT = "https://telemetry.googleapis.com/v1/traces" +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = ( + "https://telemetry.mtls.googleapis.com/v1/traces" +) + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + +_MutualTLSChannelError = auth_exceptions.MutualTLSChannelError + def get_adk_version() -> Optional[str]: """Returns the version of the ADK package.""" @@ -391,12 +413,24 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__ user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}" + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + span_exporter = ( opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( - session=google.auth.transport.requests.AuthorizedSession( - credentials=credentials - ), - endpoint="https://telemetry.googleapis.com/v1/traces", + session=session, + endpoint=endpoint, headers={"User-Agent": user_agent}, ) ) @@ -552,11 +586,80 @@ def _warn_if_telemetry_api_disabled(): except (ImportError, AttributeError): return credentials, project = google.auth.default() - session = google.auth.transport.requests.AuthorizedSession(credentials=credentials) - r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + r = session.post(endpoint, data=None) if "Telemetry API has not been used in project" in r.text: _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) +def _get_api_endpoint(client_cert_source=None): + """Returns the API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (Optional[bytes]): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + _warn( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of " + "%s. Defaulting to %s." + % ([e.value for e in MtlsEndpoint], MtlsEndpoint.AUTO.value) + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if ( + use_mtls_endpoint in (MtlsEndpoint.ALWAYS, MtlsEndpoint.AUTO) + and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + # check if google-auth version supports should_use_client_cert for automatic + # mTLS enablement + try: + return mtls.should_use_client_cert() + except (ImportError, AttributeError): + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + _warn( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" class AdkApp: """An ADK Application."""