diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 029fc83..5a54586 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -27,6 +27,7 @@ import fastapi import jax from jax import numpy as jnp +from jax.extend import backend from pathwaysutils import plugin_executable import requests import uvicorn @@ -77,7 +78,7 @@ def call_profile_executable(self) -> None: jax.sharding, "make_single_device_sharding", jax.sharding.SingleDeviceSharding, - )(jax.devices()[0]) + )(backend.get_default_device()) ] else: out_avals = ()