diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 9af27e4..7c65132 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -48,21 +48,30 @@ class ProxyOptions: use_insecure_credentials: Whether to use insecure gRPC credentials for the proxy server. xla_flags: A list of XLA flags to pass to the proxy server. + sidecar_name: The name of the colocated Python sidecar to register with the + proxy. When set (e.g. to "external"), the proxy passes + ``--sidecar_name=`` so that ``jax.experimental.colocated_python`` + can reach the sidecar containers running on the worker pods. Leave as + ``None`` when no sidecar is deployed. """ use_insecure_credentials: bool = False xla_flags: list[str] = dataclasses.field(default_factory=list) + sidecar_name: str | None = None @classmethod def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": """Creates a ProxyOptions object from a list of 'key:value' strings.""" use_insecure = False xla_flags = [] + sidecar_name = None for option in options or []: if ":" in option: key, value = option.split(":", 1) key_strip = key.strip().lower() if key_strip == "use_insecure_credentials": use_insecure = value.strip().lower() == "true" + elif key.strip().lower() == "sidecar_name": + sidecar_name = value.strip() elif key_strip == "xla_flags": val_strip = value.strip() if ( @@ -78,7 +87,10 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": if xla_flags: validators.validate_xla_flags(xla_flags) - return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags) + return cls( + use_insecure_credentials=use_insecure, xla_flags=xla_flags, + sidecar_name=sidecar_name, + ) def _deploy_pathways_proxy_server( @@ -134,6 +146,12 @@ def _deploy_pathways_proxy_server( ) proxy_args_str = "\n" + proxy_args_str + sidecar_args_str = "" + if proxy_options.sidecar_name: + sidecar_args_str = ( + f"- --sidecar_name={proxy_options.sidecar_name}" + ) + template = string.Template(yaml_template) substituted_yaml = template.substitute( PROXY_JOB_NAME=proxy_job_name, @@ -145,6 +163,7 @@ def _deploy_pathways_proxy_server( PROXY_SERVER_IMAGE=proxy_server_image, PROXY_ENV=proxy_env_str, PROXY_ARGS=proxy_args_str, + SIDECAR_ARGS=sidecar_args_str, ) _logger.info("Deploying Pathways proxy: %s", proxy_job_name) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml index e0ee244..6a55e8c 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -21,6 +21,7 @@ spec: - --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT} - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} - --virtual_slices=${EXPECTED_INSTANCES}${PROXY_ARGS} + ${SIDECAR_ARGS} env: ${PROXY_ENV} ports: