Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,91 @@ def test_update_default_telemetry_enablement(
assert _utils.to_dict(deployment_spec)["env"] == [
{"name": key, "value": value} for key, value in expected_env_vars.items()
]


class TestAdkAppMtls:
"""Test cases for mTLS functionality in AdkApp."""

def test_use_client_cert_effective_with_should_use_client_cert(self):
"""Verifies that it respects the google-auth mTLS enablement check."""
with mock.patch.object(adk_template.mtls, "should_use_client_cert", return_value=True, create=True):
assert adk_template._use_client_cert_effective() is True

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
def test_use_client_cert_effective_with_env_var_true(self):
"""Verifies that it falls back to the environment variable if google-auth check fails."""
with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True):
assert adk_template._use_client_cert_effective() is True

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"})
def test_use_client_cert_effective_with_env_var_false(self):
"""Verifies that it respects the environment variable being set to false."""
with mock.patch.object(adk_template.mtls, "should_use_client_cert", side_effect=AttributeError, create=True):
assert adk_template._use_client_cert_effective() is False

def test_get_api_endpoint_default(self):
"""Verifies the default telemetry endpoint is returned when no mTLS is configured."""
assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"})
def test_get_api_endpoint_always_with_cert(self):
"""Verifies the mTLS endpoint is used when forced and a certificate is available."""
assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"})
def test_get_api_endpoint_always_no_cert(self):
"""Verifies it falls back to regular endpoint even if forced if no certificate is provided."""
assert adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"})
def test_get_api_endpoint_never(self):
"""Verifies the regular endpoint is used when mTLS is explicitly disabled."""
assert adk_template._get_api_endpoint(client_cert_source=b"cert") == adk_template._DEFAULT_TELEMETRY_ENDPOINT

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
@mock.patch("opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter")
def test_default_instrumentor_builder_with_mtls(
self,
mock_exporter,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the instrumentor builder with mTLS enabled."""
# Mocking to enable mTLS
with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True):
with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True):
with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"):
adk_template._default_instrumentor_builder(_TEST_PROJECT_ID, enable_tracing=True)

# Verify the session was configured for mTLS
mock_session_cls.return_value.configure_mtls_channel.assert_called_once()
# Verify the exporter was initialized with the mTLS endpoint
mock_exporter.assert_called_once()
assert mock_exporter.call_args.kwargs["endpoint"] == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
def test_warn_if_telemetry_api_disabled_with_mtls(
self,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the telemetry API check with mTLS enabled."""
mock_session = mock_session_cls.return_value
mock_session.post.return_value = mock.Mock(text="")

# Mocking to enable mTLS
with mock.patch.object(adk_template, "_use_client_cert_effective", return_value=True):
with mock.patch.object(adk_template.mtls, "has_default_client_cert_source", return_value=True):
with mock.patch.object(adk_template.mtls, "default_client_cert_source", return_value=lambda: b"cert"):
adk_template._warn_if_telemetry_api_disabled()

# Verify mTLS channel was configured for the check request
mock_session.configure_mtls_channel.assert_called_once()
# Verify the check was performed against the mTLS endpoint
mock_session.post.assert_called_once_with(
adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT,
data=None
)

115 changes: 109 additions & 6 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@

import asyncio
from collections.abc import Awaitable
import enum
import os
import queue
import sys
import threading
import warnings

from google.auth import exceptions as auth_exceptions
from google.auth.transport import mtls
from google.auth.transport import requests as requests_auth

if TYPE_CHECKING:
try:
from google.adk.events.event import Event
Expand Down Expand Up @@ -106,6 +112,22 @@
"(If you enabled this API recently, you can safely ignore this warning.)"
)

_DEFAULT_TELEMETRY_ENDPOINT = "https://telemetry.googleapis.com/v1/traces"
_DEFAULT_MTLS_TELEMETRY_ENDPOINT = (
"https://telemetry.mtls.googleapis.com/v1/traces"
)


