Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions charts/model-engine/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ env:
- name: LAUNCH_SERVICE_TEMPLATE_FOLDER
value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates"
{{- $model_cache := default dict .Values.modelCache }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }}
- name: MODEL_CACHE_ENABLED
value: {{ get $model_cache "enabled" | default false | quote }}
- name: MODEL_CACHE_MOUNT_PATH
Expand Down Expand Up @@ -404,6 +405,14 @@ env:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
{{- if eq .Values.context "circleci" }}
- name: CIRCLECI
value: "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
{{- $tag := .Values.tag }}
{{- $message_broker := .Values.celeryBrokerType }}
{{- $num_shards := .Values.celery_autoscaler.num_shards }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}}
{{- $broker_name := "redis-elasticache-message-broker-master" }}
{{- if eq $message_broker "sqs" }}
{{ $broker_name = "sqs-message-broker-master" }}
{{- else if eq $message_broker "servicebus" }}
{{ $broker_name = "servicebus-message-broker-master" }}
{{- else if and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }}
{{- else if $gcp_cloud_provider }}
{{ $broker_name = "redis-gcp-memorystore-message-broker-master" }}
{{- end }}
apiVersion: apps/v1
Expand Down Expand Up @@ -89,6 +90,14 @@ spec:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}"
imagePullPolicy: Always
name: main
Expand Down
6 changes: 6 additions & 0 deletions charts/model-engine/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ utilityImages:

# Additional GPU tolerations for endpoint pods
gpuTolerations: []

# GCP configuration for GCP-based deployments
gcp:
project_id: ""
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
6 changes: 6 additions & 0 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,9 @@ recommendedHardware:
gpu_type: nvidia-hopper-h100
nodes_per_worker: 1
#serviceBuilderQueue:

# GCP configuration for GCP-based deployments
gcp:
project_id: "your-gcp-project"
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
3 changes: 3 additions & 0 deletions clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@
python_requires=">=3.8",
version="0.0.0.beta45",
packages=find_packages(),
# types-setuptools 82.0.0+ tightened package_data to _DictLike; the literal dict
# still works at runtime, only the new stub disagrees. Suppress at the call site
# rather than down-pinning the stub (which would mask real future tightenings).
package_data={"llmengine": ["py.typed"]}, # type: ignore[arg-type]
)
19 changes: 14 additions & 5 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
Expand All @@ -104,9 +107,6 @@
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.redis_queue_endpoint_resource_delegate import (
RedisQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
Expand Down Expand Up @@ -248,8 +248,17 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client)
# Mirror the SQS_PROFILE env-first pattern: the Helm chart injects GCP_PROJECT_ID as a
# pod env var (from .Values.gcp.project_id), which is a different source than the YAML-
# rendered infra_service_config. Read the env first so the chart value reaches the delegate;
# the infra_config.gcp_project_id field handles setups that wire it via the config YAML.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class _InfraConfig:
celery_enable_sha256: Optional[bool] = None
docker_registry_type: Optional[str] = None
debug_mode: Optional[bool] = None
gcp_project_id: Optional[str] = None


@dataclass
Expand Down
13 changes: 13 additions & 0 deletions model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
Expand Down Expand Up @@ -119,6 +122,16 @@ async def main(args: Any):
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var;
# the infra_service_config YAML is a different source. Read the env first.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
Expand Down Expand Up @@ -90,6 +93,16 @@ async def run_batch_job(
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# See dependencies.py for rationale: Helm injects GCP_PROJECT_ID as a pod env var;
# the infra_service_config YAML is a different source. Read the env first.
gcp_project_id = os.getenv("GCP_PROJECT_ID") or infra_config().gcp_project_id
if not gcp_project_id:
raise ValueError(
"cloud_provider=gcp requires GCP_PROJECT_ID env var "
"(via .Values.gcp.project_id) or infra.gcp_project_id in the service config."
)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(project_id=gcp_project_id)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -110,6 +123,9 @@ async def run_batch_job(
if infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
elif infra_config().cloud_provider == "onprem" or infra_config().celery_broker_type_redis:
# On-prem uses Redis-based task queues
inference_task_queue_gateway = redis_task_queue_gateway
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Any, Dict, Optional

from google.api_core import exceptions as gcp_exceptions
from google.cloud import pubsub_v1
from google.protobuf import field_mask_pb2
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import EndpointResourceInfraException
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
QueueInfo,
)

logger = make_logger(logger_name())

GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS = 600 # Pub/Sub hard limit


class GcpPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate):
"""
Using GCP Pub/Sub (topic + subscription per endpoint).

topic_prefix and subscription_prefix control the GCP resource name prefix.
The logical queue_name returned to callers always uses the canonical
QueueEndpointResourceDelegate.endpoint_id_to_queue_name format, independent
of these prefixes.
"""

def __init__(
self,
project_id: str,
topic_prefix: str = "launch-endpoint-id-",
subscription_prefix: str = "launch-endpoint-id-",
) -> None:
if not project_id:
raise ValueError(
"GcpPubSubQueueEndpointResourceDelegate requires a non-empty project_id; "
"set infra.gcp_project_id in the service config."
)
self.project_id = project_id
self.topic_prefix = topic_prefix
self.subscription_prefix = subscription_prefix
# Lazily-initialized gRPC clients. Construction calls Google ADC which is
# unavailable in unit-test environments, so defer until first real use.
# The clients are then cached for the lifetime of the delegate.
self._publisher_client: Optional[pubsub_v1.PublisherClient] = None
self._subscriber_client: Optional[pubsub_v1.SubscriberClient] = None

