Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,8 @@ def cuda(self, *args, **kwargs):
def to(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled

fp32_modules = self._keep_in_fp32_modules or []

device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs

Expand All @@ -1528,6 +1530,11 @@ def to(self, *args, **kwargs):
dtype_present_in_args = True
break

if dtype_present_in_args and fp32_modules is not None:
logger.warning(
f"There are modules in {self.__class__.__name__} that should be kept in float32: {fp32_modules}. Casting directly with `to()` can lead to inconsistent results; set `torch_dtype` in `from_pretrained()` instead to keep these modules in float32."
)

if getattr(self, "is_quantized", False):
if dtype_present_in_args:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import gc

import numpy as np
import pytest
import torch

from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
Expand Down Expand Up @@ -86,7 +87,12 @@ def get_dummy_inputs(self) -> dict:


class TestConsistencyDecoderVAE(ConsistencyDecoderVAETesterConfig, ModelTesterMixin):
pass
@pytest.mark.skip(
reason="The consistency decoder samples noise (`randn_tensor`) during `decode`, so two forward passes "
"diverge regardless of dtype. This makes a save/load output comparison non-deterministic."
)
def test_from_save_pretrained_dtype_inference(self, *args, **kwargs):
pass


class TestConsistencyDecoderVAETraining(ConsistencyDecoderVAETesterConfig, TrainingTesterMixin):
Expand Down
11 changes: 11 additions & 0 deletions tests/models/autoencoders/test_models_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ def get_dummy_inputs(self) -> dict:


class TestVQModel(VQModelTesterConfig, ModelTesterMixin):
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# The reference and reloaded models hold identical weights, so any output difference is
# half-precision kernel nondeterminism between the two module instances rather than a save/load
# fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs.
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=1e-3)

def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
assert model is not None
Expand Down
4 changes: 4 additions & 0 deletions tests/models/controlnets/test_models_controlnet_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def test_determinism(self):
def test_from_save_pretrained(self):
super().test_from_save_pretrained()

@pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.")
def test_from_save_pretrained_dtype_inference(self, *args, **kwargs):
super().test_from_save_pretrained_dtype_inference(*args, **kwargs)

@pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
Expand Down
43 changes: 42 additions & 1 deletion tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,29 @@ def test_keep_in_fp32_modules(self, tmp_path):
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"

def test_to_keep_in_fp32_modules_warns(self, caplog):
fp32_modules = self.model_class._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")

model = self.model_class(**self.get_init_dict())

logger_name = "diffusers.models.modeling_utils"
logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
model.to(torch.float16)
finally:
logging.disable_propagation()

expected_message = (
f"There are modules in {model.__class__.__name__} that should be kept in float32: "
f"{fp32_modules}. Casting directly with `to()` can lead to inconsistent results; set "
f"`torch_dtype` in `from_pretrained()` instead to keep these modules in float32."
)
assert expected_message in caplog.text

@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
Expand All @@ -519,7 +542,25 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []

model.to(dtype).save_pretrained(tmp_path)
# Build the reference model with the same mixed-precision layout that `from_pretrained` enforces, so
# the comparison reflects real save/load fidelity:
# - `_keep_in_fp32_modules` stay in fp32 while everything else is cast to `dtype`;
# - non-persistent buffers (e.g. fp32 RoPE `inv_freq`) are left untouched, because they are not part
# of the checkpoint and are regenerated by `__init__` on load. Truncating them here would make the
# reference diverge from the reloaded model for reasons unrelated to save/load.
persistent_tensor_names = {name for name, _ in named_persistent_module_tensors(model, recurse=True)}

def keep_in_fp32(name):
return any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules)

for name, param in model.named_parameters():
param.data = param.data.to(torch.float32 if keep_in_fp32(name) else dtype)
for name, buf in model.named_buffers():
if not buf.is_floating_point() or name not in persistent_tensor_names:
continue
buf.data = buf.data.to(torch.float32 if keep_in_fp32(name) else dtype)

model.save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)

for name, param in model_loaded.named_parameters():
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import AnyFlowTransformer3DModel
Expand Down Expand Up @@ -100,12 +99,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestAnyFlowTransformer3D(AnyFlowTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow Transformer 3D (bidirectional variant)."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowTransformer3DMemory(AnyFlowTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]:
class TestAnyFlowFARTransformer3D(AnyFlowFARTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow FAR causal Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowFARTransformer3DMemory(AnyFlowFARTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow FAR Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Helios Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Helios Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import Ideogram4Transformer2DModel
Expand Down Expand Up @@ -142,14 +141,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Ideogram 4 Transformer."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory
# .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any
# meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype
# and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Ideogram 4 Transformer."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import JoyImageEditTransformer3DModel
Expand Down Expand Up @@ -86,9 +85,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:


class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
pytest.skip("Tolerance requirements too high for meaningful test")
pass


class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin):
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import WanTransformer3DModel
Expand Down Expand Up @@ -106,12 +105,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,6 @@ def test_output(self, base_model_output):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(base_model_output, expected_output_shape=expected_output_shape)

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")

def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
Expand Down
Loading