class MtlsEndpoint(enum.Enum):
"""Enum for the mTLS endpoint setting."""

AUTO = "auto"
ALWAYS = "always"
NEVER = "never"


_MutualTLSChannelError = auth_exceptions.MutualTLSChannelError


def get_adk_version() -> Optional[str]:
"""Returns the version of the ADK package."""
Expand Down Expand Up @@ -391,12 +413,24 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:
otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__
user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}"

session = requests_auth.AuthorizedSession(credentials=credentials)

use_client_cert = _use_client_cert_effective()
if use_client_cert:
client_cert_source = (
mtls.default_client_cert_source()
if mtls.has_default_client_cert_source()
else None
)
session.configure_mtls_channel()
endpoint = _get_api_endpoint(client_cert_source)
else:
endpoint = _DEFAULT_TELEMETRY_ENDPOINT

span_exporter = (
opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter(
session=google.auth.transport.requests.AuthorizedSession(
credentials=credentials
),
endpoint="https://telemetry.googleapis.com/v1/traces",
session=session,
endpoint=endpoint,
headers={"User-Agent": user_agent},
)
)
Expand Down Expand Up @@ -552,11 +586,80 @@ def _warn_if_telemetry_api_disabled():
except (ImportError, AttributeError):
return
credentials, project = google.auth.default()
session = google.auth.transport.requests.AuthorizedSession(credentials=credentials)
r = session.post("https://telemetry.googleapis.com/v1/traces", data=None)
session = requests_auth.AuthorizedSession(credentials=credentials)

use_client_cert = _use_client_cert_effective()
if use_client_cert:
client_cert_source = (
mtls.default_client_cert_source()
if mtls.has_default_client_cert_source()
else None
)
session.configure_mtls_channel()
endpoint = _get_api_endpoint(client_cert_source)
else:
endpoint = _DEFAULT_TELEMETRY_ENDPOINT
r = session.post(endpoint, data=None)
if "Telemetry API has not been used in project" in r.text:
_warn(_TELEMETRY_API_DISABLED_WARNING % (project, project))

def _get_api_endpoint(client_cert_source=None):
"""Returns the API endpoint based on mTLS configuration and cert availability.

Args:
client_cert_source (Optional[bytes]): The client certificate source.

Returns:
str: The API endpoint to be used.
"""
use_mtls_endpoint_str = os.getenv(
"GOOGLE_API_USE_MTLS_ENDPOINT", MtlsEndpoint.AUTO.value
).lower()

try:
use_mtls_endpoint = MtlsEndpoint(use_mtls_endpoint_str)
except ValueError:
_warn(
"Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of "
"%s. Defaulting to %s."
% ([e.value for e in MtlsEndpoint], MtlsEndpoint.AUTO.value)
)
use_mtls_endpoint = MtlsEndpoint.AUTO

if (
use_mtls_endpoint in (MtlsEndpoint.ALWAYS, MtlsEndpoint.AUTO)
and client_cert_source
):
return _DEFAULT_MTLS_TELEMETRY_ENDPOINT

return _DEFAULT_TELEMETRY_ENDPOINT


def _use_client_cert_effective():
"""Returns whether client certificate should be used for mTLS.

This checks if the google-auth version supports should_use_client_cert
automatic mTLS enablement. Alternatively, it reads from the
GOOGLE_API_USE_CLIENT_CERTIFICATE env var.

Returns:
bool: whether client certificate should be used for mTLS.
"""
# check if google-auth version supports should_use_client_cert for automatic
# mTLS enablement
try:
return mtls.should_use_client_cert()
except (ImportError, AttributeError):
# if unsupported, fallback to reading from env var
use_client_cert_str = os.getenv(
"GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
).lower()
if use_client_cert_str not in ("true", "false"):
_warn(
"Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be"
" either `true` or `false`"
)
return use_client_cert_str == "true"

class AdkApp:
"""An ADK Application."""
Expand Down
Loading