From fe38d7760313cfab9e5d40854b5c6401c9622a82 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 15 Apr 2026 11:48:47 +0200 Subject: [PATCH 1/3] update --- .../transformers/transformer_wan_animate.py | 6 ++- tests/models/testing_utils/common.py | 5 ++ tests/models/testing_utils/quantization.py | 47 +++++++------------ .../test_models_transformer_flux.py | 45 ++++++++++++++++++ 4 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index 166b0b4c2721..dfea5a71353d 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -445,10 +445,14 @@ def __call__( # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim B, T, N, C = encoder_hidden_states.shape + # Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only + # accepts 2D/3D inputs and would otherwise fail on this 4D activation. + encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C] + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D] - key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv] + key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv] value = value.view(B, T, N, attn.heads, -1) query = attn.norm_q(query) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 7036bb16203d..ba060b3b120d 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -205,6 +205,11 @@ def pretrained_model_kwargs(self) -> Dict[str, Any]: """Additional kwargs to pass to from_pretrained (e.g., subfolder, variant).""" return {} + @property + def torch_dtype(self) -> torch.dtype: + """Compute dtype used to build dummy inputs and cast inputs where needed.""" + return torch.float32 + @property def output_shape(self) -> Optional[tuple]: """Expected output shape for output validation tests.""" diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 4403cacc6966..1aab0b240148 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -359,15 +359,7 @@ def _test_dequantize(self, config_kwargs): if isinstance(module, torch.nn.Linear): assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" - # Get model dtype from first parameter - model_dtype = next(model.parameters()).dtype - inputs = self.get_dummy_inputs() - # Cast inputs to model dtype - inputs = { - k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v - for k, v in inputs.items() - } output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None after dequantization" assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" @@ -575,33 +567,28 @@ def test_bnb_original_dtype(self): @torch.no_grad() def test_bnb_keep_modules_in_fp32(self): - if not hasattr(self.model_class, "_keep_in_fp32_modules"): - pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") + fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None) + if not fp32_modules: + pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules") config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"] - original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None) - self.model_class._keep_in_fp32_modules = ["proj_out"] - - try: - model = self._create_quantized_model(config_kwargs) + model = self._create_quantized_model(config_kwargs) + model.to(torch_device) - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert module.weight.dtype == torch.float32, ( - f"Module {name} should be FP32 but is {module.weight.dtype}" - ) - else: - assert module.weight.dtype == torch.uint8, ( - f"Module {name} should be uint8 but is {module.weight.dtype}" - ) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in fp32_modules): + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) + else: + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) - inputs = self.get_dummy_inputs() - _ = model(**inputs) - finally: - if original_fp32_modules is not None: - self.model_class._keep_in_fp32_modules = original_fp32_modules + inputs = self.get_dummy_inputs() + _ = model(**inputs) def test_bnb_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index a15b7be50b97..641902e23071 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -320,6 +320,51 @@ def pretrained_model_name_or_path(self): class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): """BitsAndBytes quantization tests for Flux Transformer.""" + @property + def torch_dtype(self): + return torch.float16 + + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "pooled_projections": randn_tensor( + (batch_size, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "img_ids": randn_tensor( + (height * width, num_image_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "txt_ids": randn_tensor( + (sequence_length, num_image_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), + } + class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): """Quanto quantization tests for Flux Transformer.""" From 80ad468e2b3c13a3d5d08301bce9fd127b35ef48 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 16 Apr 2026 09:40:38 +0200 Subject: [PATCH 2/3] update --- .../test_models_transformer_flux.py | 68 ++++++------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 641902e23071..03c19a4700e0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -159,21 +159,36 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor( - (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + (batch_size, height * width, num_latent_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, ), "encoder_hidden_states": randn_tensor( - (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + (batch_size, sequence_length, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, ), "pooled_projections": randn_tensor( - (batch_size, embedding_dim), generator=self.generator, device=torch_device + (batch_size, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, ), "img_ids": randn_tensor( - (height * width, num_image_channels), generator=self.generator, device=torch_device + (height * width, num_image_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, ), "txt_ids": randn_tensor( - (sequence_length, num_image_channels), generator=self.generator, device=torch_device + (sequence_length, num_image_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, ), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size), } @@ -324,47 +339,6 @@ class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesT def torch_dtype(self): return torch.float16 - def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: - height = width = 4 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 - - return { - "hidden_states": randn_tensor( - (batch_size, height * width, num_latent_channels), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ), - "encoder_hidden_states": randn_tensor( - (batch_size, sequence_length, embedding_dim), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ), - "pooled_projections": randn_tensor( - (batch_size, embedding_dim), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ), - "img_ids": randn_tensor( - (height * width, num_image_channels), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ), - "txt_ids": randn_tensor( - (sequence_length, num_image_channels), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ), - "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), - } - class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): """Quanto quantization tests for Flux Transformer.""" From 97ee35f82611f106177f0a0ccd2832c63707ff3d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 16 Apr 2026 11:52:19 +0200 Subject: [PATCH 3/3] update --- src/diffusers/models/transformers/transformer_wan_vace.py | 4 ++-- tests/models/transformers/test_models_transformer_wan.py | 2 ++ .../transformers/test_models_transformer_wan_animate.py | 5 +++++ .../models/transformers/test_models_transformer_wan_vace.py | 3 +++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 7c2e205ee3ed..46caaf579ffd 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -331,7 +331,7 @@ def forward( ) if i in self.config.vace_layers: control_hint, scale = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * scale + hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale else: # Prepare VACE hints control_hidden_states_list = [] @@ -346,7 +346,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) if i in self.config.vace_layers: control_hint, scale = control_hidden_states_list.pop() - hidden_states = hidden_states + control_hint * scale + hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale # 6. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 26b0ac946434..60bba9dfbe18 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -91,11 +91,13 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "encoder_hidden_states": randn_tensor( (batch_size, sequence_length, text_encoder_embedding_dim), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), } diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index ac0ef0698c63..df67e55c9b5d 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -113,27 +113,32 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: (batch_size, 2 * num_channels + 4, num_frames + 1, height, width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), "encoder_hidden_states": randn_tensor( (batch_size, sequence_length, text_encoder_embedding_dim), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "encoder_hidden_states_image": randn_tensor( (batch_size, clip_seq_len, clip_dim), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "pose_hidden_states": randn_tensor( (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "face_pixel_values": randn_tensor( (batch_size, 3, inference_segment_length, face_height, face_width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), } diff --git a/tests/models/transformers/test_models_transformer_wan_vace.py b/tests/models/transformers/test_models_transformer_wan_vace.py index 5ab51bbb9003..1cc829f88b9d 100644 --- a/tests/models/transformers/test_models_transformer_wan_vace.py +++ b/tests/models/transformers/test_models_transformer_wan_vace.py @@ -96,16 +96,19 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "encoder_hidden_states": randn_tensor( (batch_size, sequence_length, text_encoder_embedding_dim), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "control_hidden_states": randn_tensor( (batch_size, vace_in_channels, num_frames, height, width), generator=self.generator, device=torch_device, + dtype=self.torch_dtype, ), "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), }