diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index acfe93f5..c01fb683 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -354,6 +354,7 @@ env: - name: LAUNCH_SERVICE_TEMPLATE_FOLDER value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates" {{- $model_cache := default dict .Values.modelCache }} + {{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }} - name: MODEL_CACHE_ENABLED value: {{ get $model_cache "enabled" | default false | quote }} - name: MODEL_CACHE_MOUNT_PATH @@ -404,6 +405,14 @@ env: - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} + {{- if $gcp_cloud_provider }} + - name: GCP_PROJECT_ID + value: {{ (.Values.gcp).project_id | default "" | quote }} + - name: PUBSUB_TOPIC_PREFIX + value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }} + - name: PUBSUB_SUBSCRIPTION_PREFIX + value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }} + {{- end }} {{- if eq .Values.context "circleci" }} - name: CIRCLECI value: "true" diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index 74c8cf44..6449a6b1 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -5,12 +5,13 @@ {{- $tag := .Values.tag }} {{- $message_broker := .Values.celeryBrokerType }} {{- $num_shards := .Values.celery_autoscaler.num_shards }} +{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}} {{- $broker_name := "redis-elasticache-message-broker-master" }} {{- if eq $message_broker "sqs" }} {{ $broker_name = "sqs-message-broker-master" }} {{- else if eq $message_broker "servicebus" }} {{ $broker_name = "servicebus-message-broker-master" }} -{{- else if and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }} +{{- else if $gcp_cloud_provider }} {{ $broker_name = "redis-gcp-memorystore-message-broker-master" }} {{- end }} apiVersion: apps/v1 @@ -89,6 +90,14 @@ spec: - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} + {{- if $gcp_cloud_provider }} + - name: GCP_PROJECT_ID + value: {{ (.Values.gcp).project_id | default "" | quote }} + - name: PUBSUB_TOPIC_PREFIX + value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }} + - name: PUBSUB_SUBSCRIPTION_PREFIX + value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }} + {{- end }} image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}" imagePullPolicy: Always name: main diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index 7509a88f..bae62ce3 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -94,3 +94,9 @@ utilityImages: # Additional GPU tolerations for endpoint pods gpuTolerations: [] + +# GCP configuration for GCP-based deployments +gcp: + project_id: "" + pubsub_topic_prefix: "launch-endpoint-id-" + pubsub_subscription_prefix: "launch-endpoint-id-" diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index b8dc6fb7..dd683a97 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -438,3 +438,9 @@ recommendedHardware: gpu_type: nvidia-hopper-h100 nodes_per_worker: 1 #serviceBuilderQueue: + +# GCP configuration for GCP-based deployments +gcp: + project_id: "your-gcp-project" + pubsub_topic_prefix: "launch-endpoint-id-" + pubsub_subscription_prefix: "launch-endpoint-id-" diff --git a/clients/python/setup.py b/clients/python/setup.py index 152ae510..8fbee7c7 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -5,5 +5,8 @@ python_requires=">=3.8", version="0.0.0.beta45", packages=find_packages(), + # types-setuptools 82.0.0+ tightened package_data to _DictLike; the literal dict + # still works at runtime, only the new stub disagrees. Suppress at the call site + # rather than down-pinning the stub (which would mask real future tightenings). package_data={"llmengine": ["py.typed"]}, # type: ignore[arg-type] ) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index f28425d8..8e68be55 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -95,6 +95,9 @@ from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) @@ -104,9 +107,6 @@ from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( QueueEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import ( - RedisQueueEndpointResourceDelegate, -) from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( SQSQueueEndpointResourceDelegate, ) @@ -248,8 +248,17 @@ def _get_external_interfaces( elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "gcp": - # GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate - queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client) + # Mirror the SQS_PROFILE env-first pattern: the Helm chart injects GCP_PROJECT_ID as a + # pod env var (from .Values.gcp.project_id), which is a different source than the YAML- + # rendered infra_service_config. Read the env first so the chart value reaches the delegate; + # the infra_config.gcp_project_id field handles setups that wire it via the config YAML. + gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id + if not gcp_project_id: + raise ValueError( + "cloud_provider=gcp requires GCP_PROJECT_ID env var " + "(via .Values.gcp.project_id) or infra.gcp_project_id in the service config." + ) + queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 6886174f..4d837d6e 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -54,6 +54,7 @@ class _InfraConfig: celery_enable_sha256: Optional[bool] = None docker_registry_type: Optional[str] = None debug_mode: Optional[bool] = None + gcp_project_id: Optional[str] = None @dataclass diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 36bd8e96..a64d55e9 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -27,6 +27,9 @@ from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, @@ -119,6 +122,16 @@ async def main(args: Any): queue_delegate = OnPremQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + # See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var; + # the infra_service_config YAML is a different source. Read the env first. + gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id + if not gcp_project_id: + raise ValueError( + "cloud_provider=gcp requires GCP_PROJECT_ID env var " + "(via .Values.gcp.project_id) or infra.gcp_project_id in the service config." + ) + queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 3d659c4a..16b8c46a 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -34,6 +34,9 @@ from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) @@ -90,6 +93,16 @@ async def run_batch_job( queue_delegate = OnPremQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + # See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var; + # the infra_service_config YAML is a different source. Read the env first. + gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id + if not gcp_project_id: + raise ValueError( + "cloud_provider=gcp requires GCP_PROJECT_ID env var " + "(via .Values.gcp.project_id) or infra.gcp_project_id in the service config." + ) + queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) @@ -110,6 +123,9 @@ async def run_batch_job( if infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway + elif infra_config().cloud_provider == "gcp": + inference_task_queue_gateway = redis_task_queue_gateway + infra_task_queue_gateway = redis_task_queue_gateway elif infra_config().cloud_provider == "onprem" or infra_config().celery_broker_type_redis: # On-prem uses Redis-based task queues inference_task_queue_gateway = redis_task_queue_gateway diff --git a/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..89b280ef --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, Optional + +from google.api_core import exceptions as gcp_exceptions +from google.cloud import pubsub_v1 +from google.protobuf import field_mask_pb2 +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + +GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS = 600 # Pub/Sub hard limit + + +class GcpPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + """ + Using GCP Pub/Sub (topic + subscription per endpoint). + + topic_prefix and subscription_prefix control the GCP resource name prefix. + The logical queue_name returned to callers always uses the canonical + QueueEndpointResourceDelegate.endpoint_id_to_queue_name format, independent + of these prefixes. + """ + + def __init__( + self, + project_id: str, + topic_prefix: str = "launch-endpoint-id-", + subscription_prefix: str = "launch-endpoint-id-", + ) -> None: + if not project_id: + raise ValueError( + "GcpPubSubQueueEndpointResourceDelegate requires a non-empty project_id; " + "set infra.gcp_project_id in the service config." + ) + self.project_id = project_id + self.topic_prefix = topic_prefix + self.subscription_prefix = subscription_prefix + # Lazily-initialized gRPC clients. Construction calls Google ADC which is + # unavailable in unit-test environments, so defer until first real use. + # The clients are then cached for the lifetime of the delegate. + self._publisher_client: Optional[pubsub_v1.PublisherClient] = None + self._subscriber_client: Optional[pubsub_v1.SubscriberClient] = None + + @property + def _publisher(self) -> pubsub_v1.PublisherClient: + if self._publisher_client is None: + self._publisher_client = pubsub_v1.PublisherClient() + return self._publisher_client + + @property + def _subscriber(self) -> pubsub_v1.SubscriberClient: + if self._subscriber_client is None: + self._subscriber_client = pubsub_v1.SubscriberClient() + return self._subscriber_client + + def _topic_id(self, endpoint_id: str) -> str: + return f"{self.topic_prefix}{endpoint_id}" + + def _subscription_id(self, endpoint_id: str) -> str: + return f"{self.subscription_prefix}{endpoint_id}" + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + queue_message_timeout_seconds: Optional[int] = None, + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}" + subscription_path = ( + f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}" + ) + ack_deadline = min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS) + + try: + self._publisher.create_topic(name=topic_path) + except gcp_exceptions.AlreadyExists: + pass + + try: + self._subscriber.create_subscription( + name=subscription_path, + topic=topic_path, + ack_deadline_seconds=ack_deadline, + ) + except gcp_exceptions.AlreadyExists: + try: + self._subscriber.update_subscription( + subscription=pubsub_v1.types.Subscription( + name=subscription_path, + ack_deadline_seconds=ack_deadline, + ), + update_mask=field_mask_pb2.FieldMask(paths=["ack_deadline_seconds"]), + ) + except gcp_exceptions.GoogleAPIError as e: + logger.warning( + f"Failed to update ack_deadline for Pub/Sub subscription {subscription_path}: {e}" + ) + + # Pub/Sub has no URL concept analogous to SQS queue URLs + return QueueInfo(queue_name, queue_url=None) + + async def delete_queue(self, endpoint_id: str) -> None: + subscription_path = ( + f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}" + ) + topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}" + + # Always attempt BOTH deletions so a failure on one doesn't leave the other resource + # orphaned (Greptile P1). NotFound is silent. Other GoogleAPIErrors are collected and + # surfaced together at the end so callers see every cleanup failure, not just the first. + errors: list[tuple[str, str, gcp_exceptions.GoogleAPIError]] = [] + + try: + self._subscriber.delete_subscription(subscription=subscription_path) + except gcp_exceptions.NotFound: + logger.info( + f"Could not find Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}" + ) + except gcp_exceptions.GoogleAPIError as e: + errors.append(("subscription", subscription_path, e)) + + try: + self._publisher.delete_topic(topic=topic_path) + except gcp_exceptions.NotFound: + logger.info(f"Could not find Pub/Sub topic {topic_path} for endpoint {endpoint_id}") + except gcp_exceptions.GoogleAPIError as e: + errors.append(("topic", topic_path, e)) + + if errors: + details = "; ".join( + f"Failed to delete Pub/Sub {kind} {path}: {err}" for kind, path, err in errors + ) + raise EndpointResourceInfraException( + f"Cleanup errors for endpoint {endpoint_id}: {details}" + ) from errors[0][2] + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + return { + "name": queue_name, + # Pub/Sub does not expose a synchronous undelivered message count; + # real observability requires the Cloud Monitoring API as a separate concern. + "num_undelivered_messages": -1, + } diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 3a028b57..d0d0ca59 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -101,6 +101,14 @@ async def get_resources( ) elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate resources.num_queued_items = int(sqs_attributes["active_message_count"]) + elif ( + "num_undelivered_messages" in sqs_attributes + ): # from GcpPubSubQueueEndpointResourceDelegate + # Pub/Sub returns -1 when num_undelivered_messages is not yet wired to Cloud Monitoring. + # Treat -1 as "unknown" and skip; downstream autoscaling expects non-negative counts. + gcp_count = int(sqs_attributes["num_undelivered_messages"]) + if gcp_count >= 0: + resources.num_queued_items = gcp_count return resources diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 07e849aa..716c463b 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -12,6 +12,12 @@ azure-storage-blob~=12.19.0 # GCP dependencies gcloud-aio-storage~=9.6 google-auth~=2.25.0 +google-cloud-pubsub>=2.18 +# google-cloud-pubsub transitively pulls opentelemetry-sdk, which flips +# common/startup_tracing/correlation.py's OTEL_AVAILABLE to True. Once that's +# True, tracer.py imports from opentelemetry-exporter-otlp-proto-grpc, which +# isn't otherwise a dependency. Pin it explicitly so the import resolves. +opentelemetry-exporter-otlp-proto-grpc google-cloud-artifact-registry~=1.21.0 google-cloud-secret-manager>=2.24.0 google-cloud-storage~=2.14.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 9cf220f7..e5a77684 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -190,6 +190,7 @@ google-api-core==2.29.0 # via # google-cloud-artifact-registry # google-cloud-core + # google-cloud-pubsub # google-cloud-secret-manager # google-cloud-storage google-auth==2.25.2 @@ -198,6 +199,7 @@ google-auth==2.25.2 # google-api-core # google-cloud-artifact-registry # google-cloud-core + # google-cloud-pubsub # google-cloud-secret-manager # google-cloud-storage # kubernetes @@ -205,6 +207,8 @@ google-cloud-artifact-registry==1.21.0 # via -r requirements.in google-cloud-core==2.5.0 # via google-cloud-storage +google-cloud-pubsub==2.38.0 + # via -r requirements.in google-cloud-secret-manager==2.24.0 # via -r requirements.in google-cloud-storage==2.14.0 @@ -220,6 +224,7 @@ googleapis-common-protos==1.72.0 # google-api-core # grpc-google-iam-v1 # grpcio-status + # opentelemetry-exporter-otlp-proto-grpc greenlet==3.3.2 # via # -r requirements.in @@ -227,16 +232,21 @@ greenlet==3.3.2 grpc-google-iam-v1==0.14.3 # via # google-cloud-artifact-registry + # google-cloud-pubsub # google-cloud-secret-manager grpcio==1.75.1 # via # google-api-core # google-cloud-artifact-registry + # google-cloud-pubsub # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status + # opentelemetry-exporter-otlp-proto-grpc grpcio-status==1.75.1 - # via google-api-core + # via + # google-api-core + # google-cloud-pubsub gunicorn==23.0.0 # via -r requirements.in h11==0.16.0 @@ -370,8 +380,27 @@ numpy==2.4.4 # transformers oauthlib==3.2.2 # via requests-oauthlib -opentelemetry-api==1.40.0 - # via ddtrace +opentelemetry-api==1.41.1 + # via + # ddtrace + # google-cloud-pubsub + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-exporter-otlp-proto-common==1.41.1 + # via opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-grpc==1.41.1 + # via -r requirements.in +opentelemetry-proto==1.41.1 + # via + # opentelemetry-exporter-otlp-proto-common + # opentelemetry-exporter-otlp-proto-grpc +opentelemetry-sdk==1.41.1 + # via + # google-cloud-pubsub + # opentelemetry-exporter-otlp-proto-grpc +opentelemetry-semantic-conventions==0.62b1 + # via opentelemetry-sdk orjson==3.11.7 # via -r requirements.in packaging==23.1 @@ -401,16 +430,19 @@ proto-plus==1.27.1 # via # google-api-core # google-cloud-artifact-registry + # google-cloud-pubsub # google-cloud-secret-manager protobuf==6.33.5 # via # -r requirements.in # google-api-core # google-cloud-artifact-registry + # google-cloud-pubsub # google-cloud-secret-manager # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status + # opentelemetry-proto # proto-plus psycopg2-binary==2.9.11 # via -r requirements.in @@ -607,6 +639,9 @@ typing-extensions==4.15.0 # grpcio # huggingface-hub # opentelemetry-api + # opentelemetry-exporter-otlp-proto-grpc + # opentelemetry-sdk + # opentelemetry-semantic-conventions # pydantic # pydantic-core # sqlalchemy diff --git a/model-engine/tests/unit/api/test_dependencies.py b/model-engine/tests/unit/api/test_dependencies.py index a2712827..62c535ce 100644 --- a/model-engine/tests/unit/api/test_dependencies.py +++ b/model-engine/tests/unit/api/test_dependencies.py @@ -8,8 +8,8 @@ GCSFilesystemGateway, GCSLLMArtifactGateway, ) -from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import ( - RedisQueueEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, ) from model_engine_server.infra.repositories import ( GARDockerRepository, @@ -119,7 +119,7 @@ def test_gcp_provider_selects_gcp_implementations(): ) assert isinstance( external_interfaces.resource_gateway.queue_delegate, - RedisQueueEndpointResourceDelegate, + GcpPubSubQueueEndpointResourceDelegate, ) diff --git a/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..2f9340c6 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py @@ -0,0 +1,206 @@ +from unittest.mock import MagicMock, patch + +import pytest +from google.api_core import exceptions as gcp_exceptions +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) + +MODULE_PATH = ( + "model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate" +) + +ENDPOINT_ID = "test_endpoint_id" +PROJECT_ID = "test-project" +TOPIC_PREFIX = "launch-endpoint-id-" +SUBSCRIPTION_PREFIX = "launch-endpoint-id-" +QUEUE_NAME = f"{TOPIC_PREFIX}{ENDPOINT_ID}" + + +@pytest.fixture +def mock_publisher(): + with patch(f"{MODULE_PATH}.pubsub_v1.PublisherClient") as mock_cls: + yield mock_cls.return_value + + +@pytest.fixture +def mock_subscriber(): + with patch(f"{MODULE_PATH}.pubsub_v1.SubscriberClient") as mock_cls: + yield mock_cls.return_value + + +@pytest.fixture +def delegate(mock_publisher, mock_subscriber): + return GcpPubSubQueueEndpointResourceDelegate(project_id=PROJECT_ID) + + +def test_init_empty_project_id_raises(): + with pytest.raises(ValueError, match="non-empty project_id"): + GcpPubSubQueueEndpointResourceDelegate(project_id="") + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_new(mock_publisher, mock_subscriber, delegate): + """Both topic and subscription are created when neither exists.""" + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={"team": "test"}, + ) + + topic_path = f"projects/{PROJECT_ID}/topics/{TOPIC_PREFIX}{ENDPOINT_ID}" + subscription_path = f"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION_PREFIX}{ENDPOINT_ID}" + + mock_publisher.create_topic.assert_called_once_with(name=topic_path) + mock_subscriber.create_subscription.assert_called_once_with( + name=subscription_path, + topic=topic_path, + ack_deadline_seconds=60, # default when timeout is None + ) + assert result.queue_name == QUEUE_NAME + assert result.queue_url is None + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_topic_already_exists( + mock_publisher, mock_subscriber, delegate +): + """AlreadyExists on topic creation is silenced; subscription still attempts creation.""" + mock_publisher.create_topic.side_effect = gcp_exceptions.AlreadyExists("topic exists") + + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + ) + + mock_subscriber.create_subscription.assert_called_once() + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_subscription_already_exists_updates_ack_deadline( + mock_publisher, mock_subscriber, delegate +): + """AlreadyExists on subscription triggers an update_subscription call.""" + mock_subscriber.create_subscription.side_effect = gcp_exceptions.AlreadyExists( + "subscription exists" + ) + + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + queue_message_timeout_seconds=120, + ) + + mock_publisher.create_topic.assert_called_once() + mock_subscriber.update_subscription.assert_called_once() + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_create_queue_subscription_already_exists_update_failure_is_warned( + mock_publisher, mock_subscriber, delegate +): + """update_subscription GoogleAPIError is swallowed with a warning (not raised).""" + mock_subscriber.create_subscription.side_effect = gcp_exceptions.AlreadyExists("exists") + mock_subscriber.update_subscription.side_effect = gcp_exceptions.GoogleAPIError("boom") + + # Should not raise + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + ) + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_not_found_silent( + mock_publisher, mock_subscriber, delegate +): + """NotFound on subscription deletion is silenced; topic deletion still attempts.""" + mock_subscriber.delete_subscription.side_effect = gcp_exceptions.NotFound("sub not found") + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + mock_subscriber.delete_subscription.assert_called_once() + mock_publisher.delete_topic.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_queue_topic_not_found_silent(mock_publisher, mock_subscriber, delegate): + """NotFound on topic deletion is silenced.""" + mock_publisher.delete_topic.side_effect = gcp_exceptions.NotFound("topic not found") + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + mock_subscriber.delete_subscription.assert_called_once() + mock_publisher.delete_topic.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_api_error_raises( + mock_publisher, mock_subscriber, delegate +): + """Non-NotFound GoogleAPIError on subscription deletion raises EndpointResourceInfraException.""" + mock_subscriber.delete_subscription.side_effect = gcp_exceptions.GoogleAPIError("network error") + + with pytest.raises( + EndpointResourceInfraException, match="Failed to delete Pub/Sub subscription" + ): + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + +@pytest.mark.asyncio +async def test_delete_queue_topic_api_error_raises(mock_publisher, mock_subscriber, delegate): + """Non-NotFound GoogleAPIError on topic deletion raises EndpointResourceInfraException.""" + mock_publisher.delete_topic.side_effect = gcp_exceptions.GoogleAPIError("network error") + + with pytest.raises(EndpointResourceInfraException, match="Failed to delete Pub/Sub topic"): + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_failure_does_not_orphan_topic( + mock_publisher, mock_subscriber, delegate +): + """When subscription delete fails, topic delete must still be attempted (no orphan).""" + mock_subscriber.delete_subscription.side_effect = gcp_exceptions.GoogleAPIError("transient") + + with pytest.raises(EndpointResourceInfraException): + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + # The key invariant: topic deletion was attempted even though subscription deletion failed. + mock_publisher.delete_topic.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_deleted_before_topic( + mock_publisher, mock_subscriber, delegate +): + """Subscription must be deleted before topic (Pub/Sub requirement to avoid race).""" + parent = MagicMock() + parent.attach_mock(mock_subscriber.delete_subscription, "sub_del") + parent.attach_mock(mock_publisher.delete_topic, "topic_del") + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + call_order = [c[0] for c in parent.mock_calls] + assert call_order == ["sub_del", "topic_del"] + + +@pytest.mark.asyncio +async def test_get_queue_attributes_returns_expected_shape(delegate): + """get_queue_attributes returns a dict with 'name' and 'num_undelivered_messages'.""" + attrs = await delegate.get_queue_attributes(endpoint_id=ENDPOINT_ID) + + assert attrs["name"] == QUEUE_NAME + assert "num_undelivered_messages" in attrs + assert attrs["num_undelivered_messages"] == -1