diff --git a/pathwaysutils/experimental/gke/jobset.py b/pathwaysutils/experimental/gke/jobset.py index a29441f..149167c 100644 --- a/pathwaysutils/experimental/gke/jobset.py +++ b/pathwaysutils/experimental/gke/jobset.py @@ -9,12 +9,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pathways JobSet generator and builder (Head Job Config).""" +"""Pathways JobSet generator and builder (with Worker Job Config).""" import json import logging -from typing import Any, Mapping +import math +import time +from typing import Any, Mapping, Sequence from kubernetes import client +from kubernetes import config as k8s_config +import yaml # GKE sidecar containers restartPolicy compatibility placeholder. @@ -33,6 +37,7 @@ PATHWAYS_PROXY_PORT = 29000 PATHWAYS_RM_PORT = 29001 +PATHWAYS_WORKER_PORT = 29005 MACHINE_TYPE_TO_TPU_VERSION_MAP = { "tpu7x-standard-4t": "tpu7x", @@ -77,7 +82,7 @@ def __init__(self, data): class PathwaysJobSet: - """Generates JobSet configuration for Pathways (with Head Job Config).""" + """JobSet configuration generator for Pathways.""" def __init__( self, @@ -90,11 +95,14 @@ def __init__( user_pod_template: Mapping[str, Any] | None = None, main_container_name: str = "main", max_restarts: int = 0, + max_slice_restarts: int = 0, + termination_grace_period_seconds: int | None = None, pathways_version: str = "latest", jobset_api_version: str = "v1alpha2", elastic_slices: int = 0, labels: Mapping[str, str] | None = None, annotations: Mapping[str, str] | None = None, + shared_pathways_service: bool = False, ): """Initializes the instance. @@ -114,6 +122,13 @@ def __init__( labels: Optional labels for the JobSet. annotations: Optional annotations for the JobSet. """ + if shared_pathways_service and user_pod_template: + raise ValueError( + "Cannot enable shared_pathways_service when user_pod_template is" + " provided." + ) + self._shared_pathways_service = shared_pathways_service + self._name = name self._namespace = namespace self._jobset_api_version = jobset_api_version @@ -126,6 +141,19 @@ def __init__( if not tpu_version: raise ValueError(f"Unsupported TPU type: {tpu_type}") + gke_accel_type = MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP.get( + tpu_type.lower() + ) + + # Calculate VMs. + dims = [int(x) for x in topology.split("x")] + total_chips = math.prod(dims) + chips_per_vm = 8 if tpu_type.lower().endswith("8t") else 4 + if total_chips < chips_per_vm: + num_vms = 1 + else: + num_vms = total_chips // chips_per_vm + instance_type = f"{tpu_version}:{topology}" image_tag = pathways_version @@ -138,31 +166,37 @@ def __init__( user_pod_template=user_pod_template, main_container_name=main_container_name, elastic_slices=elastic_slices, + shared_pathways_service=shared_pathways_service, ) - # Build minimal worker template (placeholder) - self._worker_job_template = self._build_minimal_job_template("worker") + # Build worker template. + self._worker_job_template = self._build_worker_job_template( + name=name, + pathways_dir=pathways_dir, + num_slices=num_slices, + num_vms=num_vms, + chips_per_vm=chips_per_vm, + gke_accel_type=gke_accel_type, + topology=topology, + image_tag=image_tag, + max_slice_restarts=max_slice_restarts, + termination_grace_period_seconds=termination_grace_period_seconds, + ) self._success_policy = None - if user_pod_template: + if user_pod_template or shared_pathways_service: self._success_policy = { "operator": "All", "targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME], } - def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec: - """Builds a minimal job template for a given role.""" - pod_spec = client.V1PodSpec( - containers=[ - client.V1Container(name=f"placeholder-{role}", image="ubuntu") - ] - ) - job_spec = client.V1JobSpec( - template=client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta(labels={"role": role}), spec=pod_spec - ) - ) - return client.V1JobTemplateSpec(spec=job_spec) + @property + def head_job_template(self) -> client.V1JobTemplateSpec: + return self._head_job_template + + @property + def worker_job_template(self) -> client.V1JobTemplateSpec: + return self._worker_job_template def _build_head_job_template( self, @@ -173,6 +207,7 @@ def _build_head_job_template( user_pod_template: Mapping[str, Any] | None, main_container_name: str, elastic_slices: int, + shared_pathways_service: bool, ) -> client.V1JobTemplateSpec: """Builds the head job template for the JobSet. @@ -331,10 +366,13 @@ def _build_head_job_template( labels = user_pod_template.get("metadata", {}).get("labels", {}) else: # Headless mode. + containers = [rm_container] + if not shared_pathways_service: + containers.append(proxy_container) head_pod_spec = client.V1PodSpec( host_network=True, dns_policy="ClusterFirstWithHostNet", - containers=[rm_container, proxy_container], + containers=containers, ) annotations = {} labels = {} @@ -365,6 +403,315 @@ def _build_head_job_template( ) return head_job_template + def _build_worker_job_template( + self, + name: str, + pathways_dir: str, + num_slices: int, + num_vms: int, + chips_per_vm: int, + gke_accel_type: str, + topology: str, + image_tag: str, + max_slice_restarts: int, + termination_grace_period_seconds: int | None, + ) -> client.V1JobTemplateSpec: + worker_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}" + + args = [ + f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}", + f"--server_port={PATHWAYS_WORKER_PORT}", + f"--gcs_scratch_location={pathways_dir}", + ] + worker_env = [ + client.V1EnvVar(name="TPU_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"), + client.V1EnvVar(name="MEGASCALE_GRPC_ENABLE_XOR_TRACER", value="false"), + client.V1EnvVar( + name="MEGASCALE_NUM_SLICES", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']" + ) + ), + ), + client.V1EnvVar( + name="JOBSET_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.annotations['jobset.sigs.k8s.io/jobset-name']" + ) + ) + ), + ), + client.V1EnvVar( + name="REPLICATED_JOB_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']" + ) + ), + ), + client.V1EnvVar( + name="MEGASCALE_SLICE_ID", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.labels['jobset.sigs.k8s.io/job-index']" + ) + ), + ), + client.V1EnvVar( + name="PATHWAYS_HEAD", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + client.V1EnvVar( + name="MEGASCALE_COORDINATOR_ADDRESS", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + ] + + worker_container = client.V1Container( + name="pathways-worker", + image=worker_image, + image_pull_policy="Always", + args=args, + env=worker_env, + ports=[ + client.V1ContainerPort( + container_port=PATHWAYS_WORKER_PORT, protocol="TCP" + ), + client.V1ContainerPort(container_port=29006, protocol="TCP"), + client.V1ContainerPort(container_port=8471, protocol="TCP"), + client.V1ContainerPort(container_port=8080, protocol="TCP"), + ], + volume_mounts=[ + client.V1VolumeMount(name="shared-tmp", mount_path="/tmp") + ], + resources=client.V1ResourceRequirements( + limits={"google.com/tpu": str(chips_per_vm)} + ), + ) + + node_selector = { + "cloud.google.com/gke-tpu-accelerator": gke_accel_type, + "cloud.google.com/gke-tpu-topology": topology, + } + + backoff_limit = num_vms * 4 + if max_slice_restarts > 0: + backoff_limit = num_vms * max_slice_restarts + + worker_pod_spec = client.V1PodSpec( + containers=[worker_container], + node_selector=node_selector, + volumes=[ + client.V1Volume( + name="shared-tmp", + host_path=client.V1HostPathVolumeSource( + path="/tmp", type="DirectoryOrCreate" + ), + ) + ], + host_network=True, + dns_policy="ClusterFirstWithHostNet", + restart_policy="OnFailure", + ) + if termination_grace_period_seconds is not None: + worker_pod_spec.termination_grace_period_seconds = ( + termination_grace_period_seconds + ) + + worker_job_template = client.V1JobTemplateSpec( + metadata=client.V1ObjectMeta(), + spec=client.V1JobSpec( + backoff_limit=backoff_limit, + completion_mode="Indexed", + completions=num_vms, + parallelism=num_vms, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + annotations={ + "alpha.jobset.sigs.k8s.io/exclusive-topology": ( + "cloud.google.com/gke-nodepool" + ) + } + ), + spec=worker_pod_spec, + ), + ), + ) + return worker_job_template + + def add_colocated_python(self) -> "PathwaysJobSet": + """Adds colocated python sidecar to the worker pods.""" + pod_spec = self._worker_job_template.spec.template.spec + + # Add shared memory volume if not exists. + volumes = pod_spec.volumes or [] + shm_volume_name = "shared-memory" + shm_exists = any(v.name == shm_volume_name for v in volumes) + if not shm_exists: + volumes.append( + client.V1Volume( + name=shm_volume_name, + empty_dir=client.V1EmptyDirVolumeSource( + medium="Memory", size_limit="100Gi" + ), + ) + ) + pod_spec.volumes = volumes + + # Add colocated python container. + colocated_container = client.V1Container( + name="colocated-python-sidecar", + image="gcr.io/cloud-tpu-multipod-dev/colocated-python:latest", + image_pull_policy="Always", + env=[ + client.V1EnvVar(name="GRPC_SERVER_ADDRESS", value="0.0.0.0:50051") + ], + ports=[client.V1ContainerPort(container_port=50051)], + volume_mounts=[ + client.V1VolumeMount(name="shared-tmp", mount_path="/tmp"), + client.V1VolumeMount(name=shm_volume_name, mount_path="/tmp/ifrt_proxy"), + ], + restart_policy="Always", + ) + + init_containers = pod_spec.init_containers or [] + init_containers.append(colocated_container) + pod_spec.init_containers = init_containers + + # Add volume mount to pathways-worker. + for container in pod_spec.containers: + if container.name == "pathways-worker": + volume_mounts = container.volume_mounts or [] + volume_mounts.append( + client.V1VolumeMount( + name=shm_volume_name, mount_path="/tmp/ifrt_proxy" + ) + ) + container.volume_mounts = volume_mounts + # Add env var for shm dir. + env = container.env or [] + env.append( + client.V1EnvVar( + name="cloud_pathways_sidecar_shm_directory", + value="/tmp/ifrt_proxy", + ) + ) + container.env = env + break + + return self + + def add_gcsfuse( + self, + containers: str, + mount_path: str, + bucket: str, + read_only: bool = False, + ) -> "PathwaysJobSet": + """Adds GCSFuse mount to specified containers.""" + target_templates = [] + if containers in ("head", "all"): + target_templates.append((PATHWAYS_HEAD_JOB_NAME, self._head_job_template)) + if containers in ("worker", "all"): + target_templates.append(("worker", self._worker_job_template)) + + bucket_hash = abs(hash(bucket)) % (10**8) + volume_name = f"gcsfuse-{bucket_hash}" + + for job_name, job_template_obj in target_templates: + # 1. Add annotation. + job_template_metadata = job_template_obj.metadata or client.V1ObjectMeta() + job_template_annotations = job_template_metadata.annotations or {} + job_template_annotations["gke-gcsfuse/volumes"] = "true" + job_template_metadata.annotations = job_template_annotations + job_template_obj.metadata = job_template_metadata + + pod_template = job_template_obj.spec.template + pod_metadata = pod_template.metadata or client.V1ObjectMeta() + pod_annotations = pod_metadata.annotations or {} + pod_annotations["gke-gcsfuse/volumes"] = "true" + pod_metadata.annotations = pod_annotations + pod_template.metadata = pod_metadata + + pod_spec = pod_template.spec + + # 2. Add volume. + volumes = pod_spec.volumes or [] + vol_exists = any(v.name == volume_name for v in volumes) + if not vol_exists: + volumes.append( + client.V1Volume( + name=volume_name, + csi=client.V1CSIVolumeSource( + driver="gcsfuse.csi.storage.gke.io", + volume_attributes={"bucketName": bucket}, + ), + ) + ) + pod_spec.volumes = volumes + + # 3. Add volumeMount to containers. + container_names = [] + if job_name == PATHWAYS_HEAD_JOB_NAME: + has_user_container = len(pod_spec.containers) > 2 or ( + len(pod_spec.containers) == 1 + and pod_spec.containers[0].name != "pathways-rm" + ) + if has_user_container: + container_names = [ + c.name + for c in pod_spec.containers + if c.name not in ("pathways-rm", "pathways-proxy") + ] + else: + container_names = ["pathways-rm", "pathways-proxy"] + else: + container_names = ["pathways-worker"] + + for container in pod_spec.containers: + if container.name in container_names: + volume_mounts = container.volume_mounts or [] + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_path, + read_only=read_only, + ) + ) + container.volume_mounts = volume_mounts + + # Also check initContainers. + for container in pod_spec.init_containers or []: + if container.name in container_names: + volume_mounts = container.volume_mounts or [] + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_path, + read_only=read_only, + ) + ) + container.volume_mounts = volume_mounts + + return self + def _compile_config(self) -> dict[str, Any]: """Compiles the JobSet configuration into a dictionary.""" with client.ApiClient() as api_client: @@ -375,29 +722,43 @@ def _compile_config(self) -> dict[str, Any]: self._worker_job_template ) - replicated_jobs = [ - { - "name": PATHWAYS_HEAD_JOB_NAME, - "replicas": 1, - "template": serialized_head, - }, - { - "name": PATHWAYS_WORKER_JOB_NAME, - "replicas": self._worker_replicas, - "template": serialized_worker, - }, - ] + head_job = { + "name": PATHWAYS_HEAD_JOB_NAME, + "replicas": 1, + "template": serialized_head, + } + worker_job = { + "name": PATHWAYS_WORKER_JOB_NAME, + "replicas": self._worker_replicas, + "template": serialized_worker, + } + + coordinator = { + "replicatedJob": PATHWAYS_HEAD_JOB_NAME, + } + + failure_policy = { + "restartStrategy": "Recreate", + } + if self._max_restarts > 0: + failure_policy["maxRestarts"] = self._max_restarts jobset_config = { - "apiVersion": f"jobset.sigs.k8s.io/{self._jobset_api_version}", + "apiVersion": f"jobset.x-k8s.io/{self._jobset_api_version}", "kind": "JobSet", "metadata": { "name": self._name, "namespace": self._namespace, }, "spec": { - "failurePolicy": {"maxRestarts": self._max_restarts}, - "replicatedJobs": replicated_jobs, + "startupPolicy": {"startupPolicyOrder": "InOrder"}, + "failurePolicy": failure_policy, + "network": { + "enableDNSHostnames": True, + "publishNotReadyAddresses": True, + }, + "coordinator": coordinator, + "replicatedJobs": [head_job, worker_job], }, } if self._labels: @@ -412,3 +773,142 @@ def _compile_config(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]: """Returns the JobSet configuration as a dictionary.""" return self._compile_config() + + def export_yaml(self, filepath: str) -> None: + """Exports the JobSet configuration to a YAML file.""" + with open(filepath, "w") as f: + yaml.dump(self.to_dict(), f, default_flow_style=False) + + @classmethod + def import_yaml(cls, filepath: str) -> "PathwaysJobSet": + """Imports a JobSet configuration from a YAML file.""" + with open(filepath, "r") as f: + config = yaml.safe_load(f) + + cls._validate_config(config) + + instance = cls.__new__(cls) + instance._name = config["metadata"]["name"] + instance._namespace = config["metadata"].get("namespace", "default") + api_version_parts = config.get("apiVersion", "").split("/") + instance._jobset_api_version = ( + api_version_parts[-1] if len(api_version_parts) > 1 else "v1alpha2" + ) + instance._max_restarts = ( + config["spec"].get("failurePolicy", {}).get("maxRestarts", 0) + ) + instance._labels = config["metadata"].get("labels", {}) + instance._annotations = config["metadata"].get("annotations", {}) + + # Extract replicated jobs and deserialize. + instance._head_job_template = None + instance._worker_job_template = None + + with client.ApiClient() as api_client: + for job in config["spec"]["replicatedJobs"]: + if job["name"] == PATHWAYS_HEAD_JOB_NAME: + instance._head_job_template = _deserialize_dict( + api_client, job["template"], client.V1JobTemplateSpec + ) + elif job["name"] in ("worker", PATHWAYS_WORKER_JOB_NAME): + instance._worker_job_template = _deserialize_dict( + api_client, job["template"], client.V1JobTemplateSpec + ) + instance._worker_replicas = job["replicas"] + + instance._success_policy = config["spec"].get("successPolicy") + return instance + + @classmethod + def _validate_config(cls, config: dict[str, Any]) -> None: + """Validates that the config is a valid Pathways JobSet.""" + if config.get("kind") != "JobSet": + raise ValueError("Resource kind is not JobSet") + jobs = { + j["name"]: j for j in config.get("spec", {}).get("replicatedJobs", []) + } + if "head" not in jobs and PATHWAYS_HEAD_JOB_NAME not in jobs: + raise ValueError( + f"Missing head replicated job ('head' or '{PATHWAYS_HEAD_JOB_NAME}')" + ) + if "worker" not in jobs and PATHWAYS_WORKER_JOB_NAME not in jobs: + raise ValueError( + "Missing worker replicated job ('worker' or" + f" '{PATHWAYS_WORKER_JOB_NAME}')" + ) + + def apply( + self, recreate: bool = False, field_manager: str = "pathwaysutils" + ) -> None: + """Applies the JobSet to the GKE cluster.""" + + try: + k8s_config.load_kube_config() + except Exception: # pylint: disable=broad-except + try: + k8s_config.load_incluster_config() + except Exception as e: + raise RuntimeError("Failed to load Kubernetes configuration") from e + + api = client.CustomObjectsApi() + group = "jobset.x-k8s.io" + version = self._jobset_api_version + plural = "jobsets" + + exists = False + try: + api.get_namespaced_custom_object( + group, version, self._namespace, plural, self._name + ) + exists = True + except client.rest.ApiException as e: + if e.status != 404: + raise + + if exists: + if recreate: + _logger.info( + "JobSet %s already exists. Deleting it first...", self._name + ) + api.delete_namespaced_custom_object( + group, version, self._namespace, plural, self._name + ) + + # Poll for deletion. + max_retries = 30 + for i in range(max_retries): + try: + api.get_namespaced_custom_object( + group, version, self._namespace, plural, self._name + ) + _logger.info( + "Waiting for JobSet %s to be deleted... (%d/%d)", + self._name, + i + 1, + max_retries, + ) + time.sleep(2) + except client.rest.ApiException as e: + if e.status == 404: + _logger.info("JobSet %s deleted.", self._name) + break + raise + else: + raise RuntimeError( + f"Timeout waiting for JobSet {self._name} to be deleted" + ) + else: + raise RuntimeError( + f"JobSet {self._name} already exists. Use recreate=True to" + " overwrite." + ) + + _logger.info("Creating JobSet %s...", self._name) + api.create_namespaced_custom_object( + group=group, + version=version, + namespace=self._namespace, + plural=plural, + body=self.to_dict(), + ) + _logger.info("JobSet %s created successfully.", self._name) diff --git a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py index bfd6979..92e2b2f 100644 --- a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py +++ b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py @@ -5,12 +5,12 @@ import logging import math import os -import string from typing import Any from absl import app from absl import flags from kubernetes import client from kubernetes import config +from pathwaysutils.experimental.gke import jobset import yaml _logger = logging.getLogger(__name__) @@ -45,13 +45,6 @@ "gs://pathways-test-bucket", "GCS bucket name for scratch space", ) -_TEMPLATE_FILE = flags.DEFINE_string( - "template_file", - os.path.join( - os.path.dirname(__file__), "yamls/pw-service.yaml", - ), - "Path to the JobSet YAML template file", -) _DRY_RUN = flags.DEFINE_boolean( "dry_run", False, @@ -149,25 +142,6 @@ def calculate_vms_per_slice(topology: str, chips_per_vm: int) -> int: ) from e -def load_and_substitute_template( - template_path: str, context: dict[str, Any] -) -> dict[str, Any]: - """Loads and substitutes the string.Template from the given path.""" - try: - with open(template_path, "r") as f: - template_str = f.read() - except OSError as err: - raise ValueError( - f"Could not read template file: {template_path}: {err}" - ) from err - - _logger.info("Template file: %s", template_path) - _logger.info("Context: %s", context) - template = string.Template(template_str) - _logger.info("Template: %s", template) - substituted_yaml = template.substitute(context) - return yaml.safe_load(substituted_yaml) - def deploy_jobset(jobset_yaml: dict[str, Any]) -> None: """Deploys the JobSet to the current Kubernetes cluster.""" @@ -198,29 +172,72 @@ def run_deployment( gcs_bucket, server_image, sidecar_image, - template_file, dry_run, deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset, ) -> None: """Executes the deployment logic.""" - tpu_config = get_tpu_config(tpu_type) - vms_per_slice = calculate_vms_per_slice(topology, tpu_config.chips_per_vm) - - context = { - "JOBSET_NAME": jobset_name, - "SERVER_IMAGE": server_image, - "SIDECAR_IMAGE": sidecar_image, - "SIDECAR_SHM_DIR": _SIDECAR_SHM_DIR, - "GCS_SCRATCH_LOCATION": gcs_bucket, - "NUM_SLICES": num_slices, - "INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}", - "VMS_PER_SLICE": vms_per_slice, - "CHIPS_PER_VM": tpu_config.chips_per_vm, - "ACCELERATOR_LABEL": tpu_config.accelerator_label, - "TOPOLOGY": topology, - } + # Use PathwaysJobSet builder instead of YAML template. + pw_jobset = jobset.PathwaysJobSet( + name=jobset_name, + namespace="default", + pathways_dir=gcs_bucket, + tpu_type=tpu_type, + topology=topology, + num_slices=num_slices, + shared_pathways_service=True, + ) - jobset_config = load_and_substitute_template(template_file, context) + # If custom server_image is provided, mutate the templates to use it. + if server_image: + # Mutate head job. + for container in pw_jobset.head_job_template.spec.template.spec.containers: + if container.name == "pathways-rm": + container.image = server_image + # Mutate worker job. + for container in pw_jobset.worker_job_template.spec.template.spec.containers: + if container.name == "pathways-worker": + container.image = server_image + + # Add colocated python sidecar. + pw_jobset.add_colocated_python() + + # Mutate the sidecar configuration to match what HEAD expects. + worker_spec = pw_jobset.worker_job_template.spec.template.spec + + # 1. Update sidecar image and env/mounts + for container in worker_spec.init_containers: + if container.name == "colocated-python-sidecar": + container.image = sidecar_image + container.env = [ + client.V1EnvVar(name="GRPC_SERVER_ADDRESS", value="0.0.0.0:50051"), + client.V1EnvVar(name="CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY", value=_SIDECAR_SHM_DIR), + client.V1EnvVar(name="PYTHONUNBUFFERED", value="1"), + client.V1EnvVar(name="LOGLEVEL", value="DEBUG"), + client.V1EnvVar(name="GLOG_minloglevel", value="0"), + client.V1EnvVar(name="GLOG_v", value="5"), + client.V1EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="TF_CPP_MIN_VLOG_LEVEL", value="5"), + client.V1EnvVar(name="TPU_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="GLOG_vmodule", value="jax_array_handlers=5,type_handlers=5,tensorstore_utils=5"), + ] + for mount in container.volume_mounts: + if mount.name == "shared-memory": + mount.mount_path = _SIDECAR_SHM_DIR + + # 2. Update pathways-worker container + for container in worker_spec.containers: + if container.name == "pathways-worker": + for mount in container.volume_mounts: + if mount.name == "shared-memory": + mount.mount_path = _SIDECAR_SHM_DIR + if container.env: + container.env = [e for e in container.env if e.name != "cloud_pathways_sidecar_shm_directory"] + args = container.args or [] + if not any(a.startswith("--cloud_pathways_sidecar_shm_directory=") for a in args): + args.append(f"--cloud_pathways_sidecar_shm_directory={_SIDECAR_SHM_DIR}") + container.args = args + + jobset_config = pw_jobset.to_dict() _logger.info("--- Generated JobSet YAML ---") _logger.info("\n%s", yaml.dump(jobset_config)) @@ -256,15 +273,10 @@ def main(argv: Sequence[str]) -> None: gcs_bucket=_GCS_BUCKET.value, server_image=server_image, sidecar_image=_SIDECAR_IMAGE.value, - template_file=_TEMPLATE_FILE.value, dry_run=_DRY_RUN.value, ) except ValueError as e: _logger.exception("Error: %s", e) - except FileNotFoundError: - _logger.exception( - "Error: Template file not found at %s", _TEMPLATE_FILE.value - ) if __name__ == "__main__": diff --git a/pathwaysutils/test/experimental/gke/jobset_test.py b/pathwaysutils/test/experimental/gke/jobset_test.py index 2a1a92d..020c115 100644 --- a/pathwaysutils/test/experimental/gke/jobset_test.py +++ b/pathwaysutils/test/experimental/gke/jobset_test.py @@ -1,11 +1,120 @@ +"""Tests for PathwaysJobSet builder (CL 1 + CL 2).""" + +import os +from typing import Any +from unittest import mock from absl.testing import absltest from absl.testing import parameterized from kubernetes import client from pathwaysutils.experimental.gke import jobset +import yaml + + +def normalize_k8s_spec(spec: Any) -> Any: + if isinstance(spec, dict): + result = {} + for k, v in spec.items(): + if k == "env" and isinstance(v, list): + result[k] = sorted( + [normalize_k8s_spec(x) for x in v], key=lambda x: x.get("name", "") + ) + elif k == "ports" and isinstance(v, list): + result[k] = sorted( + [normalize_k8s_spec(x) for x in v], + key=lambda x: x.get("containerPort", 0), + ) + elif k == "volumeMounts" and isinstance(v, list): + result[k] = sorted( + [normalize_k8s_spec(x) for x in v], + key=lambda x: x.get("mountPath", ""), + ) + else: + result[k] = normalize_k8s_spec(v) + return result + elif isinstance(spec, list): + return [normalize_k8s_spec(x) for x in spec] + else: + return spec class PathwaysJobSetTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.testdata_dir = os.path.join(os.path.dirname(__file__), "testdata") + + def test_golden_v5e_4x8(self): + # Load golden YAML. + golden_path = os.path.join(self.testdata_dir, "model_lite_tpuv5e_4x8.yaml") + with open(golden_path, "r") as f: + golden_config = yaml.safe_load(f) + + # Define user pod template matching the golden. + user_pod_template = { + "metadata": {"labels": {"kueue.x-k8s.io/podset": "pathways-head"}}, + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "ubuntu:latest", + "imagePullPolicy": "Always", + "command": ["sleep", "infinity"], + "resources": {"limits": {"cpu": "24", "memory": "100G"}}, + "securityContext": {"privileged": True}, + "volumeMounts": [{"name": "shared-tmp", "mountPath": "/tmp"}], + "env": [ + { + "name": "TPU_VMODULE", + "value": "real_program_continuator=1", + }, + {"name": "ENABLE_PATHWAYS_PERSISTENCE", "value": "1"}, + {"name": "ENABLE_PERSISTENCE_API", "value": "1"}, + {"name": "ENABLE_PJRT_COMPATIBILITY", "value": "true"}, + ], + }], + "nodeSelector": {"cloud.google.com/gke-nodepool": "cpu-np"}, + "volumes": [{ + "name": "shared-tmp", + "hostPath": {"path": "/tmp", "type": "DirectoryOrCreate"}, + }], + }, + } + + # Generate config. + pw_jobset = jobset.PathwaysJobSet( + name="lukebaumann-model-v5e", + namespace="default", + pathways_dir="gs://fake-bucket/scratch", + tpu_type="v5e", + topology="4x8", + num_slices=2, + user_pod_template=user_pod_template, + main_container_name="jax-tpu", + elastic_slices=2, # To match --num_elastic_slices=2 in proxy. + max_slice_restarts=4000000, + labels={"kueue.x-k8s.io/queue-name": "multislice-queue"}, + ) + + generated_config = pw_jobset.to_dict() + + # Compare using unified diff of YAML representation. + import difflib + + normalized_gen = normalize_k8s_spec(generated_config) + normalized_golden = normalize_k8s_spec(golden_config) + gen_yaml = yaml.dump(normalized_gen, default_flow_style=False) + golden_yaml = yaml.dump(normalized_golden, default_flow_style=False) + diff = list( + difflib.unified_diff( + golden_yaml.splitlines(), + gen_yaml.splitlines(), + fromfile="golden", + tofile="generated", + lineterm="", + ) + ) + if diff: + self.fail("YAML diff:\n" + "\n".join(diff)) + def test_invalid_tpu_type(self): with self.assertRaisesRegex(ValueError, "Unsupported TPU type"): jobset.PathwaysJobSet( @@ -132,6 +241,535 @@ def test_monkeypatch_restart_policy(self): ) # pytype: disable=wrong-keyword-args self.assertEqual(getattr(c, "restart_policy"), "Always") + def test_worker_job(self): + js = jobset.PathwaysJobSet( + name="test-jobset", + namespace="default", + pathways_dir="gs://test-bucket", + tpu_type="v5e", + topology="4x8", + num_slices=2, + max_slice_restarts=3, + termination_grace_period_seconds=60, + ) + config = js.to_dict() + + replicated_jobs = config["spec"]["replicatedJobs"] + worker_job = next( + j for j in replicated_jobs if j["name"] == "pathways-worker" + ) + self.assertEqual(worker_job["replicas"], 2) + + # 4x8 v5e topology has 32 chips. v5e has 4 chips per VM. + # Total VMs = 32 / 4 = 8 VMs. + job_spec = worker_job["template"]["spec"] + self.assertEqual(job_spec["completions"], 8) + self.assertEqual(job_spec["parallelism"], 8) + # backoffLimit = num_vms * max_slice_restarts = 8 * 3 = 24 + self.assertEqual(job_spec["backoffLimit"], 24) + + pod_spec = job_spec["template"]["spec"] + self.assertTrue(pod_spec["hostNetwork"]) + self.assertEqual(pod_spec["dnsPolicy"], "ClusterFirstWithHostNet") + self.assertEqual(pod_spec["restartPolicy"], "OnFailure") + self.assertEqual(pod_spec["terminationGracePeriodSeconds"], 60) + + # Node selector + self.assertEqual( + pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-accelerator"], + "tpu-v5-lite-podslice", + ) + self.assertEqual( + pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-topology"], "4x8" + ) + + # Container limits + container = pod_spec["containers"][0] + self.assertEqual(container["name"], "pathways-worker") + self.assertEqual(container["resources"]["limits"]["google.com/tpu"], "4") + + def test_add_colocated_python(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.add_colocated_python() + config = pw_jobset.to_dict() + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + pod_spec = worker_job["template"]["spec"]["template"]["spec"] + + sidecar = next( + c + for c in pod_spec["initContainers"] + if c["name"] == "colocated-python-sidecar" + ) + self.assertEqual(sidecar["restartPolicy"], "Always") + self.assertTrue( + any( + m["name"] == "shared-memory" and m["mountPath"] == "/tmp/ifrt_proxy" + for m in sidecar["volumeMounts"] + ) + ) + + volumes = pod_spec["volumes"] + self.assertTrue(any(v["name"] == "shared-memory" for v in volumes)) + + worker_container = next( + c for c in pod_spec["containers"] if c["name"] == "pathways-worker" + ) + self.assertTrue( + any( + m["name"] == "shared-memory" and m["mountPath"] == "/tmp/ifrt_proxy" + for m in worker_container["volumeMounts"] + ) + ) + self.assertTrue( + any( + e["name"] == "cloud_pathways_sidecar_shm_directory" + and e["value"] == "/tmp/ifrt_proxy" + for e in worker_container["env"] + ) + ) + + def test_colocated_python_with_jax_command(self): + jax_command = ( + "import jax; import pathwaysutils; pathwaysutils.initialize(); assert" + " jax.device_count() > 0; print(jax.devices()); x =" + " jax.device_put([0], jax.devices()[0]); y = x + 1; assert y[0] == 1;" + " jax.block_until_ready(y); print(y);" + ) + user_pod_template = { + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "gcr.io/my-project/jax-tpu:latest", + "command": ["python3", "-c", jax_command], + }] + } + } + pw_jobset = jobset.PathwaysJobSet( + name="jax-test-workload", + namespace="default", + pathways_dir="gs://my-bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + user_pod_template=user_pod_template, + main_container_name="jax-tpu", + ) + pw_jobset.add_colocated_python() + config = pw_jobset.to_dict() + + head_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == "pathways-head" + ) + pod_spec = head_job["template"]["spec"]["template"]["spec"] + jax_container = next( + c for c in pod_spec["containers"] if c["name"] == "jax-tpu" + ) + self.assertEqual(jax_container["command"], ["python3", "-c", jax_command]) + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + worker_pod_spec = worker_job["template"]["spec"]["template"]["spec"] + self.assertTrue( + any( + c["name"] == "colocated-python-sidecar" + for c in worker_pod_spec["initContainers"] + ) + ) + + def test_add_gcsfuse(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.add_gcsfuse( + containers="worker", + mount_path="/gcs/data", + bucket="my-bucket", + read_only=True, + ) + config = pw_jobset.to_dict() + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + pod_template = worker_job["template"]["spec"]["template"] + self.assertEqual( + pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], "true" + ) + self.assertEqual( + worker_job["template"]["metadata"]["annotations"][ + "gke-gcsfuse/volumes" + ], + "true", + ) + + volumes = pod_template["spec"]["volumes"] + gcs_vol = next(v for v in volumes if v["name"].startswith("gcsfuse-")) + self.assertEqual(gcs_vol["csi"]["driver"], "gcsfuse.csi.storage.gke.io") + self.assertEqual( + gcs_vol["csi"]["volumeAttributes"]["bucketName"], "my-bucket" + ) + + container = pod_template["spec"]["containers"][0] + mount = next( + m for m in container["volumeMounts"] if m["name"] == gcs_vol["name"] + ) + self.assertEqual(mount["mountPath"], "/gcs/data") + self.assertTrue(mount["readOnly"]) + + def test_gcsfuse_with_jax_command(self): + jax_command = ( + "import jax; import pathwaysutils; pathwaysutils.initialize(); assert" + " jax.device_count() > 0; print(jax.devices()); x =" + " jax.device_put([0], jax.devices()[0]); y = x + 1; assert y[0] == 1;" + " jax.block_until_ready(y); print(y);" + ) + user_pod_template = { + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "gcr.io/my-project/jax-tpu:latest", + "command": ["python3", "-c", jax_command], + }] + } + } + pw_jobset = jobset.PathwaysJobSet( + name="jax-test-workload", + namespace="default", + pathways_dir="gs://my-bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + user_pod_template=user_pod_template, + main_container_name="jax-tpu", + ) + pw_jobset.add_gcsfuse( + containers="all", + mount_path="/gcs/data", + bucket="my-bucket", + ) + config = pw_jobset.to_dict() + + head_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == "pathways-head" + ) + head_pod_template = head_job["template"]["spec"]["template"] + self.assertEqual( + head_pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], + "true", + ) + head_pod_spec = head_pod_template["spec"] + jax_container = next( + c for c in head_pod_spec["containers"] if c["name"] == "jax-tpu" + ) + self.assertEqual(jax_container["command"], ["python3", "-c", jax_command]) + self.assertTrue( + any( + m["mountPath"] == "/gcs/data" for m in jax_container["volumeMounts"] + ) + ) + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + worker_pod_template = worker_job["template"]["spec"]["template"] + self.assertEqual( + worker_pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], + "true", + ) + worker_pod_spec = worker_pod_template["spec"] + worker_container = next( + c + for c in worker_pod_spec["containers"] + if c["name"] == "pathways-worker" + ) + self.assertTrue( + any( + m["mountPath"] == "/gcs/data" + for m in worker_container["volumeMounts"] + ) + ) + + @mock.patch("kubernetes.config.load_kube_config") + @mock.patch("kubernetes.config.load_incluster_config") + @mock.patch("kubernetes.client.CustomObjectsApi") + def test_apply_create( + self, mock_custom_objects_api, mock_load_incluster, mock_load_kube + ): + mock_api = mock_custom_objects_api.return_value + # Mock GET to return 404 (not exists). + from kubernetes.client.rest import ApiException + + mock_api.get_namespaced_custom_object.side_effect = ApiException(status=404) + + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.apply() + + mock_api.create_namespaced_custom_object.assert_called_once_with( + group="jobset.x-k8s.io", + version="v1alpha2", + namespace="default", + plural="jobsets", + body=pw_jobset.to_dict(), + ) + + @mock.patch("kubernetes.config.load_kube_config") + @mock.patch("kubernetes.config.load_incluster_config") + @mock.patch("kubernetes.client.CustomObjectsApi") + def test_apply_exists_recreate( + self, mock_custom_objects_api, mock_load_incluster, mock_load_kube + ): + mock_api = mock_custom_objects_api.return_value + # Mock GET to return success (exists). + mock_api.get_namespaced_custom_object.return_value = {} + # Mock GET after delete to return 404. + from kubernetes.client.rest import ApiException + + mock_api.get_namespaced_custom_object.side_effect = [ + {}, + ApiException(status=404), + ] + + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.apply(recreate=True) + + mock_api.delete_namespaced_custom_object.assert_called_once_with( + "jobset.x-k8s.io", "v1alpha2", "default", "jobsets", "test-workload" + ) + mock_api.create_namespaced_custom_object.assert_called_once_with( + group="jobset.x-k8s.io", + version="v1alpha2", + namespace="default", + plural="jobsets", + body=pw_jobset.to_dict(), + ) + + @mock.patch("kubernetes.config.load_kube_config") + @mock.patch("kubernetes.config.load_incluster_config") + @mock.patch("kubernetes.client.CustomObjectsApi") + def test_apply_exists_no_recreate_fails( + self, mock_custom_objects_api, mock_load_incluster, mock_load_kube + ): + mock_api = mock_custom_objects_api.return_value + # Mock GET to return success (exists). + mock_api.get_namespaced_custom_object.return_value = {} + + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + with self.assertRaises(RuntimeError): + pw_jobset.apply(recreate=False) + + def test_export_import_roundtrip(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.add_colocated_python() + pw_jobset.add_gcsfuse( + containers="worker", mount_path="/tmp/gcs", bucket="my-bucket" + ) + + # Export. + temp_filepath = os.path.join(self.create_tempdir().full_path, "jobset.yaml") + pw_jobset.export_yaml(temp_filepath) + + # Import. + imported_jobset = jobset.PathwaysJobSet.import_yaml(temp_filepath) + + # Verify they are semantically identical. + self.assertEqual( + normalize_k8s_spec(pw_jobset.to_dict()), + normalize_k8s_spec(imported_jobset.to_dict()), + ) + + def test_import_validation_failures(self): + temp_dir = self.create_tempdir().full_path + + # 1. Missing kind. + invalid_config1 = { + "apiVersion": "jobset.x-k8s.io/v1alpha2", + "metadata": {"name": "test"}, + "spec": {"replicatedJobs": []}, + } + path1 = os.path.join(temp_dir, "invalid1.yaml") + with open(path1, "w") as f: + yaml.dump(invalid_config1, f) + with self.assertRaisesRegex(ValueError, "Resource kind is not JobSet"): + jobset.PathwaysJobSet.import_yaml(path1) + + # 2. Missing head. + invalid_config2 = { + "apiVersion": "jobset.x-k8s.io/v1alpha2", + "kind": "JobSet", + "metadata": {"name": "test"}, + "spec": { + "replicatedJobs": [ + {"name": "worker", "replicas": 1, "template": {}} + ] + }, + } + path2 = os.path.join(temp_dir, "invalid2.yaml") + with open(path2, "w") as f: + yaml.dump(invalid_config2, f) + with self.assertRaisesRegex(ValueError, "Missing head replicated job"): + jobset.PathwaysJobSet.import_yaml(path2) + + # 3. Missing worker. + invalid_config3 = { + "apiVersion": "jobset.x-k8s.io/v1alpha2", + "kind": "JobSet", + "metadata": {"name": "test"}, + "spec": { + "replicatedJobs": [ + {"name": "pathways-head", "replicas": 1, "template": {}} + ] + }, + } + path3 = os.path.join(temp_dir, "invalid3.yaml") + with open(path3, "w") as f: + yaml.dump(invalid_config3, f) + with self.assertRaisesRegex(ValueError, "Missing worker replicated job"): + jobset.PathwaysJobSet.import_yaml(path3) + + def test_labels_and_annotations(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + labels={"key1": "val1"}, + annotations={"key2": "val2"}, + ) + config = pw_jobset.to_dict() + self.assertEqual(config["metadata"]["labels"], {"key1": "val1"}) + self.assertEqual(config["metadata"]["annotations"], {"key2": "val2"}) + + def test_direct_mutation(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + head_spec = pw_jobset.head_job_template.spec.template.spec + head_spec.active_deadline_seconds = 100 + + worker_spec = pw_jobset.worker_job_template.spec.template.spec + worker_spec.active_deadline_seconds = 200 + + def test_shared_pathways_service(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-sps", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=2, + shared_pathways_service=True, + ) + config = pw_jobset.to_dict() + + # Success policy should be set to target head job. + self.assertEqual( + config["spec"]["successPolicy"], + { + "operator": "All", + "targetReplicatedJobs": ["pathways-head"], + }, + ) + + replicated_jobs = config["spec"]["replicatedJobs"] + head_job = next( + j for j in replicated_jobs if j["name"] == "pathways-head" + ) + pod_spec = head_job["template"]["spec"]["template"]["spec"] + + # Head job should only have pathways-rm container, no pathways-proxy. + container_names = [c["name"] for c in pod_spec["containers"]] + self.assertIn("pathways-rm", container_names) + self.assertNotIn("pathways-proxy", container_names) + self.assertEqual(len(pod_spec["containers"]), 1) + + def test_shared_pathways_service_with_user_template_fails(self): + user_pod_template = { + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "gcr.io/my-project/jax-tpu:latest", + }] + } + } + with self.assertRaisesRegex( + ValueError, + "Cannot enable shared_pathways_service when user_pod_template is" + " provided.", + ): + jobset.PathwaysJobSet( + name="test-sps", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=2, + shared_pathways_service=True, + user_pod_template=user_pod_template, + ) + if __name__ == "__main__": absltest.main() diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml b/pathwaysutils/test/experimental/gke/testdata/model_lite_tpuv5e_4x8.yaml similarity index 55% rename from pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml rename to pathwaysutils/test/experimental/gke/testdata/model_lite_tpuv5e_4x8.yaml index a02750e..ad72f17 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml +++ b/pathwaysutils/test/experimental/gke/testdata/model_lite_tpuv5e_4x8.yaml @@ -1,13 +1,14 @@ apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: - name: ${JOBSET_NAME} + labels: + kueue.x-k8s.io/queue-name: multislice-queue + name: lukebaumann-model-v5e namespace: default spec: coordinator: replicatedJob: pathways-head failurePolicy: - maxRestarts: 1 restartStrategy: Recreate network: enableDNSHostnames: true @@ -20,7 +21,7 @@ spec: annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname spec: - backoffLimit: 3 + backoffLimit: 0 completionMode: Indexed completions: 1 parallelism: 1 @@ -28,17 +29,59 @@ spec: metadata: annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname + labels: + kueue.x-k8s.io/podset: pathways-head spec: containers: - - name: pathways-rm - image: ${SERVER_IMAGE} + - command: + - sleep + - infinity + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + - name: JAX_PLATFORMS + value: proxy + - name: XCLOUD_ENVIRONMENT + value: GCP + # - name: TPU_STDERR_LOG_LEVEL + # value: "0" + # - name: TPU_MIN_LOG_LEVEL + # value: "0" + # - name: TF_CPP_MIN_LOG_LEVEL + # value: "0" + - name: TPU_VMODULE + value: "real_program_continuator=1" + - name: ENABLE_PATHWAYS_PERSISTENCE + value: "1" + - name: ENABLE_PERSISTENCE_API + value: "1" + - name: ENABLE_PJRT_COMPATIBILITY + value: "true" + - name: JAX_BACKEND_TARGET + value: grpc://$(PATHWAYS_HEAD):29000 + image: ubuntu:latest imagePullPolicy: Always - args: + name: jax-tpu + resources: + limits: + cpu: "24" + memory: 100G + securityContext: + privileged: true + volumeMounts: + - mountPath: /tmp + name: shared-tmp + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + initContainers: + - args: - --server_port=29001 - - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + - --gcs_scratch_location=gs://fake-bucket/scratch - --node_type=resource_manager - - --instance_count=${NUM_SLICES} - - --instance_type=${INSTANCE_TYPE} + - --instance_count=2 + - --instance_type=tpuv5e:4x8 env: - name: REPLICATED_JOB_NAME valueFrom: @@ -54,6 +97,9 @@ spec: fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] - name: TPU_SKIP_MDS_QUERY value: "true" + image: "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" + imagePullPolicy: Always + name: pathways-rm ports: - containerPort: 29001 protocol: TCP @@ -63,31 +109,55 @@ spec: limits: cpu: "8" memory: 32G - dnsPolicy: ClusterFirstWithHostNet - hostNetwork: true - restartPolicy: OnFailure - - name: worker - replicas: ${NUM_SLICES} + restartPolicy: Always + - args: + - --server_port=29000 + - --resource_manager_address=$(PATHWAYS_HEAD):29001 + - --gcs_scratch_location=gs://fake-bucket/scratch + - --num_elastic_slices=2 + env: + - name: PATHWAYS_HEAD + valueFrom: + fieldRef: + fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest" + imagePullPolicy: Always + name: pathways-proxy + ports: + - containerPort: 29000 + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + restartPolicy: Always + nodeSelector: + cloud.google.com/gke-nodepool: cpu-np + restartPolicy: Never + volumes: + - hostPath: + path: /tmp + type: DirectoryOrCreate + name: shared-tmp + - name: pathways-worker + replicas: 2 # number of slices template: + metadata: {} spec: - backoffLimit: 1000000 + backoffLimit: 32000000 completionMode: Indexed - completions: ${VMS_PER_SLICE} - parallelism: ${VMS_PER_SLICE} + completions: 8 # number of vms per slice + parallelism: 8 # number of vms per slice template: metadata: annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: containers: - - name: pathways-worker - image: ${SERVER_IMAGE} - imagePullPolicy: Always - args: + - args: + - --resource_manager_address=$(PATHWAYS_HEAD):29001 - --server_port=29005 - - --resource_manager_address=$$(PATHWAYS_HEAD):29001 - - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} - - --cloud_pathways_sidecar_shm_directory=${SIDECAR_SHM_DIR} + - --gcs_scratch_location=gs://fake-bucket/scratch env: - name: TPU_MIN_LOG_LEVEL value: "0" @@ -121,6 +191,9 @@ spec: valueFrom: fieldRef: fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator'] + image: "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" + imagePullPolicy: Always + name: pathways-worker ports: - containerPort: 29005 protocol: TCP @@ -132,64 +205,24 @@ spec: protocol: TCP resources: limits: - google.com/tpu: "${CHIPS_PER_VM}" - volumeMounts: - - name: shared-tmp - mountPath: /tmp - - name: sidecar-shared-memory - mountPath: ${SIDECAR_SHM_DIR} - initContainers: - - name: colocated-python-sidecar - image: ${SIDECAR_IMAGE} - imagePullPolicy: Always - env: - - name: GRPC_SERVER_ADDRESS - value: '''0.0.0.0:50051''' - - name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY - value: ${SIDECAR_SHM_DIR} - - name: PYTHONUNBUFFERED - value: '1' - # --- High Verbosity Logging Variables --- - - name: LOGLEVEL - value: 'DEBUG' - - name: GLOG_minloglevel - value: '0' # 0 = INFO level base - - name: GLOG_v - value: '5' # Extreme verbosity for all C++ modules - - name: TF_CPP_MIN_LOG_LEVEL - value: '0' - - name: TF_CPP_MIN_VLOG_LEVEL - value: '5' # TF/XLA verbose logging - - name: TPU_MIN_LOG_LEVEL - value: '0' - - name: GLOG_vmodule - value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5' - # ---------------------------------------- - ports: - - containerPort: 50051 - protocol: TCP - resources: {} - restartPolicy: Always + google.com/tpu: "4" volumeMounts: - - name: shared-tmp - mountPath: /tmp - - name: sidecar-shared-memory - mountPath: ${SIDECAR_SHM_DIR} + - mountPath: /tmp + name: shared-tmp dnsPolicy: ClusterFirstWithHostNet hostNetwork: true nodeSelector: - cloud.google.com/gke-tpu-accelerator: ${ACCELERATOR_LABEL} - cloud.google.com/gke-tpu-topology: ${TOPOLOGY} + cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice + cloud.google.com/gke-tpu-topology: 4x8 restartPolicy: OnFailure volumes: - - name: shared-tmp - hostPath: + - hostPath: path: /tmp type: DirectoryOrCreate - - name: sidecar-shared-memory - emptyDir: - medium: Memory + name: shared-tmp startupPolicy: startupPolicyOrder: InOrder successPolicy: operator: All + targetReplicatedJobs: + - pathways-head