From 8a77a63b281146fc20775144d90c78a1e89113c7 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Mon, 22 Jun 2026 14:15:26 -0700 Subject: [PATCH] Add GCSFuse support to PathwaysJobSet. PiperOrigin-RevId: 936246234 --- pathwaysutils/experimental/gke/jobset.py | 297 ++++++++++++++++-- .../test/experimental/gke/jobset_test.py | 171 ++++++++++ 2 files changed, 450 insertions(+), 18 deletions(-) diff --git a/pathwaysutils/experimental/gke/jobset.py b/pathwaysutils/experimental/gke/jobset.py index a29441f..268768d 100644 --- a/pathwaysutils/experimental/gke/jobset.py +++ b/pathwaysutils/experimental/gke/jobset.py @@ -9,10 +9,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pathways JobSet generator and builder (Head Job Config).""" +"""Pathways JobSet generator and builder (with Worker Job Config).""" import json import logging +import math from typing import Any, Mapping from kubernetes import client @@ -33,6 +34,7 @@ PATHWAYS_PROXY_PORT = 29000 PATHWAYS_RM_PORT = 29001 +PATHWAYS_WORKER_PORT = 29005 MACHINE_TYPE_TO_TPU_VERSION_MAP = { "tpu7x-standard-4t": "tpu7x", @@ -77,7 +79,7 @@ def __init__(self, data): class PathwaysJobSet: - """Generates JobSet configuration for Pathways (with Head Job Config).""" + """JobSet configuration generator for Pathways.""" def __init__( self, @@ -90,6 +92,8 @@ def __init__( user_pod_template: Mapping[str, Any] | None = None, main_container_name: str = "main", max_restarts: int = 0, + max_slice_restarts: int = 0, + termination_grace_period_seconds: int | None = None, pathways_version: str = "latest", jobset_api_version: str = "v1alpha2", elastic_slices: int = 0, @@ -126,6 +130,19 @@ def __init__( if not tpu_version: raise ValueError(f"Unsupported TPU type: {tpu_type}") + gke_accel_type = MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP.get( + tpu_type.lower() + ) + + # Calculate VMs. + dims = [int(x) for x in topology.split("x")] + total_chips = math.prod(dims) + chips_per_vm = 8 if tpu_type.lower().endswith("8t") else 4 + if total_chips < chips_per_vm: + num_vms = 1 + else: + num_vms = total_chips // chips_per_vm + instance_type = f"{tpu_version}:{topology}" image_tag = pathways_version @@ -140,8 +157,19 @@ def __init__( elastic_slices=elastic_slices, ) - # Build minimal worker template (placeholder) - self._worker_job_template = self._build_minimal_job_template("worker") + # Build worker template. + self._worker_job_template = self._build_worker_job_template( + name=name, + pathways_dir=pathways_dir, + num_slices=num_slices, + num_vms=num_vms, + chips_per_vm=chips_per_vm, + gke_accel_type=gke_accel_type, + topology=topology, + image_tag=image_tag, + max_slice_restarts=max_slice_restarts, + termination_grace_period_seconds=termination_grace_period_seconds, + ) self._success_policy = None if user_pod_template: @@ -150,20 +178,6 @@ def __init__( "targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME], } - def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec: - """Builds a minimal job template for a given role.""" - pod_spec = client.V1PodSpec( - containers=[ - client.V1Container(name=f"placeholder-{role}", image="ubuntu") - ] - ) - job_spec = client.V1JobSpec( - template=client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta(labels={"role": role}), spec=pod_spec - ) - ) - return client.V1JobTemplateSpec(spec=job_spec) - def _build_head_job_template( self, pathways_dir: str, @@ -365,6 +379,253 @@ def _build_head_job_template( ) return head_job_template + def _build_worker_job_template( + self, + name: str, + pathways_dir: str, + num_slices: int, + num_vms: int, + chips_per_vm: int, + gke_accel_type: str, + topology: str, + image_tag: str, + max_slice_restarts: int, + termination_grace_period_seconds: int | None, + ) -> client.V1JobTemplateSpec: + worker_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}" + + args = [ + f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}", + f"--server_port={PATHWAYS_WORKER_PORT}", + f"--gcs_scratch_location={pathways_dir}", + ] + worker_env = [ + client.V1EnvVar(name="TPU_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="0"), + client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"), + client.V1EnvVar(name="MEGASCALE_GRPC_ENABLE_XOR_TRACER", value="false"), + client.V1EnvVar( + name="MEGASCALE_NUM_SLICES", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']" + ) + ), + ), + client.V1EnvVar( + name="JOBSET_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.annotations['jobset.sigs.k8s.io/jobset-name']" + ) + ) + ), + ), + client.V1EnvVar( + name="REPLICATED_JOB_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']" + ) + ), + ), + client.V1EnvVar( + name="MEGASCALE_SLICE_ID", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.labels['jobset.sigs.k8s.io/job-index']" + ) + ), + ), + client.V1EnvVar( + name="PATHWAYS_HEAD", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + client.V1EnvVar( + name="MEGASCALE_COORDINATOR_ADDRESS", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + ] + + worker_container = client.V1Container( + name="pathways-worker", + image=worker_image, + image_pull_policy="Always", + args=args, + env=worker_env, + ports=[ + client.V1ContainerPort( + container_port=PATHWAYS_WORKER_PORT, protocol="TCP" + ), + client.V1ContainerPort(container_port=29006, protocol="TCP"), + client.V1ContainerPort(container_port=8471, protocol="TCP"), + client.V1ContainerPort(container_port=8080, protocol="TCP"), + ], + volume_mounts=[ + client.V1VolumeMount(name="shared-tmp", mount_path="/tmp") + ], + resources=client.V1ResourceRequirements( + limits={"google.com/tpu": str(chips_per_vm)} + ), + ) + + node_selector = { + "cloud.google.com/gke-tpu-accelerator": gke_accel_type, + "cloud.google.com/gke-tpu-topology": topology, + } + + backoff_limit = num_vms * 4 + if max_slice_restarts > 0: + backoff_limit = num_vms * max_slice_restarts + + worker_pod_spec = client.V1PodSpec( + containers=[worker_container], + node_selector=node_selector, + volumes=[ + client.V1Volume( + name="shared-tmp", + host_path=client.V1HostPathVolumeSource( + path="/tmp", type="DirectoryOrCreate" + ), + ) + ], + host_network=True, + dns_policy="ClusterFirstWithHostNet", + restart_policy="OnFailure", + ) + if termination_grace_period_seconds is not None: + worker_pod_spec.termination_grace_period_seconds = ( + termination_grace_period_seconds + ) + + worker_job_template = client.V1JobTemplateSpec( + metadata=client.V1ObjectMeta(), + spec=client.V1JobSpec( + backoff_limit=backoff_limit, + completion_mode="Indexed", + completions=num_vms, + parallelism=num_vms, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + annotations={ + "alpha.jobset.sigs.k8s.io/exclusive-topology": ( + "cloud.google.com/gke-nodepool" + ) + } + ), + spec=worker_pod_spec, + ), + ), + ) + return worker_job_template + + def add_gcsfuse( + self, + containers: str, + mount_path: str, + bucket: str, + read_only: bool = False, + ) -> "PathwaysJobSet": + """Adds GCSFuse mount to specified containers.""" + target_templates = [] + if containers in ("head", "all"): + target_templates.append((PATHWAYS_HEAD_JOB_NAME, self._head_job_template)) + if containers in ("worker", "all"): + target_templates.append(("worker", self._worker_job_template)) + + bucket_hash = abs(hash(bucket)) % (10**8) + volume_name = f"gcsfuse-{bucket_hash}" + + for job_name, job_template_obj in target_templates: + # 1. Add annotation. + job_template_metadata = job_template_obj.metadata or client.V1ObjectMeta() + job_template_annotations = job_template_metadata.annotations or {} + job_template_annotations["gke-gcsfuse/volumes"] = "true" + job_template_metadata.annotations = job_template_annotations + job_template_obj.metadata = job_template_metadata + + pod_template = job_template_obj.spec.template + pod_metadata = pod_template.metadata or client.V1ObjectMeta() + pod_annotations = pod_metadata.annotations or {} + pod_annotations["gke-gcsfuse/volumes"] = "true" + pod_metadata.annotations = pod_annotations + pod_template.metadata = pod_metadata + + pod_spec = pod_template.spec + + # 2. Add volume. + volumes = pod_spec.volumes or [] + vol_exists = any(v.name == volume_name for v in volumes) + if not vol_exists: + volumes.append( + client.V1Volume( + name=volume_name, + csi=client.V1CSIVolumeSource( + driver="gcsfuse.csi.storage.gke.io", + volume_attributes={"bucketName": bucket}, + ), + ) + ) + pod_spec.volumes = volumes + + # 3. Add volumeMount to containers. + container_names = [] + if job_name == PATHWAYS_HEAD_JOB_NAME: + has_user_container = len(pod_spec.containers) > 2 or ( + len(pod_spec.containers) == 1 + and pod_spec.containers[0].name != "pathways-rm" + ) + if has_user_container: + container_names = [ + c.name + for c in pod_spec.containers + if c.name not in ("pathways-rm", "pathways-proxy") + ] + else: + container_names = ["pathways-rm", "pathways-proxy"] + else: + container_names = ["pathways-worker"] + + for container in pod_spec.containers: + if container.name in container_names: + volume_mounts = container.volume_mounts or [] + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_path, + read_only=read_only, + ) + ) + container.volume_mounts = volume_mounts + + # Also check initContainers. + for container in pod_spec.init_containers or []: + if container.name in container_names: + volume_mounts = container.volume_mounts or [] + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_path, + read_only=read_only, + ) + ) + container.volume_mounts = volume_mounts + + return self + def _compile_config(self) -> dict[str, Any]: """Compiles the JobSet configuration into a dictionary.""" with client.ApiClient() as api_client: diff --git a/pathwaysutils/test/experimental/gke/jobset_test.py b/pathwaysutils/test/experimental/gke/jobset_test.py index 2a1a92d..5c839b1 100644 --- a/pathwaysutils/test/experimental/gke/jobset_test.py +++ b/pathwaysutils/test/experimental/gke/jobset_test.py @@ -132,6 +132,177 @@ def test_monkeypatch_restart_policy(self): ) # pytype: disable=wrong-keyword-args self.assertEqual(getattr(c, "restart_policy"), "Always") + def test_worker_job(self): + js = jobset.PathwaysJobSet( + name="test-jobset", + namespace="default", + pathways_dir="gs://test-bucket", + tpu_type="v5e", + topology="4x8", + num_slices=2, + max_slice_restarts=3, + termination_grace_period_seconds=60, + ) + config = js.to_dict() + + replicated_jobs = config["spec"]["replicatedJobs"] + worker_job = next( + j for j in replicated_jobs if j["name"] == "pathways-worker" + ) + self.assertEqual(worker_job["replicas"], 2) + + # 4x8 v5e topology has 32 chips. v5e has 4 chips per VM. + # Total VMs = 32 / 4 = 8 VMs. + job_spec = worker_job["template"]["spec"] + self.assertEqual(job_spec["completions"], 8) + self.assertEqual(job_spec["parallelism"], 8) + # backoffLimit = num_vms * max_slice_restarts = 8 * 3 = 24 + self.assertEqual(job_spec["backoffLimit"], 24) + + pod_spec = job_spec["template"]["spec"] + self.assertTrue(pod_spec["hostNetwork"]) + self.assertEqual(pod_spec["dnsPolicy"], "ClusterFirstWithHostNet") + self.assertEqual(pod_spec["restartPolicy"], "OnFailure") + self.assertEqual(pod_spec["terminationGracePeriodSeconds"], 60) + + # Node selector + self.assertEqual( + pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-accelerator"], + "tpu-v5-lite-podslice", + ) + self.assertEqual( + pod_spec["nodeSelector"]["cloud.google.com/gke-tpu-topology"], "4x8" + ) + + # Container limits + container = pod_spec["containers"][0] + self.assertEqual(container["name"], "pathways-worker") + self.assertEqual(container["resources"]["limits"]["google.com/tpu"], "4") + + def test_add_gcsfuse(self): + pw_jobset = jobset.PathwaysJobSet( + name="test-workload", + namespace="default", + pathways_dir="gs://bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + ) + pw_jobset.add_gcsfuse( + containers="worker", + mount_path="/gcs/data", + bucket="my-bucket", + read_only=True, + ) + config = pw_jobset.to_dict() + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + pod_template = worker_job["template"]["spec"]["template"] + self.assertEqual( + pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], "true" + ) + self.assertEqual( + worker_job["template"]["metadata"]["annotations"][ + "gke-gcsfuse/volumes" + ], + "true", + ) + + volumes = pod_template["spec"]["volumes"] + gcs_vol = next(v for v in volumes if v["name"].startswith("gcsfuse-")) + self.assertEqual(gcs_vol["csi"]["driver"], "gcsfuse.csi.storage.gke.io") + self.assertEqual( + gcs_vol["csi"]["volumeAttributes"]["bucketName"], "my-bucket" + ) + + container = pod_template["spec"]["containers"][0] + mount = next( + m for m in container["volumeMounts"] if m["name"] == gcs_vol["name"] + ) + self.assertEqual(mount["mountPath"], "/gcs/data") + self.assertTrue(mount["readOnly"]) + + def test_gcsfuse_with_jax_command(self): + jax_command = ( + "import jax; import pathwaysutils; pathwaysutils.initialize(); assert" + " jax.device_count() > 0; print(jax.devices()); x =" + " jax.device_put([0], jax.devices()[0]); y = x + 1; assert y[0] == 1;" + " jax.block_until_ready(y); print(y);" + ) + user_pod_template = { + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "gcr.io/my-project/jax-tpu:latest", + "command": ["python3", "-c", jax_command], + }] + } + } + pw_jobset = jobset.PathwaysJobSet( + name="jax-test-workload", + namespace="default", + pathways_dir="gs://my-bucket/scratch", + tpu_type="v5e", + topology="2x2", + num_slices=1, + user_pod_template=user_pod_template, + main_container_name="jax-tpu", + ) + pw_jobset.add_gcsfuse( + containers="all", + mount_path="/gcs/data", + bucket="my-bucket", + ) + config = pw_jobset.to_dict() + + head_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == "pathways-head" + ) + head_pod_template = head_job["template"]["spec"]["template"] + self.assertEqual( + head_pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], + "true", + ) + head_pod_spec = head_pod_template["spec"] + jax_container = next( + c for c in head_pod_spec["containers"] if c["name"] == "jax-tpu" + ) + self.assertEqual(jax_container["command"], ["python3", "-c", jax_command]) + self.assertTrue( + any( + m["mountPath"] == "/gcs/data" for m in jax_container["volumeMounts"] + ) + ) + + worker_job = next( + j + for j in config["spec"]["replicatedJobs"] + if j["name"] == jobset.PATHWAYS_WORKER_JOB_NAME + ) + worker_pod_template = worker_job["template"]["spec"]["template"] + self.assertEqual( + worker_pod_template["metadata"]["annotations"]["gke-gcsfuse/volumes"], + "true", + ) + worker_pod_spec = worker_pod_template["spec"] + worker_container = next( + c + for c in worker_pod_spec["containers"] + if c["name"] == "pathways-worker" + ) + self.assertTrue( + any( + m["mountPath"] == "/gcs/data" + for m in worker_container["volumeMounts"] + ) + ) + if __name__ == "__main__": absltest.main()