@property
def _publisher(self) -> pubsub_v1.PublisherClient:
if self._publisher_client is None:
self._publisher_client = pubsub_v1.PublisherClient()
return self._publisher_client

@property
def _subscriber(self) -> pubsub_v1.SubscriberClient:
if self._subscriber_client is None:
self._subscriber_client = pubsub_v1.SubscriberClient()
return self._subscriber_client

def _topic_id(self, endpoint_id: str) -> str:
return f"{self.topic_prefix}{endpoint_id}"

def _subscription_id(self, endpoint_id: str) -> str:
return f"{self.subscription_prefix}{endpoint_id}"

async def create_queue_if_not_exists(
self,
endpoint_id: str,
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_seconds: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"
subscription_path = (
f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
)
ack_deadline = min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Pub/Sub enforces a 10–600 second range for ack_deadline_seconds. The current expression only clamps to the 600-second ceiling; values of 1–9 (a valid user-supplied queue_message_timeout_seconds) will be forwarded to the API and rejected with INVALID_ARGUMENT. Add a lower bound of 10 to mirror the ceiling clamp.

Suggested change
ack_deadline = min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS)
GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS = 10 # Pub/Sub hard minimum
ack_deadline = max(
GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS,
min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS),
)
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py
Line: 64

Comment:
Pub/Sub enforces a 10–600 second range for `ack_deadline_seconds`. The current expression only clamps to the 600-second ceiling; values of 1–9 (a valid user-supplied `queue_message_timeout_seconds`) will be forwarded to the API and rejected with `INVALID_ARGUMENT`. Add a lower bound of 10 to mirror the ceiling clamp.

```suggestion
        GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS = 10  # Pub/Sub hard minimum
        ack_deadline = max(
            GCP_PUBSUB_MIN_ACK_DEADLINE_SECONDS,
            min(queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS),
        )
```

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex


try:
self._publisher.create_topic(name=topic_path)
except gcp_exceptions.AlreadyExists:
pass

try:
self._subscriber.create_subscription(
name=subscription_path,
topic=topic_path,
ack_deadline_seconds=ack_deadline,
)
except gcp_exceptions.AlreadyExists:
try:
self._subscriber.update_subscription(
subscription=pubsub_v1.types.Subscription(
name=subscription_path,
ack_deadline_seconds=ack_deadline,
),
update_mask=field_mask_pb2.FieldMask(paths=["ack_deadline_seconds"]),
)
except gcp_exceptions.GoogleAPIError as e:
logger.warning(
f"Failed to update ack_deadline for Pub/Sub subscription {subscription_path}: {e}"
)

# Pub/Sub has no URL concept analogous to SQS queue URLs
return QueueInfo(queue_name, queue_url=None)

async def delete_queue(self, endpoint_id: str) -> None:
subscription_path = (
f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
)
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"

# Always attempt BOTH deletions so a failure on one doesn't leave the other resource
# orphaned (Greptile P1). NotFound is silent. Other GoogleAPIErrors are collected and
# surfaced together at the end so callers see every cleanup failure, not just the first.
errors: list[tuple[str, str, gcp_exceptions.GoogleAPIError]] = []

try:
self._subscriber.delete_subscription(subscription=subscription_path)
except gcp_exceptions.NotFound:
logger.info(
f"Could not find Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}"
)
except gcp_exceptions.GoogleAPIError as e:
errors.append(("subscription", subscription_path, e))

try:
self._publisher.delete_topic(topic=topic_path)
except gcp_exceptions.NotFound:
logger.info(f"Could not find Pub/Sub topic {topic_path} for endpoint {endpoint_id}")
except gcp_exceptions.GoogleAPIError as e:
errors.append(("topic", topic_path, e))

if errors:
details = "; ".join(
f"Failed to delete Pub/Sub {kind} {path}: {err}" for kind, path, err in errors
)
raise EndpointResourceInfraException(
f"Cleanup errors for endpoint {endpoint_id}: {details}"
) from errors[0][2]

async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
return {
"name": queue_name,
# Pub/Sub does not expose a synchronous undelivered message count;
# real observability requires the Cloud Monitoring API as a separate concern.
"num_undelivered_messages": -1,
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ async def get_resources(
)
elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate
resources.num_queued_items = int(sqs_attributes["active_message_count"])
elif (
"num_undelivered_messages" in sqs_attributes
): # from GcpPubSubQueueEndpointResourceDelegate
# Pub/Sub returns -1 when num_undelivered_messages is not yet wired to Cloud Monitoring.
# Treat -1 as "unknown" and skip; downstream autoscaling expects non-negative counts.
gcp_count = int(sqs_attributes["num_undelivered_messages"])
if gcp_count >= 0:
resources.num_queued_items = gcp_count

return resources

Expand Down
6 changes: 6 additions & 0 deletions model-engine/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ azure-storage-blob~=12.19.0
# GCP dependencies
gcloud-aio-storage~=9.6
google-auth~=2.25.0
google-cloud-pubsub>=2.18
# google-cloud-pubsub transitively pulls opentelemetry-sdk, which flips
# common/startup_tracing/correlation.py's OTEL_AVAILABLE to True. Once that's
# True, tracer.py imports from opentelemetry-exporter-otlp-proto-grpc, which
# isn't otherwise a dependency. Pin it explicitly so the import resolves.
opentelemetry-exporter-otlp-proto-grpc
google-cloud-artifact-registry~=1.21.0
google-cloud-secret-manager>=2.24.0
google-cloud-storage~=2.14.0
Expand Down
Loading