From 18375758a784852cfccf6b99767e3ddabb29688e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 14 May 2026 17:16:48 +0900 Subject: [PATCH] fix bitsandbytes compile tests for flux. --- tests/models/testing_utils/quantization.py | 4 +- .../test_models_transformer_flux.py | 52 ++++++++++++++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 30d44a92c425..13eaaccdbf82 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -1187,7 +1187,7 @@ def _test_torch_compile(self, config_kwargs): model.to(torch_device) model.eval() - model = torch.compile(model, fullgraph=True) + model.compile(fullgraph=True) with torch._dynamo.config.patch(error_on_recompile=True): inputs = self.get_dummy_inputs() @@ -1219,7 +1219,7 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False "use_stream": use_stream, } model.enable_group_offload(**group_offload_kwargs) - model = torch.compile(model) + model.compile() inputs = self.get_dummy_inputs() output = model(**inputs, return_dict=False)[0] diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 840eaa338430..e45dc5177c64 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile from typing import Any import pytest import torch -from diffusers import FluxTransformer2DModel +from diffusers import BitsAndBytesConfig, FluxTransformer2DModel from diffusers.models.embeddings import ImageProjection from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor from diffusers.utils.torch_utils import randn_tensor @@ -440,10 +441,57 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo """ModelOpt + compile tests for Flux Transformer.""" -@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes") class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin): """BitsAndBytes + compile tests for Flux Transformer.""" + def get_init_dict(self) -> dict[str, int | list[int]]: + # Dims must be multiples of 64 (bnb 4bit blocksize) so single-token activations + # don't trigger the runtime `warn()` inside bnb.matmul_4bit that breaks fullgraph compile. + return { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 32, + "num_attention_heads": 2, + "joint_attention_dim": 64, + "pooled_projection_dim": 64, + "axes_dims_rope": [8, 8, 16], + } + + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + inputs = super().get_dummy_inputs(batch_size=batch_size) + embedding_dim = 64 + sequence_length = inputs["encoder_hidden_states"].shape[1] + inputs["encoder_hidden_states"] = randn_tensor( + (batch_size, sequence_length, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ) + inputs["pooled_projections"] = randn_tensor( + (batch_size, embedding_dim), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + return inputs + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config_kwargs = {**config_kwargs, "bnb_4bit_compute_dtype": self.torch_dtype} + bnb_config = BitsAndBytesConfig(**config_kwargs) + base_model = self.model_class(**self.get_init_dict()).to(self.torch_dtype) + with tempfile.TemporaryDirectory() as tmp_dir: + base_model.save_pretrained(tmp_dir) + del base_model + return self.model_class.from_pretrained( + tmp_dir, quantization_config=bnb_config, torch_dtype=self.torch_dtype, **extra_kwargs + ) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile_with_group_offload(self, config_name): + # use_stream=True is required: bnb 4bit kernels read device pointers eagerly, so + # without an explicit prefetch-stream sync we hit "illegal memory access" in + # bnb/csrc/ops.cu. The pipeline-level Bnb4BitCompileTests override does the same. + self._test_torch_compile_with_group_offload(self.BNB_CONFIGS[config_name], use_stream=True) + class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): """FirstBlockCache tests for Flux Transformer."""