From 038fa8409243df47e97240b6f6bd7a59ebe697ad Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 28 Apr 2026 16:12:32 -0700 Subject: [PATCH] feat: Add mTLS support for telemetry endpoint in adk.py. This change enables the telemetry exporter to use mTLS endpoints when configured, by dynamically determining the correct endpoint and configuring the requests session accordingly. It introduces helper functions to handle client certificate source management. PiperOrigin-RevId: 907230510 --- .../test_agent_engine_templates_adk.py | 88 ++++++++++++++ vertexai/agent_engines/templates/adk.py | 115 +++++++++++++++++- 2 files changed, 197 insertions(+), 6 deletions(-) diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 93b10c1214..7d7496be26 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -1167,3 +1167,91 @@ def test_update_default_telemetry_enablement( assert _utils.to_dict(deployment_spec)["env"] == [ {"name": key, "value": value} for key, value in expected_env_vars.items() ] + + +class TestAdkAppMtls: + """Test cases for mTLS functionality in AdkApp.""" + + def test_use_client_cert_effective_with_should_use_client_cert(self): + """Verifies that it respects the google-auth mTLS enablement check.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", return_value=True, create=True): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_use_client_cert_effective_with_env_var_true(self): + """Verifies that it falls back to the environment variable if google-auth check fails.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}) + def test_use_client_cert_effective_with_env_var_false(self): + """Verifies that it respects the environment variable being set to false.""" + with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_default(self): + """Verifies the default telemetry endpoint is returned when no mTLS is configured.""" + assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_with_cert(self): + """Verifies the mTLS endpoint is used when forced and a certificate is available.""" + assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_no_cert(self): + """Verifies it falls back to regular endpoint even if forced if no certificate is provided.""" + assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}) + def test_get_api_endpoint_never(self): + """Verifies the regular endpoint is used when mTLS is explicitly disabled.""" + assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_TELEMETRY_ENDPOINT + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter") + def test_default_instrumentor_builder_with_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS enabled.""" + # Mocking to enable mTLS + with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True): + with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True): + with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"): + adk_template._default_instrumentor_builder(_TEST_PROJECT_ID, enable_tracing=True) + + # Verify the session was configured for mTLS + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + # Verify the exporter was initialized with the mTLS endpoint + mock_exporter.assert_called_once() + assert mock_exporter.call_args.kwargs["endpoint"] == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_with_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS enabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + # Mocking to enable mTLS + with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True): + with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True): + with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was configured for the check request + mock_session.configure_mtls_channel.assert_called_once() + # Verify the check was performed against the mTLS endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, + data=None + ) + diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 6bab24f0a2..5d72d00773 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -26,11 +26,17 @@ import asyncio from collections.abc import Awaitable +import enum +import os import queue import sys import threading import warnings +from google.auth import exceptions as auth_exceptions +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth + if TYPE_CHECKING: try: from google.adk.events.event import Event @@ -106,6 +112,22 @@ "(If you enabled this API recently, you can safely ignore this warning.)" ) +_DEFAULT_TELEMETRY_ENDPOINT = "https://telemetry.googleapis.com/v1/traces" +_DEFAULT_MTLS_TELEMETRY_ENDPOINT = ( + "https://telemetry.mtls.googleapis.com/v1/traces" +) + + +class MtlsEndpoint(enum.Enum): + """Enum for the mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + +_MutualTLSChannelError = auth_exceptions.MutualTLSChannelError + def get_adk_version() -> Optional[str]: """Returns the version of the ADK package.""" @@ -391,12 +413,24 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__ user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}" + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + span_exporter = ( opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( - session=google.auth.transport.requests.AuthorizedSession( - credentials=credentials - ), - endpoint="https://telemetry.googleapis.com/v1/traces", + session=session, + endpoint=endpoint, headers={"User-Agent": user_agent}, ) ) @@ -552,11 +586,80 @@ def _warn_if_telemetry_api_disabled(): except (ImportError, AttributeError): return credentials, project = google.auth.default() - session = google.auth.transport.requests.AuthorizedSession(credentials=credentials) - r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + session = requests_auth.AuthorizedSession(credentials=credentials) + + use_client_cert = _use_client_cert_effective() + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + endpoint = _get_api_endpoint(client_cert_source) + else: + endpoint = _DEFAULT_TELEMETRY_ENDPOINT + r = session.post(endpoint, data=None) if "Telemetry API has not been used in project" in r.text: _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) +def _get_api_endpoint(client_cert_source=None): + """Returns the API endpoint based on mTLS configuration and cert availability. + + Args: + client_cert_source (Optional[bytes]): The client certificate source. + + Returns: + str: The API endpoint to be used. + """ + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + _warn( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of " + "%s. Defaulting to %s." + % ([e.value for e in MtlsEndpoint], MtlsEndpoint.AUTO.value) + ) + use_mtls_endpoint = MtlsEndpoint.AUTO + + if ( + use_mtls_endpoint in (MtlsEndpoint.ALWAYS, MtlsEndpoint.AUTO) + and client_cert_source + ): + return _DEFAULT_MTLS_TELEMETRY_ENDPOINT + + return _DEFAULT_TELEMETRY_ENDPOINT + + +def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS. + + This checks if the google-auth version supports should_use_client_cert + automatic mTLS enablement. Alternatively, it reads from the + GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS. + """ + # check if google-auth version supports should_use_client_cert for automatic + # mTLS enablement + try: + return mtls.should_use_client_cert() + except (ImportError, AttributeError): + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + _warn( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" class AdkApp: """An ADK Application."""