diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41b0f689d9a4..650ab4c7fcda 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1509,6 +1509,8 @@ def cuda(self, *args, **kwargs): def to(self, *args, **kwargs): from ..hooks.group_offloading import _is_group_offload_enabled + fp32_modules = self._keep_in_fp32_modules or [] + device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs dtype_present_in_args = "dtype" in kwargs @@ -1528,6 +1530,11 @@ def to(self, *args, **kwargs): dtype_present_in_args = True break + if dtype_present_in_args and fp32_modules is not None: + logger.warning( + f"There are modules in {self.__class__.__name__} that should be kept in float32: {fp32_modules}. Casting directly with `to()` can lead to inconsistent results; set `torch_dtype` in `from_pretrained()` instead to keep these modules in float32." + ) + if getattr(self, "is_quantized", False): if dtype_present_in_args: raise ValueError( diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 906baa60a9dc..f594b722cd3d 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -16,6 +16,7 @@ import gc import numpy as np +import pytest import torch from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline @@ -86,7 +87,12 @@ def get_dummy_inputs(self) -> dict: class TestConsistencyDecoderVAE(ConsistencyDecoderVAETesterConfig, ModelTesterMixin): - pass + @pytest.mark.skip( + reason="The consistency decoder samples noise (`randn_tensor`) during `decode`, so two forward passes " + "diverge regardless of dtype. This makes a save/load output comparison non-deterministic." + ) + def test_from_save_pretrained_dtype_inference(self, *args, **kwargs): + pass class TestConsistencyDecoderVAETraining(ConsistencyDecoderVAETesterConfig, TrainingTesterMixin): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index ce1606f0e859..bc60b7df5b1b 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -64,6 +64,17 @@ def get_dummy_inputs(self) -> dict: class TestVQModel(VQModelTesterConfig, ModelTesterMixin): + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # The reference and reloaded models hold identical weights, so any output difference is + # half-precision kernel nondeterminism between the two module instances rather than a save/load + # fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs. + super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=1e-3) + def test_from_pretrained_hub(self): model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) assert model is not None diff --git a/tests/models/controlnets/test_models_controlnet_cosmos.py b/tests/models/controlnets/test_models_controlnet_cosmos.py index e7ea6362213d..870650b1a6bf 100644 --- a/tests/models/controlnets/test_models_controlnet_cosmos.py +++ b/tests/models/controlnets/test_models_controlnet_cosmos.py @@ -251,6 +251,10 @@ def test_determinism(self): def test_from_save_pretrained(self): super().test_from_save_pretrained() + @pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.") + def test_from_save_pretrained_dtype_inference(self, *args, **kwargs): + super().test_from_save_pretrained_dtype_inference(*args, **kwargs) + @pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.") def test_from_save_pretrained_variant(self): super().test_from_save_pretrained_variant() diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index eb120567f3d1..3cb5ac5c8e79 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -507,6 +507,29 @@ def test_keep_in_fp32_modules(self, tmp_path): else: assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}" + def test_to_keep_in_fp32_modules_warns(self, caplog): + fp32_modules = self.model_class._keep_in_fp32_modules + if fp32_modules is None or len(fp32_modules) == 0: + pytest.skip("Model does not have _keep_in_fp32_modules defined.") + + model = self.model_class(**self.get_init_dict()) + + logger_name = "diffusers.models.modeling_utils" + logging.enable_propagation() + try: + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() + model.to(torch.float16) + finally: + logging.disable_propagation() + + expected_message = ( + f"There are modules in {model.__class__.__name__} that should be kept in float32: " + f"{fp32_modules}. Casting directly with `to()` can lead to inconsistent results; set " + f"`torch_dtype` in `from_pretrained()` instead to keep these modules in float32." + ) + assert expected_message in caplog.text + @require_accelerator @pytest.mark.skipif( torch_device not in ["cuda", "xpu"], @@ -519,7 +542,25 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, model.to(torch_device) fp32_modules = model._keep_in_fp32_modules or [] - model.to(dtype).save_pretrained(tmp_path) + # Build the reference model with the same mixed-precision layout that `from_pretrained` enforces, so + # the comparison reflects real save/load fidelity: + # - `_keep_in_fp32_modules` stay in fp32 while everything else is cast to `dtype`; + # - non-persistent buffers (e.g. fp32 RoPE `inv_freq`) are left untouched, because they are not part + # of the checkpoint and are regenerated by `__init__` on load. Truncating them here would make the + # reference diverge from the reloaded model for reasons unrelated to save/load. + persistent_tensor_names = {name for name, _ in named_persistent_module_tensors(model, recurse=True)} + + def keep_in_fp32(name): + return any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules) + + for name, param in model.named_parameters(): + param.data = param.data.to(torch.float32 if keep_in_fp32(name) else dtype) + for name, buf in model.named_buffers(): + if not buf.is_floating_point() or name not in persistent_tensor_names: + continue + buf.data = buf.data.to(torch.float32 if keep_in_fp32(name) else dtype) + + model.save_pretrained(tmp_path) model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) for name, param in model_loaded.named_parameters(): diff --git a/tests/models/transformers/test_models_transformer_anyflow.py b/tests/models/transformers/test_models_transformer_anyflow.py index df72567a7455..5011222f17c9 100644 --- a/tests/models/transformers/test_models_transformer_anyflow.py +++ b/tests/models/transformers/test_models_transformer_anyflow.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import AnyFlowTransformer3DModel @@ -100,12 +99,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestAnyFlowTransformer3D(AnyFlowTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for AnyFlow Transformer 3D (bidirectional variant).""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestAnyFlowTransformer3DMemory(AnyFlowTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for AnyFlow Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_anyflow_far.py b/tests/models/transformers/test_models_transformer_anyflow_far.py index d7ed471fa875..b1b9d155b752 100644 --- a/tests/models/transformers/test_models_transformer_anyflow_far.py +++ b/tests/models/transformers/test_models_transformer_anyflow_far.py @@ -113,12 +113,6 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]: class TestAnyFlowFARTransformer3D(AnyFlowFARTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for AnyFlow FAR causal Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestAnyFlowFARTransformer3DMemory(AnyFlowFARTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for AnyFlow FAR Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index c365c258e596..927581b095e8 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -135,12 +135,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Helios Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Helios Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_ideogram4.py b/tests/models/transformers/test_models_transformer_ideogram4.py index d8e7318d501d..0d29f507b7e1 100644 --- a/tests/models/transformers/test_models_transformer_ideogram4.py +++ b/tests/models/transformers/test_models_transformer_ideogram4.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import Ideogram4Transformer2DModel @@ -142,14 +141,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin): """Core model tests for Ideogram 4 Transformer.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory - # .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any - # meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype - # and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for Ideogram 4 Transformer.""" diff --git a/tests/models/transformers/test_models_transformer_joyimage.py b/tests/models/transformers/test_models_transformer_joyimage.py index c464a44c29b5..45d15b2d470a 100644 --- a/tests/models/transformers/test_models_transformer_joyimage.py +++ b/tests/models/transformers/test_models_transformer_joyimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import JoyImageEditTransformer3DModel @@ -86,9 +85,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin): - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - pytest.skip("Tolerance requirements too high for meaningful test") + pass class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 60bba9dfbe18..aacbf542b548 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import WanTransformer3DModel @@ -106,12 +105,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Wan Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index bd751974637b..241b6980d0f4 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -152,12 +152,6 @@ def test_output(self, base_model_output): expected_output_shape = (1, 4, 21, 16, 16) super().test_output(base_model_output, expected_output_shape=expected_output_shape) - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Wan Animate Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_wan_vace.py b/tests/models/transformers/test_models_transformer_wan_vace.py index 1cc829f88b9d..503569662b14 100644 --- a/tests/models/transformers/test_models_transformer_wan_vace.py +++ b/tests/models/transformers/test_models_transformer_wan_vace.py @@ -117,12 +117,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan VACE Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - def test_model_parallelism(self, tmp_path): # Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow pytest.skip("Model parallelism not yet supported for WanVACE")