From 56e3a55bbe85b1bdf897b5b277a03c4d2a5705ff Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 24 Jun 2026 16:36:37 -0700 Subject: [PATCH] Fix JAX device compatibility in profiling.py for elasticity. PiperOrigin-RevId: 937626495 --- pathwaysutils/profiling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 029fc83..5a54586 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -27,6 +27,7 @@ import fastapi import jax from jax import numpy as jnp +from jax.extend import backend from pathwaysutils import plugin_executable import requests import uvicorn @@ -77,7 +78,7 @@ def call_profile_executable(self) -> None: jax.sharding, "make_single_device_sharding", jax.sharding.SingleDeviceSharding, - )(jax.devices()[0]) + )(backend.get_default_device()) ] else: out_avals = ()