diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index db9817c86995..2b6aab59a662 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -22,6 +22,7 @@ import torch.multiprocessing as mp from diffusers.models._modeling_parallel import ContextParallelConfig +from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry from ...testing_utils import ( is_context_parallel, @@ -160,16 +161,21 @@ def _custom_mesh_worker( @require_torch_multi_accelerator class ContextParallelTesterMixin: @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type): + def test_context_parallel_inference(self, cp_type, batch_size: int = 1): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} @@ -194,6 +200,11 @@ def test_context_parallel_inference(self, cp_type): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_batch_inputs(self, cp_type): + self.test_context_parallel_inference(cp_type, batch_size=2) + @pytest.mark.parametrize( "cp_type,mesh_shape,mesh_dim_names", [ @@ -209,6 +220,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()} diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 24be833d0ed2..a15b7be50b97 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -150,8 +150,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 8], } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: height = width = 4 num_latent_channels = 4 num_image_channels = 3 diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index a109f603411d..77b5f1b86e59 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -90,8 +90,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 4, 4], } - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = 4 sequence_length = 48 embedding_dim = 32 diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 713a1bec70a5..5b45577f2dff 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -14,6 +14,7 @@ import warnings +import pytest import torch from diffusers import QwenImageTransformer2DModel @@ -77,8 +78,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": (8, 4, 4), } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = embedding_dim = 16 height = width = 4 sequence_length = 8 @@ -106,9 +106,10 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): - def test_infers_text_seq_len_from_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_infers_text_seq_len_from_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() @@ -122,7 +123,7 @@ def test_infers_text_seq_len_from_mask(self): assert isinstance(per_sample_len, torch.Tensor) assert int(per_sample_len.max().item()) == 2 assert normalized_mask.dtype == torch.bool - assert normalized_mask.sum().item() == 2 + assert normalized_mask.sum().item() == 2 * batch_size assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] inputs["encoder_hidden_states_mask"] = normalized_mask @@ -139,7 +140,7 @@ def test_infers_text_seq_len_from_mask(self): ) assert int(per_sample_len2.max().item()) == 8 - assert normalized_mask2.sum().item() == 5 + assert normalized_mask2.sum().item() == 5 * batch_size rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None @@ -149,9 +150,10 @@ def test_infers_text_seq_len_from_mask(self): assert per_sample_len_none is None assert normalized_mask_none is None - def test_non_contiguous_attention_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_non_contiguous_attention_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()