diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..b533bef35414 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" is_npu = sample.device.type == "npu" + is_neuron = sample.device.type == "neuron" if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..bbee2189c22f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,6 +68,7 @@ is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2248,7 +2249,6 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 - class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2f6b105702e8..fdda2547f09e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1092,7 +1092,11 @@ def __call__( ) # 4. Prepare timesteps - if XLA_AVAILABLE: + # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where + # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() + # are incompatible with static-graph compilation. + is_neuron_device = hasattr(device, "type") and device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -1195,15 +1199,23 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. + # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. + if is_neuron_device: + latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) + else: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds + # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support + # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. + t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( latent_model_input, - t, + t_unet, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, @@ -1222,7 +1234,13 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. + if is_neuron_device: + latents = self.scheduler.step( + noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False + )[0].to(device) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..8a86cf4f4151 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -110,6 +110,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..2ce989626b3d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -249,6 +250,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available @@ -579,6 +584,10 @@ def is_av_available(): """ +TORCH_NEURONX_IMPORT_ERROR = """ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -609,6 +618,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 619a37034949..eefe52c477a6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -46,6 +46,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -113,6 +114,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 8a48316bf3dd..55fee1d3249e 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -22,7 +22,13 @@ from typing import Callable, ParamSpec, TypeVar from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) T = TypeVar("T") @@ -33,12 +39,13 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -46,6 +53,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -53,6 +61,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -60,6 +69,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -67,6 +77,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -74,6 +85,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -81,6 +93,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -169,11 +182,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -294,6 +311,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu" diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 037a9f44f31e..0aa6812c6b25 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -28,6 +28,8 @@ PixArtTransformer2DModel, ) +from diffusers.utils.import_utils import is_torch_neuronx_available + from ...testing_utils import ( backend_empty_cache, enable_full_determinism, @@ -291,7 +293,9 @@ def test_pixart_1024(self): expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_512(self): generator = torch.Generator("cpu").manual_seed(0) @@ -307,7 +311,9 @@ def test_pixart_512(self): expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958]) max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) - self.assertLessEqual(max_diff, 1e-4) + # Neuron uses bfloat16 internally which has lower precision than float16 on CUDA + atol = 1e-2 if is_torch_neuronx_available() else 1e-4 + self.assertLessEqual(max_diff, atol) def test_pixart_1024_without_resolution_binning(self): generator = torch.manual_seed(0) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 53c1b8aa26ce..778381cf31e0 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -45,6 +45,7 @@ is_peft_available, is_timm_available, is_torch_available, + is_torch_neuronx_available, is_torch_version, is_torchao_available, is_torchsde_available, @@ -109,6 +110,8 @@ torch_device = "cuda" elif torch.xpu.is_available(): torch_device = "xpu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + torch_device = torch.neuron.current_device() else: torch_device = "cpu" is_torch_higher_equal_than_1_12 = version.parse( @@ -1427,6 +1430,15 @@ def _is_torch_fp64_available(device): # Behaviour flags BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + # Neuron device key: torch.neuron.current_device() returns an int (e.g. 0). + # We capture it once at import time if torch_neuronx is available so we can add it + # to all dispatch tables using the same key that torch_device is set to. + _neuron_device = ( + torch.neuron.current_device() + if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available()) + else None + ) + # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, @@ -1478,13 +1490,19 @@ def _is_torch_fp64_available(device): "default": None, } + if _neuron_device is not None: + BACKEND_EMPTY_CACHE[_neuron_device] = None + BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count + BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed + BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None + BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None + BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0 + BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize + # This dispatches a defined function according to the accelerator from the function definitions. def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs): - if device not in dispatch_table: - return dispatch_table["default"](*args, **kwargs) - - fn = dispatch_table[device] + fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"] # Some device agnostic functions return values. Need to guard against 'None' instead at # user level