Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def _test_torch_compile(self, config_kwargs):
model.to(torch_device)
model.eval()

model = torch.compile(model, fullgraph=True)
model.compile(fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
inputs = self.get_dummy_inputs()
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
"use_stream": use_stream,
}
model.enable_group_offload(**group_offload_kwargs)
model = torch.compile(model)
model.compile()

inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
Expand Down
52 changes: 50 additions & 2 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from typing import Any

import pytest
import torch

from diffusers import FluxTransformer2DModel
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -440,10 +441,57 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo
"""ModelOpt + compile tests for Flux Transformer."""


@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes")
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
"""BitsAndBytes + compile tests for Flux Transformer."""

def get_init_dict(self) -> dict[str, int | list[int]]:
# Dims must be multiples of 64 (bnb 4bit blocksize) so single-token activations
# don't trigger the runtime `warn()` inside bnb.matmul_4bit that breaks fullgraph compile.
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 32,
"num_attention_heads": 2,
"joint_attention_dim": 64,
"pooled_projection_dim": 64,
"axes_dims_rope": [8, 8, 16],
}

def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
inputs = super().get_dummy_inputs(batch_size=batch_size)
embedding_dim = 64
sequence_length = inputs["encoder_hidden_states"].shape[1]
inputs["encoder_hidden_states"] = randn_tensor(
(batch_size, sequence_length, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
)
inputs["pooled_projections"] = randn_tensor(
(batch_size, embedding_dim), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
return inputs

def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config_kwargs = {**config_kwargs, "bnb_4bit_compute_dtype": self.torch_dtype}
bnb_config = BitsAndBytesConfig(**config_kwargs)
base_model = self.model_class(**self.get_init_dict()).to(self.torch_dtype)
with tempfile.TemporaryDirectory() as tmp_dir:
base_model.save_pretrained(tmp_dir)
del base_model
return self.model_class.from_pretrained(
tmp_dir, quantization_config=bnb_config, torch_dtype=self.torch_dtype, **extra_kwargs
)

@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
def test_bnb_torch_compile_with_group_offload(self, config_name):
# use_stream=True is required: bnb 4bit kernels read device pointers eagerly, so
# without an explicit prefetch-stream sync we hit "illegal memory access" in
# bnb/csrc/ops.cu. The pipeline-level Bnb4BitCompileTests override does the same.
self._test_torch_compile_with_group_offload(self.BNB_CONFIGS[config_name], use_stream=True)


class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
Expand Down
Loading