Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float |
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
is_neuron = sample.device.type == "neuron"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints

import httpx
import numpy as np
Expand Down Expand Up @@ -68,6 +68,7 @@
is_transformers_version,
logging,
numpy_to_pil,
requires_backends,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
Expand Down Expand Up @@ -2248,7 +2249,6 @@ def _is_pipeline_device_mapped(self):

return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1


class StableDiffusionMixin:
r"""
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,11 @@ def __call__(
)

# 4. Prepare timesteps
if XLA_AVAILABLE:
# Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where
# dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep()
# are incompatible with static-graph compilation.
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
if XLA_AVAILABLE or is_neuron_device:
timestep_device = "cpu"
else:
timestep_device = device
Expand Down Expand Up @@ -1195,15 +1199,23 @@ def __call__(
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region.
# index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs.
if is_neuron_device:
latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device)
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support
# int64 ops; the compiled UNet graph requires a float32 timestep input on-device.
t_unet = t.to(torch.float32).to(device) if is_neuron_device else t
noise_pred = self.unet(
latent_model_input,
t,
t_unet,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
Expand All @@ -1222,7 +1234,13 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device.
if is_neuron_device:
latents = self.scheduler.step(
noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False
)[0].to(device)
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
is_timm_available,
is_torch_available,
is_torch_mlu_available,
is_torch_neuronx_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
Expand Down Expand Up @@ -249,6 +250,10 @@ def is_torch_mlu_available():
return _torch_mlu_available


def is_torch_neuronx_available():
return _torch_neuronx_available


def is_flax_available():
return _flax_available

Expand Down Expand Up @@ -579,6 +584,10 @@ def is_av_available():
"""


TORCH_NEURONX_IMPORT_ERROR = """
{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/
"""

BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
Expand Down Expand Up @@ -609,6 +618,7 @@ def is_av_available():
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)),
]
)

Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
is_peft_available,
is_timm_available,
is_torch_available,
is_torch_neuronx_available,
is_torch_version,
is_torchao_available,
is_torchsde_available,
Expand Down Expand Up @@ -113,6 +114,8 @@
torch_device = "cuda"
elif torch.xpu.is_available():
torch_device = "xpu"
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
torch_device = torch.neuron.current_device()
else:
torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
Expand Down
25 changes: 22 additions & 3 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from typing import Callable, ParamSpec, TypeVar

from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
from .import_utils import (
is_torch_available,
is_torch_mlu_available,
is_torch_neuronx_available,
is_torch_npu_available,
is_torch_version,
)


T = TypeVar("T")
Expand All @@ -33,54 +39,61 @@
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift

BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"neuron": None,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(),
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"neuron": torch.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"neuron": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"neuron": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"neuron": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(),
"default": None,
}
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -169,11 +182,15 @@ def randn_tensor(
layout = layout or torch.strided
device = device or torch.device("cpu")

# Neuron (XLA) does not support creating random tensors directly on device; always use CPU
if device.type == "neuron":
rand_device = torch.device("cpu")

if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
if device.type not in ("mps", "neuron"):
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
Expand Down Expand Up @@ -294,6 +311,8 @@ def get_device():
return "mps"
elif is_torch_mlu_available():
return "mlu"
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
return "neuron"
else:
return "cpu"

Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/pixart_alpha/test_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
PixArtTransformer2DModel,
)

from diffusers.utils.import_utils import is_torch_neuronx_available

from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
Expand Down Expand Up @@ -291,7 +293,9 @@ def test_pixart_1024(self):
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])

max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
self.assertLessEqual(max_diff, 1e-4)
# Neuron uses bfloat16 internally which has lower precision than float16 on CUDA
atol = 1e-2 if is_torch_neuronx_available() else 1e-4
self.assertLessEqual(max_diff, atol)

def test_pixart_512(self):
generator = torch.Generator("cpu").manual_seed(0)
Expand All @@ -307,7 +311,9 @@ def test_pixart_512(self):
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])

max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
self.assertLessEqual(max_diff, 1e-4)
# Neuron uses bfloat16 internally which has lower precision than float16 on CUDA
atol = 1e-2 if is_torch_neuronx_available() else 1e-4
self.assertLessEqual(max_diff, atol)

def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
Expand Down
26 changes: 22 additions & 4 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_peft_available,
is_timm_available,
is_torch_available,
is_torch_neuronx_available,
is_torch_version,
is_torchao_available,
is_torchsde_available,
Expand Down Expand Up @@ -109,6 +110,8 @@
torch_device = "cuda"
elif torch.xpu.is_available():
torch_device = "xpu"
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
torch_device = torch.neuron.current_device()
else:
torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
Expand Down Expand Up @@ -1427,6 +1430,15 @@ def _is_torch_fp64_available(device):
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}

# Neuron device key: torch.neuron.current_device() returns an int (e.g. 0).
# We capture it once at import time if torch_neuronx is available so we can add it
# to all dispatch tables using the same key that torch_device is set to.
_neuron_device = (
torch.neuron.current_device()
if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available())
else None
)

# Function definitions
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
Expand Down Expand Up @@ -1478,13 +1490,19 @@ def _is_torch_fp64_available(device):
"default": None,
}

if _neuron_device is not None:
BACKEND_EMPTY_CACHE[_neuron_device] = None
BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count
BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed
BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None
BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None
BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0
BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize


# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)

fn = dispatch_table[device]
fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"]

# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
Expand Down
Loading