model_transformers_shared model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules. AGENTS.md is referenced by .ai/review-rules.md but is not present in this checkout; all available referenced rule files were read.
Files reviewed:
src/diffusers/models/transformers/dual_transformer_2d.py
src/diffusers/models/transformers/prior_transformer.py
src/diffusers/models/transformers/transformer_2d.py
src/diffusers/models/transformers/transformer_temporal.py
Duplicate-search status: checked GitHub Issues/PRs for model_transformers_shared, all target class/file names, and the specific failures below. No likely duplicate found for Issues 1-4. Shap-E slow-test coverage is already reported in #13593 and is called out in Issue 5.
Issue 1: DualTransformer2DModel rejects UNet encoder-mask calls and drops masks
Affected code:
|
def forward( |
|
self, |
|
hidden_states, |
|
encoder_hidden_states, |
|
timestep=None, |
|
attention_mask=None, |
|
cross_attention_kwargs=None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
Args: |
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. |
|
When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. |
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): |
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to |
|
self-attention. |
|
timestep ( `torch.long`, *optional*): |
|
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. |
|
attention_mask (`torch.Tensor`, *optional*): |
|
Optional attention mask to be applied in Attention. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: |
|
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a |
|
`tuple`. When returning a tuple, the first element is the sample tensor. |
|
""" |
|
input_states = hidden_states |
|
|
|
encoded_states = [] |
|
tokens_start = 0 |
|
# attention_mask is not used yet |
|
for i in range(2): |
|
# for each of the two transformers, pass the corresponding condition tokens |
|
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] |
|
transformer_index = self.transformer_index_for_condition[i] |
|
encoded_state = self.transformers[transformer_index]( |
|
input_states, |
|
encoder_hidden_states=condition_state, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
encoded_states.append(encoded_state - input_states) |
Problem:
DualTransformer2DModel.forward() does not accept encoder_attention_mask, but CrossAttnDownBlock2D and related UNet blocks call transformer modules with that keyword. The same method also declares attention_mask but never forwards it to either child transformer.
Impact:
Any UNet block configured with dual_cross_attention=True fails before inference/training. Masked text/image conditioning is also silently ignored if callers invoke the wrapper directly.
Reproduction:
import torch
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D
block = CrossAttnDownBlock2D(
in_channels=4, out_channels=4, temb_channels=8, num_layers=1,
transformer_layers_per_block=1, num_attention_heads=1,
cross_attention_dim=8, dual_cross_attention=True, resnet_groups=1,
)
block(
torch.randn(1, 4, 4, 4),
temb=torch.randn(1, 8),
encoder_hidden_states=torch.randn(1, 77 + 257, 8),
)
Relevant precedent:
Transformer2DModel.forward() accepts and routes encoder_attention_mask:
|
encoder_attention_mask: torch.Tensor | None = None, |
|
encoder_attention_mask=encoder_attention_mask, |
Suggested fix:
def forward(..., attention_mask=None, encoder_attention_mask=None, ...):
...
condition_mask = None
if encoder_attention_mask is not None:
condition_mask = encoder_attention_mask[..., tokens_start : tokens_start + self.condition_lengths[i]]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
attention_mask=attention_mask,
encoder_attention_mask=condition_mask,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
Issue 2: DualTransformer2DModel is not available from the top-level lazy import
Affected code:
|
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] |
|
"AuraFlowTransformer2DModel", |
|
"AutoencoderDC", |
|
"AutoencoderKL", |
|
"AutoencoderKLAllegro", |
|
"AutoencoderKLCogVideoX", |
|
"AutoencoderKLCosmos", |
|
"AutoencoderKLFlux2", |
|
"AutoencoderKLHunyuanImage", |
|
"AutoencoderKLHunyuanImageRefiner", |
|
"AutoencoderKLHunyuanVideo", |
|
"AutoencoderKLHunyuanVideo15", |
|
"AutoencoderKLKVAE", |
|
"AutoencoderKLKVAEVideo", |
|
"AutoencoderKLLTX2Audio", |
|
"AutoencoderKLLTX2Video", |
|
"AutoencoderKLLTXVideo", |
|
"AutoencoderKLMagvit", |
|
"AutoencoderKLMochi", |
|
"AutoencoderKLQwenImage", |
|
"AutoencoderKLTemporalDecoder", |
|
"AutoencoderKLWan", |
|
"AutoencoderOobleck", |
|
"AutoencoderRAE", |
|
"AutoencoderTiny", |
|
"AutoencoderVidTok", |
|
"AutoModel", |
|
"BriaFiboTransformer2DModel", |
|
"BriaTransformer2DModel", |
|
"CacheMixin", |
|
"ChromaTransformer2DModel", |
|
"ChronoEditTransformer3DModel", |
|
"CogVideoXTransformer3DModel", |
|
"CogView3PlusTransformer2DModel", |
|
"CogView4Transformer2DModel", |
|
"ConsisIDTransformer3DModel", |
|
"ConsistencyDecoderVAE", |
|
"ContextParallelConfig", |
|
"ControlNetModel", |
|
"ControlNetUnionModel", |
|
"ControlNetXSAdapter", |
|
"CosmosControlNetModel", |
|
"CosmosTransformer3DModel", |
|
"DiTTransformer2DModel", |
|
"EasyAnimateTransformer3DModel", |
|
"ErnieImageTransformer2DModel", |
|
"Flux2Transformer2DModel", |
|
"FluxControlNetModel", |
|
"FluxMultiControlNetModel", |
|
"FluxTransformer2DModel", |
|
"GlmImageTransformer2DModel", |
|
"HeliosTransformer3DModel", |
|
"HiDreamImageTransformer2DModel", |
|
"HunyuanDiT2DControlNetModel", |
|
"HunyuanDiT2DModel", |
|
"HunyuanDiT2DMultiControlNetModel", |
|
"HunyuanImageTransformer2DModel", |
|
"HunyuanVideo15Transformer3DModel", |
|
"HunyuanVideoFramepackTransformer3DModel", |
|
"HunyuanVideoTransformer3DModel", |
|
"I2VGenXLUNet", |
|
"Kandinsky3UNet", |
|
"Kandinsky5Transformer3DModel", |
|
"LatteTransformer3DModel", |
|
"LongCatAudioDiTTransformer", |
|
"LongCatAudioDiTVae", |
|
"LongCatImageTransformer2DModel", |
|
"LTX2VideoTransformer3DModel", |
|
"LTXVideoTransformer3DModel", |
|
"Lumina2Transformer2DModel", |
|
"LuminaNextDiT2DModel", |
|
"MochiTransformer3DModel", |
|
"ModelMixin", |
|
"MotionAdapter", |
|
"MultiAdapter", |
|
"MultiControlNetModel", |
|
"NucleusMoEImageTransformer2DModel", |
|
"OmniGenTransformer2DModel", |
|
"OvisImageTransformer2DModel", |
|
"ParallelConfig", |
|
"PixArtTransformer2DModel", |
|
"PriorTransformer", |
|
"PRXTransformer2DModel", |
|
"QwenImageControlNetModel", |
|
"QwenImageMultiControlNetModel", |
|
"QwenImageTransformer2DModel", |
|
"SanaControlNetModel", |
|
"SanaTransformer2DModel", |
|
"SanaVideoTransformer3DModel", |
|
"SD3ControlNetModel", |
|
"SD3MultiControlNetModel", |
|
"SD3Transformer2DModel", |
|
"SkyReelsV2Transformer3DModel", |
|
"SparseControlNetModel", |
|
"StableAudioDiTModel", |
|
"StableCascadeUNet", |
|
"T2IAdapter", |
|
"T5FilmDecoder", |
|
"Transformer2DModel", |
|
"TransformerTemporalModel", |
|
"UNet1DModel", |
|
"UNet2DConditionModel", |
|
"UNet2DModel", |
|
"UNet3DConditionModel", |
|
"UNetControlNetXSModel", |
|
"UNetMotionModel", |
|
"UNetSpatioTemporalConditionModel", |
|
"UVit2DModel", |
|
"VQModel", |
|
"WanAnimateTransformer3DModel", |
|
"WanTransformer3DModel", |
|
"WanVACETransformer3DModel", |
|
"ZImageControlNetModel", |
|
"ZImageTransformer2DModel", |
Problem:
DualTransformer2DModel is exported from diffusers.models and diffusers.models.transformers, but not from top-level diffusers. The local review rules require model classes to be wired through both subpackage and top-level lazy imports.
Impact:
Users can import related shared transformer classes from diffusers, but this one raises ImportError.
Reproduction:
from diffusers.models import DualTransformer2DModel
print(DualTransformer2DModel.__name__)
from diffusers import DualTransformer2DModel
Relevant precedent:
The same top-level list already exports PriorTransformer, Transformer2DModel, and TransformerTemporalModel:
|
"PixArtTransformer2DModel", |
|
"PriorTransformer", |
|
"PRXTransformer2DModel", |
|
"QwenImageControlNetModel", |
|
"QwenImageMultiControlNetModel", |
|
"QwenImageTransformer2DModel", |
|
"SanaControlNetModel", |
|
"SanaTransformer2DModel", |
|
"SanaVideoTransformer3DModel", |
|
"SD3ControlNetModel", |
|
"SD3MultiControlNetModel", |
|
"SD3Transformer2DModel", |
|
"SkyReelsV2Transformer3DModel", |
|
"SparseControlNetModel", |
|
"StableAudioDiTModel", |
|
"StableCascadeUNet", |
|
"T2IAdapter", |
|
"T5FilmDecoder", |
|
"Transformer2DModel", |
|
"TransformerTemporalModel", |
Suggested fix:
# src/diffusers/__init__.py
_import_structure["models"].extend([
...
"DualTransformer2DModel",
...
])
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import (
...
DualTransformer2DModel,
...
)
Also regenerate/add the matching dummy_pt_objects.py entry for no-torch imports.
Issue 3: Temporal transformer out_channels is serialized but ignored
Affected code:
|
out_channels: int | None = None, |
|
num_layers: int = 1, |
|
dropout: float = 0.0, |
|
norm_num_groups: int = 32, |
|
cross_attention_dim: int | None = None, |
|
attention_bias: bool = False, |
|
sample_size: int | None = None, |
|
activation_fn: str = "geglu", |
|
norm_elementwise_affine: bool = True, |
|
double_self_attention: bool = True, |
|
positional_embeddings: str | None = None, |
|
num_positional_embeddings: int | None = None, |
|
): |
|
super().__init__() |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_dim = attention_head_dim |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
self.in_channels = in_channels |
|
|
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
|
|
# 3. Define transformers blocks |
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicTransformerBlock( |
|
inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
dropout=dropout, |
|
cross_attention_dim=cross_attention_dim, |
|
activation_fn=activation_fn, |
|
attention_bias=attention_bias, |
|
double_self_attention=double_self_attention, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
positional_embeddings=positional_embeddings, |
|
num_positional_embeddings=num_positional_embeddings, |
|
) |
|
for d in range(num_layers) |
|
] |
|
) |
|
|
|
self.proj_out = nn.Linear(inner_dim, in_channels) |
|
out_channels: int | None = None, |
|
num_layers: int = 1, |
|
cross_attention_dim: int | None = None, |
|
): |
|
super().__init__() |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_dim = attention_head_dim |
|
|
|
inner_dim = num_attention_heads * attention_head_dim |
|
self.inner_dim = inner_dim |
|
|
|
# 2. Define input layers |
|
self.in_channels = in_channels |
|
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) |
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
|
|
# 3. Define transformers blocks |
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicTransformerBlock( |
|
inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
) |
|
for d in range(num_layers) |
|
] |
|
) |
|
|
|
time_mix_inner_dim = inner_dim |
|
self.temporal_transformer_blocks = nn.ModuleList( |
|
[ |
|
TemporalBasicTransformerBlock( |
|
inner_dim, |
|
time_mix_inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
time_embed_dim = in_channels * 4 |
|
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) |
|
self.time_proj = Timesteps(in_channels, True, 0) |
|
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") |
|
|
|
# 4. Define output layers |
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
# TODO: should use out_channels for continuous projections |
|
self.proj_out = nn.Linear(inner_dim, in_channels) |
Problem:
Both TransformerTemporalModel and TransformerSpatioTemporalModel accept out_channels, and the spatio-temporal class stores it, but proj_out is hardwired to in_channels. The model config can say out_channels=8 while the output still has 4 channels.
Impact:
Saved configs misrepresent the architecture. Users cannot rely on the public constructor/config contract, and future checkpoint conversion can silently produce the wrong projection shape.
Reproduction:
import torch
from diffusers import TransformerTemporalModel
model = TransformerTemporalModel(
num_attention_heads=1, attention_head_dim=4,
in_channels=4, out_channels=8, num_layers=1, norm_num_groups=1,
)
print(model.config.out_channels) # 8
print(model.proj_out.out_features) # 4
x = torch.randn(2, 4, 4, 4)
print(model(x, num_frames=2).sample.shape) # torch.Size([2, 4, 4, 4])
Relevant precedent:
SD3 wires self.out_channels into its output projection:
|
self.out_channels = out_channels if out_channels is not None else in_channels |
|
self.inner_dim = num_attention_heads * attention_head_dim |
|
|
|
self.pos_embed = PatchEmbed( |
|
height=sample_size, |
|
width=sample_size, |
|
patch_size=patch_size, |
|
in_channels=in_channels, |
|
embed_dim=self.inner_dim, |
|
pos_embed_max_size=pos_embed_max_size, # hard-code for now. |
|
) |
|
self.time_text_embed = CombinedTimestepTextProjEmbeddings( |
|
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim |
|
) |
|
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
JointTransformerBlock( |
|
dim=self.inner_dim, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=attention_head_dim, |
|
context_pre_only=i == num_layers - 1, |
|
qk_norm=qk_norm, |
|
use_dual_attention=True if i in dual_attention_layers else False, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) |
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
Suggested fix:
If different output channels are unsupported, reject them explicitly:
if out_channels is not None and out_channels != in_channels:
raise ValueError("`out_channels` must be None or equal to `in_channels` for this temporal transformer.")
If support is intended, wire proj_out to self.out_channels and handle the residual path when channel counts differ.
Issue 4: Transformer2DModel discrete output hardcasts through float64
Affected code:
|
def _get_output_for_vectorized_inputs(self, hidden_states): |
|
hidden_states = self.norm_out(hidden_states) |
|
logits = self.out(hidden_states) |
|
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels) |
|
logits = logits.permute(0, 2, 1) |
|
# log(p(x_0)) |
|
output = F.log_softmax(logits.double(), dim=1).float() |
|
@require_torch_accelerator_with_fp64 |
|
def test_spatial_transformer_discrete(self): |
Problem:
The vectorized/discrete path computes F.log_softmax(logits.double(), dim=1).float(). The local model rules explicitly prohibit unconditional float64 in model forwards because MPS and several NPU backends do not support it. The only discrete-path test is gated behind require_torch_accelerator_with_fp64, so unsupported devices are skipped instead of protected.
Impact:
Discrete Transformer2DModel inference can fail on devices without float64 support, and the test suite encodes that limitation rather than catching it.
Reproduction:
import torch
from diffusers import Transformer2DModel
device = torch.device("mps") # or another backend without float64 forward support
model = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
num_vector_embeds=8,
sample_size=2,
).to(device)
sample = torch.randint(0, 8, (1, 4), device=device)
model(sample)
Relevant precedent:
The review rules require avoiding unconditional float64 in model code.
Suggested fix:
output = F.log_softmax(logits.float(), dim=1)
This preserves the current float32 output contract without routing through float64.
Issue 5: Slow/direct coverage is incomplete for the shared transformer family
Affected code:
|
class Transformer2DModelTests(unittest.TestCase): |
|
def test_spatial_transformer_default(self): |
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
sample = torch.randn(1, 32, 64, 64).to(torch_device) |
|
spatial_transformer_block = Transformer2DModel( |
|
in_channels=32, |
|
num_attention_heads=1, |
|
attention_head_dim=32, |
|
dropout=0.0, |
|
cross_attention_dim=None, |
|
).to(torch_device) |
|
with torch.no_grad(): |
|
attention_scores = spatial_transformer_block(sample).sample |
|
|
|
assert attention_scores.shape == (1, 32, 64, 64) |
|
output_slice = attention_scores[0, -1, -3:, -3:] |
|
|
|
expected_slice = torch.tensor( |
|
[-1.9455, -0.0066, -1.3933, -1.5878, 0.5325, -0.6486, -1.8648, 0.7515, -0.9689], device=torch_device |
|
) |
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
|
|
|
def test_spatial_transformer_cross_attention_dim(self): |
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
sample = torch.randn(1, 64, 64, 64).to(torch_device) |
|
spatial_transformer_block = Transformer2DModel( |
|
in_channels=64, |
|
num_attention_heads=2, |
|
attention_head_dim=32, |
|
dropout=0.0, |
|
cross_attention_dim=64, |
|
).to(torch_device) |
|
with torch.no_grad(): |
|
context = torch.randn(1, 4, 64).to(torch_device) |
|
attention_scores = spatial_transformer_block(sample, context).sample |
|
|
|
assert attention_scores.shape == (1, 64, 64, 64) |
|
output_slice = attention_scores[0, -1, -3:, -3:] |
|
expected_slice = torch.tensor( |
|
[0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device |
|
) |
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
|
|
|
def test_spatial_transformer_timestep(self): |
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
num_embeds_ada_norm = 5 |
|
|
|
sample = torch.randn(1, 64, 64, 64).to(torch_device) |
|
spatial_transformer_block = Transformer2DModel( |
|
in_channels=64, |
|
num_attention_heads=2, |
|
attention_head_dim=32, |
|
dropout=0.0, |
|
cross_attention_dim=64, |
|
num_embeds_ada_norm=num_embeds_ada_norm, |
|
).to(torch_device) |
|
with torch.no_grad(): |
|
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) |
|
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) |
|
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample |
|
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample |
|
|
|
assert attention_scores_1.shape == (1, 64, 64, 64) |
|
assert attention_scores_2.shape == (1, 64, 64, 64) |
|
|
|
output_slice_1 = attention_scores_1[0, -1, -3:, -3:] |
|
output_slice_2 = attention_scores_2[0, -1, -3:, -3:] |
|
|
|
expected_slice = torch.tensor( |
|
[-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device |
|
) |
|
expected_slice_2 = torch.tensor( |
|
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device |
|
) |
|
|
|
assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3) |
|
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3) |
|
|
|
def test_spatial_transformer_dropout(self): |
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
sample = torch.randn(1, 32, 64, 64).to(torch_device) |
|
spatial_transformer_block = ( |
|
Transformer2DModel( |
|
in_channels=32, |
|
num_attention_heads=2, |
|
attention_head_dim=16, |
|
dropout=0.3, |
|
cross_attention_dim=None, |
|
) |
|
.to(torch_device) |
|
.eval() |
|
) |
|
with torch.no_grad(): |
|
attention_scores = spatial_transformer_block(sample).sample |
|
|
|
assert attention_scores.shape == (1, 32, 64, 64) |
|
output_slice = attention_scores[0, -1, -3:, -3:] |
|
|
|
expected_slice = torch.tensor( |
|
[-1.9380, -0.0083, -1.3771, -1.5819, 0.5209, -0.6441, -1.8545, 0.7563, -0.9615], device=torch_device |
|
) |
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) |
|
|
|
@require_torch_accelerator_with_fp64 |
|
def test_spatial_transformer_discrete(self): |
|
class TemporalTransformerTests(ModelTesterMixin, unittest.TestCase): |
|
model_class = TransformerTemporalModel |
|
class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
@nightly |
|
@require_torch_accelerator |
|
class StableUnCLIPPipelineIntegrationTests(unittest.TestCase): |
Problem:
DualTransformer2DModel has no direct test references. Transformer2DModel tests live in test_layers_utils.py with no slow coverage. TransformerTemporalModel has a fast model test but no slow coverage. Prior model slow coverage exists, but the prior-family pipelines are uneven: Kandinsky prior tests are fast-only, Stable UnCLIP is nightly-only, and Shap-E slow coverage is already tracked separately in #13593.
Impact:
The exact regressions above are not covered: dual UNet call compatibility, top-level import parity, temporal out_channels, and non-fp64 discrete transformer execution.
Reproduction:
from pathlib import Path
for path in [
Path("tests/models/test_layers_utils.py"),
Path("tests/models/transformers/test_models_prior.py"),
Path("tests/models/transformers/test_models_transformer_temporal.py"),
]:
text = path.read_text()
print(path, "@slow" in text, "@nightly" in text)
print(
"DualTransformer2DModel test refs:",
sum("DualTransformer2DModel" in p.read_text(errors="ignore") for p in Path("tests").rglob("test_*.py")),
)
Relevant precedent:
PriorTransformer already has direct slow model coverage:
|
@slow |
|
class PriorTransformerIntegrationTests(unittest.TestCase): |
|
def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0): |
|
torch.manual_seed(seed) |
|
|
|
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
|
|
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"timestep": 2, |
|
"proj_embedding": proj_embedding, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
|
|
def tearDown(self): |
|
# clean up the VRAM after each test |
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
@parameterized.expand( |
|
[ |
|
# fmt: off |
|
[13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]], |
|
[37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]], |
|
# fmt: on |
|
] |
|
) |
|
def test_kandinsky_prior(self, seed, expected_slice): |
|
model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior") |
Suggested fix:
Add a dedicated dual-transformer fast test that runs through a dual_cross_attention=True UNet block with masks, add import/export assertions for all public shared classes, add temporal out_channels validation coverage, and add slow or integration coverage for TransformerTemporalModel and the prior-family pipelines that are currently fast-only or nightly-only.
model_transformers_sharedmodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
AGENTS.mdis referenced by.ai/review-rules.mdbut is not present in this checkout; all available referenced rule files were read.Files reviewed:
src/diffusers/models/transformers/dual_transformer_2d.pysrc/diffusers/models/transformers/prior_transformer.pysrc/diffusers/models/transformers/transformer_2d.pysrc/diffusers/models/transformers/transformer_temporal.pyDuplicate-search status: checked GitHub Issues/PRs for
model_transformers_shared, all target class/file names, and the specific failures below. No likely duplicate found for Issues 1-4. Shap-E slow-test coverage is already reported in #13593 and is called out in Issue 5.Issue 1:
DualTransformer2DModelrejects UNet encoder-mask calls and drops masksAffected code:
diffusers/src/diffusers/models/transformers/dual_transformer_2d.py
Lines 96 to 145 in 0f1abc4
Problem:
DualTransformer2DModel.forward()does not acceptencoder_attention_mask, butCrossAttnDownBlock2Dand related UNet blocks call transformer modules with that keyword. The same method also declaresattention_maskbut never forwards it to either child transformer.Impact:
Any UNet block configured with
dual_cross_attention=Truefails before inference/training. Masked text/image conditioning is also silently ignored if callers invoke the wrapper directly.Reproduction:
Relevant precedent:
Transformer2DModel.forward()accepts and routesencoder_attention_mask:diffusers/src/diffusers/models/transformers/transformer_2d.py
Line 333 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_2d.py
Line 431 in 0f1abc4
Suggested fix:
Issue 2:
DualTransformer2DModelis not available from the top-level lazy importAffected code:
diffusers/src/diffusers/models/__init__.py
Line 86 in 0f1abc4
diffusers/src/diffusers/__init__.py
Lines 194 to 306 in 0f1abc4
Problem:
DualTransformer2DModelis exported fromdiffusers.modelsanddiffusers.models.transformers, but not from top-leveldiffusers. The local review rules require model classes to be wired through both subpackage and top-level lazy imports.Impact:
Users can import related shared transformer classes from
diffusers, but this one raisesImportError.Reproduction:
Relevant precedent:
The same top-level list already exports
PriorTransformer,Transformer2DModel, andTransformerTemporalModel:diffusers/src/diffusers/__init__.py
Lines 273 to 292 in 0f1abc4
Suggested fix:
Also regenerate/add the matching
dummy_pt_objects.pyentry for no-torch imports.Issue 3: Temporal transformer
out_channelsis serialized but ignoredAffected code:
diffusers/src/diffusers/models/transformers/transformer_temporal.py
Lines 78 to 121 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_temporal.py
Lines 225 to 276 in 0f1abc4
Problem:
Both
TransformerTemporalModelandTransformerSpatioTemporalModelacceptout_channels, and the spatio-temporal class stores it, butproj_outis hardwired toin_channels. The model config can sayout_channels=8while the output still has 4 channels.Impact:
Saved configs misrepresent the architecture. Users cannot rely on the public constructor/config contract, and future checkpoint conversion can silently produce the wrong projection shape.
Reproduction:
Relevant precedent:
SD3 wires
self.out_channelsinto its output projection:diffusers/src/diffusers/models/transformers/transformer_sd3.py
Lines 139 to 170 in 0f1abc4
Suggested fix:
If different output channels are unsupported, reject them explicitly:
If support is intended, wire
proj_outtoself.out_channelsand handle the residual path when channel counts differ.Issue 4:
Transformer2DModeldiscrete output hardcasts through float64Affected code:
diffusers/src/diffusers/models/transformers/transformer_2d.py
Lines 514 to 520 in 0f1abc4
diffusers/tests/models/test_layers_utils.py
Lines 435 to 436 in 0f1abc4
Problem:
The vectorized/discrete path computes
F.log_softmax(logits.double(), dim=1).float(). The local model rules explicitly prohibit unconditional float64 in model forwards because MPS and several NPU backends do not support it. The only discrete-path test is gated behindrequire_torch_accelerator_with_fp64, so unsupported devices are skipped instead of protected.Impact:
Discrete
Transformer2DModelinference can fail on devices without float64 support, and the test suite encodes that limitation rather than catching it.Reproduction:
Relevant precedent:
The review rules require avoiding unconditional float64 in model code.
Suggested fix:
This preserves the current float32 output contract without routing through float64.
Issue 5: Slow/direct coverage is incomplete for the shared transformer family
Affected code:
diffusers/tests/models/test_layers_utils.py
Lines 324 to 436 in 0f1abc4
diffusers/tests/models/transformers/test_models_transformer_temporal.py
Lines 32 to 33 in 0f1abc4
diffusers/tests/pipelines/kandinsky/test_kandinsky_prior.py
Line 171 in 0f1abc4
diffusers/tests/pipelines/stable_unclip/test_stable_unclip.py
Lines 202 to 204 in 0f1abc4
Problem:
DualTransformer2DModelhas no direct test references.Transformer2DModeltests live intest_layers_utils.pywith no slow coverage.TransformerTemporalModelhas a fast model test but no slow coverage. Prior model slow coverage exists, but the prior-family pipelines are uneven: Kandinsky prior tests are fast-only, Stable UnCLIP is nightly-only, and Shap-E slow coverage is already tracked separately in #13593.Impact:
The exact regressions above are not covered: dual UNet call compatibility, top-level import parity, temporal
out_channels, and non-fp64 discrete transformer execution.Reproduction:
Relevant precedent:
PriorTransformer already has direct slow model coverage:
diffusers/tests/models/transformers/test_models_prior.py
Lines 142 to 174 in 0f1abc4
Suggested fix:
Add a dedicated dual-transformer fast test that runs through a
dual_cross_attention=TrueUNet block with masks, add import/export assertions for all public shared classes, add temporalout_channelsvalidation coverage, and add slow or integration coverage forTransformerTemporalModeland the prior-family pipelines that are currently fast-only or nightly-only.