Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

"""Interceptor for collecting Cloud Spanner metrics."""

import inspect
import logging
import re
from typing import Dict
from typing import Any, Dict

import grpc
from grpc_interceptor import ClientInterceptor

from .constants import GOOGLE_CLOUD_RESOURCE_KEY, SPANNER_METHOD_PREFIX
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory

logger = logging.getLogger(__name__)


class MetricsInterceptor(ClientInterceptor):
"""Interceptor that collects metrics for Cloud Spanner operations."""
Expand Down Expand Up @@ -115,17 +120,289 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
self._set_metrics_tracer_attributes(resources)

## Format method to be be spanner.<method name>
method_name = call_details.method.removeprefix(SPANNER_METHOD_PREFIX).replace(
"/", "."
)
method_str = call_details.method
if isinstance(method_str, bytes):
method_str = method_str.decode("utf-8")
method_name = method_str.removeprefix(SPANNER_METHOD_PREFIX).replace("/", ".")

tracer.set_method(method_name)
tracer.record_attempt_start()
response = invoked_method(request_or_iterator, call_details)
tracer.record_attempt_completion()

# Process and send GFE metrics if enabled
if tracer.gfe_enabled:
metadata = response.initial_metadata()
return _wrap_response(response, tracer)


def _wrap_response(response: Any, tracer: Any) -> Any:
"""Wraps the response if it is streaming, or records metrics immediately if unary."""
if hasattr(response, "__next__"):
return _StreamingResponseWrapper(response, tracer)
else:
# Unary call: execute completion and record metrics immediately
try:
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
return response
Comment thread
sinhasubham marked this conversation as resolved.


class AsyncMetricsInterceptor(
grpc.aio.UnaryUnaryClientInterceptor,
grpc.aio.UnaryStreamClientInterceptor,
grpc.aio.StreamUnaryClientInterceptor,
grpc.aio.StreamStreamClientInterceptor,
):
"""Async Interceptor that collects metrics for Cloud Spanner operations."""

async def intercept_unary_unary(self, continuation, client_call_details, request):
return await self._async_intercept(continuation, client_call_details, request)

async def intercept_unary_stream(self, continuation, client_call_details, request):
return await self._async_intercept(continuation, client_call_details, request)

async def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
return await self._async_intercept(
continuation, client_call_details, request_iterator
)

async def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
return await self._async_intercept(
continuation, client_call_details, request_iterator
)

async def _async_intercept(
self,
continuation: Any,
call_details: grpc.ClientCallDetails,
request_or_iterator: Any,
) -> Any:
# Implementation for async interceptor
factory = SpannerMetricsTracerFactory()
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None or not factory.enabled:
return await continuation(call_details, request_or_iterator)

if not (
tracer.client_attributes.get("project_id")
and tracer.client_attributes.get("instance_id")
and tracer.client_attributes.get("database")
):
resources = MetricsInterceptor._extract_resource_from_path(
call_details.metadata
)
MetricsInterceptor._set_metrics_tracer_attributes(resources)

method_str = call_details.method
if isinstance(method_str, bytes):
method_str = method_str.decode("utf-8")
method_name = method_str.removeprefix(SPANNER_METHOD_PREFIX).replace("/", ".")

tracer.set_method(method_name)
tracer.record_attempt_start()
response = await continuation(call_details, request_or_iterator)

if hasattr(response, "__anext__"):
return _AsyncStreamingResponseWrapper(response, tracer)
else:
return _AsyncUnaryResponseWrapper(response, tracer)


class _StreamingResponseWrapper:
"""Wrapper for streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False
self._iterator = None

def __iter__(self):
self._iterator = iter(self._response)
return self

def __next__(self):
if self._iterator is None:
self._iterator = iter(self._response)
try:
return next(self._iterator)
except StopIteration:
self._record_metrics()
raise
except Exception:
self._record_metrics()
raise

def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
try:
self._record_metrics()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)


class _AsyncUnaryResponseWrapper(grpc.aio.UnaryUnaryCall):
"""Wrapper for async unary RPC response to defer metrics recording until awaited."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False

def __await__(self):
async def _wait():
try:
return await self._response
finally:
await self._record_metrics()

return _wait().__await__()

async def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
res = self._response.initial_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
res = self._response.trailing_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
if not self._metrics_recorded:
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)


class _AsyncStreamingResponseWrapper(
grpc.aio.UnaryStreamCall,
grpc.aio.StreamUnaryCall,
grpc.aio.StreamStreamCall,
):
"""Wrapper for async streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False
self._iterator = None

def __aiter__(self):
if hasattr(self._response, "__aiter__"):
self._iterator = self._response.__aiter__()
else:
self._iterator = self._response
return self

async def __anext__(self):
if self._iterator is None:
if hasattr(self._response, "__aiter__"):
self._iterator = self._response.__aiter__()
else:
self._iterator = self._response
try:
return await self._iterator.__anext__()
except StopAsyncIteration:
await self._record_metrics()
raise
except Exception:
await self._record_metrics()
raise

async def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
res = self._response.initial_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
res = self._response.trailing_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
if not self._metrics_recorded:
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)
Loading
Loading