Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 279 additions & 18 deletions pathwaysutils/experimental/gke/jobset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pathways JobSet generator and builder (Head Job Config)."""
"""Pathways JobSet generator and builder (with Worker Job Config)."""

import json
import logging
import math
from typing import Any, Mapping
from kubernetes import client

Expand All @@ -33,6 +34,7 @@

PATHWAYS_PROXY_PORT = 29000
PATHWAYS_RM_PORT = 29001
PATHWAYS_WORKER_PORT = 29005

MACHINE_TYPE_TO_TPU_VERSION_MAP = {
"tpu7x-standard-4t": "tpu7x",
Expand Down Expand Up @@ -77,7 +79,7 @@ def __init__(self, data):


class PathwaysJobSet:
"""Generates JobSet configuration for Pathways (with Head Job Config)."""
"""JobSet configuration generator for Pathways."""

def __init__(
self,
Expand All @@ -90,6 +92,8 @@ def __init__(
user_pod_template: Mapping[str, Any] | None = None,
main_container_name: str = "main",
max_restarts: int = 0,
max_slice_restarts: int = 0,
termination_grace_period_seconds: int | None = None,
pathways_version: str = "latest",
jobset_api_version: str = "v1alpha2",
elastic_slices: int = 0,
Expand Down Expand Up @@ -126,6 +130,19 @@ def __init__(
if not tpu_version:
raise ValueError(f"Unsupported TPU type: {tpu_type}")

gke_accel_type = MACHINE_TYPE_TO_GKE_ACCELERATOR_TYPE_MAP.get(
tpu_type.lower()
)

# Calculate VMs.
dims = [int(x) for x in topology.split("x")]
total_chips = math.prod(dims)
chips_per_vm = 8 if tpu_type.lower().endswith("8t") else 4
if total_chips < chips_per_vm:
num_vms = 1
else:
num_vms = total_chips // chips_per_vm

instance_type = f"{tpu_version}:{topology}"
image_tag = pathways_version

Expand All @@ -140,8 +157,19 @@ def __init__(
elastic_slices=elastic_slices,
)

# Build minimal worker template (placeholder)
self._worker_job_template = self._build_minimal_job_template("worker")
# Build worker template.
self._worker_job_template = self._build_worker_job_template(
name=name,
pathways_dir=pathways_dir,
num_slices=num_slices,
num_vms=num_vms,
chips_per_vm=chips_per_vm,
gke_accel_type=gke_accel_type,
topology=topology,
image_tag=image_tag,
max_slice_restarts=max_slice_restarts,
termination_grace_period_seconds=termination_grace_period_seconds,
)

self._success_policy = None
if user_pod_template:
Expand All @@ -150,20 +178,6 @@ def __init__(
"targetReplicatedJobs": [PATHWAYS_HEAD_JOB_NAME],
}

def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec:
"""Builds a minimal job template for a given role."""
pod_spec = client.V1PodSpec(
containers=[
client.V1Container(name=f"placeholder-{role}", image="ubuntu")
]
)
job_spec = client.V1JobSpec(
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels={"role": role}), spec=pod_spec
)
)
return client.V1JobTemplateSpec(spec=job_spec)

def _build_head_job_template(
self,
pathways_dir: str,
Expand Down Expand Up @@ -365,6 +379,253 @@ def _build_head_job_template(
)
return head_job_template

def _build_worker_job_template(
self,
name: str,
pathways_dir: str,
num_slices: int,
num_vms: int,
chips_per_vm: int,
gke_accel_type: str,
topology: str,
image_tag: str,
max_slice_restarts: int,
termination_grace_period_seconds: int | None,
) -> client.V1JobTemplateSpec:
worker_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}"

args = [
f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}",
f"--server_port={PATHWAYS_WORKER_PORT}",
f"--gcs_scratch_location={pathways_dir}",
]
worker_env = [
client.V1EnvVar(name="TPU_MIN_LOG_LEVEL", value="0"),
client.V1EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="0"),
client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"),
client.V1EnvVar(name="MEGASCALE_GRPC_ENABLE_XOR_TRACER", value="false"),
client.V1EnvVar(
name="MEGASCALE_NUM_SLICES",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']"
)
),
),
client.V1EnvVar(
name="JOBSET_NAME",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
)
)
),
),
client.V1EnvVar(
name="REPLICATED_JOB_NAME",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']"
)
),
),
client.V1EnvVar(
name="MEGASCALE_SLICE_ID",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path="metadata.labels['jobset.sigs.k8s.io/job-index']"
)
),
),
client.V1EnvVar(
name="PATHWAYS_HEAD",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
)
)
),
),
client.V1EnvVar(
name="MEGASCALE_COORDINATOR_ADDRESS",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.labels['jobset.sigs.k8s.io/coordinator']"
)
)
),
),
]

worker_container = client.V1Container(
name="pathways-worker",
image=worker_image,
image_pull_policy="Always",
args=args,
env=worker_env,
ports=[
client.V1ContainerPort(
container_port=PATHWAYS_WORKER_PORT, protocol="TCP"
),
client.V1ContainerPort(container_port=29006, protocol="TCP"),
client.V1ContainerPort(container_port=8471, protocol="TCP"),
client.V1ContainerPort(container_port=8080, protocol="TCP"),
],
volume_mounts=[
client.V1VolumeMount(name="shared-tmp", mount_path="/tmp")
],
resources=client.V1ResourceRequirements(
limits={"google.com/tpu": str(chips_per_vm)}
),
)

node_selector = {
"cloud.google.com/gke-tpu-accelerator": gke_accel_type,
"cloud.google.com/gke-tpu-topology": topology,
}

backoff_limit = num_vms * 4
if max_slice_restarts > 0:
backoff_limit = num_vms * max_slice_restarts

worker_pod_spec = client.V1PodSpec(
containers=[worker_container],
node_selector=node_selector,
volumes=[
client.V1Volume(
name="shared-tmp",
host_path=client.V1HostPathVolumeSource(
path="/tmp", type="DirectoryOrCreate"
),
)
],
host_network=True,
dns_policy="ClusterFirstWithHostNet",
restart_policy="OnFailure",
)
if termination_grace_period_seconds is not None:
worker_pod_spec.termination_grace_period_seconds = (
termination_grace_period_seconds
)

worker_job_template = client.V1JobTemplateSpec(
metadata=client.V1ObjectMeta(),
spec=client.V1JobSpec(
backoff_limit=backoff_limit,
completion_mode="Indexed",
completions=num_vms,
parallelism=num_vms,
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
annotations={
"alpha.jobset.sigs.k8s.io/exclusive-topology": (
"cloud.google.com/gke-nodepool"
)
}
),
spec=worker_pod_spec,
),
),
)
return worker_job_template

def add_gcsfuse(
self,
containers: str,
mount_path: str,
bucket: str,
read_only: bool = False,
) -> "PathwaysJobSet":
"""Adds GCSFuse mount to specified containers."""
target_templates = []
if containers in ("head", "all"):
target_templates.append((PATHWAYS_HEAD_JOB_NAME, self._head_job_template))
if containers in ("worker", "all"):
target_templates.append(("worker", self._worker_job_template))

bucket_hash = abs(hash(bucket)) % (10**8)
volume_name = f"gcsfuse-{bucket_hash}"

for job_name, job_template_obj in target_templates:
# 1. Add annotation.
job_template_metadata = job_template_obj.metadata or client.V1ObjectMeta()
job_template_annotations = job_template_metadata.annotations or {}
job_template_annotations["gke-gcsfuse/volumes"] = "true"
job_template_metadata.annotations = job_template_annotations
job_template_obj.metadata = job_template_metadata

pod_template = job_template_obj.spec.template
pod_metadata = pod_template.metadata or client.V1ObjectMeta()
pod_annotations = pod_metadata.annotations or {}
pod_annotations["gke-gcsfuse/volumes"] = "true"
pod_metadata.annotations = pod_annotations
pod_template.metadata = pod_metadata

pod_spec = pod_template.spec

# 2. Add volume.
volumes = pod_spec.volumes or []
vol_exists = any(v.name == volume_name for v in volumes)
if not vol_exists:
volumes.append(
client.V1Volume(
name=volume_name,
csi=client.V1CSIVolumeSource(
driver="gcsfuse.csi.storage.gke.io",
volume_attributes={"bucketName": bucket},
),
)
)
pod_spec.volumes = volumes

# 3. Add volumeMount to containers.
container_names = []
if job_name == PATHWAYS_HEAD_JOB_NAME:
has_user_container = len(pod_spec.containers) > 2 or (
len(pod_spec.containers) == 1
and pod_spec.containers[0].name != "pathways-rm"
)
if has_user_container:
container_names = [
c.name
for c in pod_spec.containers
if c.name not in ("pathways-rm", "pathways-proxy")
]
else:
container_names = ["pathways-rm", "pathways-proxy"]
else:
container_names = ["pathways-worker"]

for container in pod_spec.containers:
if container.name in container_names:
volume_mounts = container.volume_mounts or []
volume_mounts.append(
client.V1VolumeMount(
name=volume_name,
mount_path=mount_path,
read_only=read_only,
)
)
container.volume_mounts = volume_mounts

# Also check initContainers.
for container in pod_spec.init_containers or []:
if container.name in container_names:
volume_mounts = container.volume_mounts or []
volume_mounts.append(
client.V1VolumeMount(
name=volume_name,
mount_path=mount_path,
read_only=read_only,
)
)
container.volume_mounts = volume_mounts

return self

def _compile_config(self) -> dict[str, Any]:
"""Compiles the JobSet configuration into a dictionary."""
with client.ApiClient() as api_client:
Expand Down
Loading
Loading