Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.multiprocessing as mp

from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry

from ...testing_utils import (
is_context_parallel,
Expand Down Expand Up @@ -160,16 +161,21 @@ def _custom_mesh_worker(
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")

if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)

# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
Expand All @@ -194,6 +200,11 @@ def test_context_parallel_inference(self, cp_type):
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case for Flux as well?

Also, let's always require get_dummy_inputs() to have batch_size. So, we can safely remove the inspect stuff from here and elsewhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhtmike 👀

Copy link
Contributor Author

@zhtmike zhtmike Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case for Flux as well?

yes. Flux works fine for bs > 2, will drop xfail once the qwenimage is fixed.

Also, let's always require get_dummy_inputs() to have batch_size. So, we can safely remove the inspect stuff from here and elsewhere.

Done. Add batch size args to newly refactored model: flux & flux2. Tests are passed

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)

@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
Expand All @@ -209,6 +220,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
Expand Down
3 changes: 1 addition & 2 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
"axes_dims_rope": [4, 4, 8],
}

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
height = width = 4
num_latent_channels = 4
num_image_channels = 3
Expand Down
3 changes: 1 addition & 2 deletions tests/models/transformers/test_models_transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
"axes_dims_rope": [4, 4, 4, 4],
}

def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
Expand Down
18 changes: 10 additions & 8 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import warnings

import pytest
import torch

from diffusers import QwenImageTransformer2DModel
Expand Down Expand Up @@ -77,8 +78,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]:
"axes_dims_rope": (8, 4, 4),
}

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
Expand Down Expand Up @@ -106,9 +106,10 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:


class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_infers_text_seq_len_from_mask(self, batch_size):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
inputs = self.get_dummy_inputs(batch_size=batch_size)
model = self.model_class(**init_dict).to(torch_device)

encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
Expand All @@ -122,7 +123,7 @@ def test_infers_text_seq_len_from_mask(self):
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert normalized_mask.sum().item() == 2 * batch_size
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]

inputs["encoder_hidden_states_mask"] = normalized_mask
Expand All @@ -139,7 +140,7 @@ def test_infers_text_seq_len_from_mask(self):
)

assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
assert normalized_mask2.sum().item() == 5 * batch_size

rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
Expand All @@ -149,9 +150,10 @@ def test_infers_text_seq_len_from_mask(self):
assert per_sample_len_none is None
assert normalized_mask_none is None

def test_non_contiguous_attention_mask(self):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_non_contiguous_attention_mask(self, batch_size):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
inputs = self.get_dummy_inputs(batch_size=batch_size)
model = self.model_class(**init_dict).to(torch_device)

encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
Expand Down
Loading