model_unets_shared model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Files reviewed: unet_1d.py, unet_1d_blocks.py, unet_2d.py, unet_2d_blocks.py, unet_2d_blocks_flax.py, unet_3d_blocks.py, uvit_2d.py, unet_2d_condition.py, unet_2d_condition_flax.py, unet_3d_condition.py.
Duplicate search: searched GitHub Issues/PRs for model_unets_shared, affected class/file names, and the specific failure modes. Exact duplicate found only for the UViT checkpointing failure: #11214. No exact duplicate found for the scalar timestep truncation, UNet3DConditionModel mask drop, or missing coverage; old issue #1890 is related to general attention masking but not this 3D UNet path.
Issue 1: UNet3DConditionModel accepts attention_mask but drops it
Affected code:
|
attention_mask: torch.Tensor | None = None, |
|
cross_attention_kwargs: dict[str, Any] | None = None, |
|
down_block_additional_residuals: tuple[torch.Tensor] | None = None, |
|
mid_block_additional_residual: torch.Tensor | None = None, |
|
return_dict: bool = True, |
|
) -> UNet3DConditionOutput | tuple[torch.Tensor]: |
|
r""" |
|
The [`UNet3DConditionModel`] forward method. |
|
|
|
Args: |
|
sample (`torch.Tensor`): |
|
The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. |
|
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
|
encoder_hidden_states (`torch.Tensor`): |
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. |
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): |
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed |
|
through the `self.time_embedding` layer to obtain the timestep embeddings. |
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): |
|
A tuple of tensors that if specified are added to the residuals of down unet blocks. |
|
mid_block_additional_residual: (`torch.Tensor`, *optional*): |
|
A tensor that if specified is added to the residual of the middle unet block. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain |
|
tuple. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. |
|
|
|
Returns: |
|
[`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: |
|
If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is the sample tensor. |
|
""" |
|
# By default samples have to be AT least a multiple of the overall upsampling factor. |
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). |
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size |
|
# on the fly if necessary. |
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` |
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
|
logger.info("Forward upsample size to force interpolation output size.") |
|
forward_upsample_size = True |
|
|
|
# prepare attention_mask |
|
if attention_mask is not None: |
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
# 1. time |
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
temb: torch.Tensor | None = None, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
num_frames: int = 1, |
|
cross_attention_kwargs: dict[str, Any] = None, |
|
) -> torch.Tensor | tuple[torch.Tensor, ...]: |
|
# TODO(Patrick, William) - attention mask is not used |
|
output_states = () |
|
|
|
for resnet, temp_conv, attn, temp_attn in zip( |
|
self.resnets, self.temp_convs, self.attentions, self.temp_attentions |
|
): |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = temp_conv(hidden_states, num_frames=num_frames) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
hidden_states = temp_attn( |
|
hidden_states, |
|
num_frames=num_frames, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
output_states += (hidden_states,) |
|
|
|
hidden_states: torch.Tensor, |
|
res_hidden_states_tuple: tuple[torch.Tensor, ...], |
|
temb: torch.Tensor | None = None, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
upsample_size: int | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
num_frames: int = 1, |
|
cross_attention_kwargs: dict[str, Any] = None, |
|
) -> torch.Tensor: |
|
is_freeu_enabled = ( |
|
getattr(self, "s1", None) |
|
and getattr(self, "s2", None) |
|
and getattr(self, "b1", None) |
|
and getattr(self, "b2", None) |
|
) |
|
|
|
# TODO(Patrick, William) - attention mask is not used |
|
for resnet, temp_conv, attn, temp_attn in zip( |
|
self.resnets, self.temp_convs, self.attentions, self.temp_attentions |
|
): |
|
# pop res hidden states |
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
|
# FreeU: Only operate on the first two stages |
|
if is_freeu_enabled: |
|
hidden_states, res_hidden_states = apply_freeu( |
|
self.resolution_idx, |
|
hidden_states, |
|
res_hidden_states, |
|
s1=self.s1, |
|
s2=self.s2, |
|
b1=self.b1, |
|
b2=self.b2, |
|
) |
|
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = temp_conv(hidden_states, num_frames=num_frames) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
hidden_states = temp_attn( |
|
hidden_states, |
|
num_frames=num_frames, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
Problem:
UNet3DConditionModel.forward() documents and prepares attention_mask, but CrossAttnDownBlock3D and CrossAttnUpBlock3D explicitly do not use it. The mid block also declares attention_mask and does not pass it to the spatial transformer. Padding masks therefore silently have no effect.
Impact:
Batched prompts with padding can attend to discarded text tokens in video UNets. The API suggests masking is honored, so users get silently incorrect conditioning.
Reproduction:
import torch
from diffusers import UNet3DConditionModel
torch.manual_seed(0)
model = UNet3DConditionModel(
block_out_channels=(8, 16),
norm_num_groups=4,
down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
cross_attention_dim=8,
attention_head_dim=2,
out_channels=4,
in_channels=4,
layers_per_block=1,
sample_size=16,
).eval()
sample = torch.randn(1, 4, 2, 16, 16)
encoder_hidden_states = torch.randn(1, 4, 8)
with torch.no_grad():
keep = model(sample, torch.tensor([10]), encoder_hidden_states, attention_mask=torch.ones(1, 4)).sample
drop = model(sample, torch.tensor([10]), encoder_hidden_states, attention_mask=torch.zeros(1, 4)).sample
print((keep - drop).abs().max().item()) # 0.0: mask is ignored
Relevant precedent:
|
if encoder_attention_mask is not None: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
Suggested fix:
# In UNet3DConditionModel.forward(), after num_frames is known:
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(
num_frames, dim=0, output_size=attention_mask.shape[0] * num_frames
)
# In 3D cross-attn blocks, pass it to the spatial Transformer2DModel cross-attn mask:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
Issue 2: scalar float timesteps are truncated in unconditional UNets
Affected code:
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
timestep_embed = self.time_proj(timesteps) |
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
|
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) |
|
|
|
# 1. time |
|
if not isinstance(timesteps, jnp.ndarray): |
|
timesteps = jnp.array([timesteps], dtype=jnp.int32) |
|
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: |
|
timesteps = timesteps.astype(dtype=jnp.float32) |
|
timesteps = jnp.expand_dims(timesteps, 0) |
Problem:
UNet1DModel and UNet2DModel type their public timestep as torch.Tensor | float | int, but scalar Python floats are converted with integer dtype. UNet2DModel then divides Fourier outputs by the integer timestep, so 1e-4 becomes 0 and produces non-finite output. The Flax model has the same scalar-float-to-int32 pattern.
Impact:
Calling the public API with a scalar float timestep changes semantics versus passing torch.tensor([float]); VE/NCSN-style small sigma timesteps can produce NaNs.
Reproduction:
import torch
from diffusers import UNet1DModel, UNet2DModel
unet2d = UNet2DModel(
sample_size=8,
in_channels=3,
out_channels=3,
block_out_channels=(8,),
layers_per_block=1,
down_block_types=("DownBlock2D",),
up_block_types=("UpBlock2D",),
norm_num_groups=4,
time_embedding_type="fourier",
).eval()
sample2d = torch.randn(1, 3, 8, 8)
with torch.no_grad():
print(torch.isfinite(unet2d(sample2d, 1e-4).sample).all().item()) # False
print(torch.isfinite(unet2d(sample2d, torch.tensor([1e-4])).sample).all().item()) # True
unet1d = UNet1DModel(
sample_size=16,
in_channels=4,
out_channels=4,
block_out_channels=(8, 8),
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=("UpResnetBlock1D",),
mid_block_type="MidResTemporalBlock1D",
out_block_type="OutConv1DBlock",
time_embedding_type="positional",
use_timestep_embedding=True,
norm_num_groups=4,
act_fn="swish",
).eval()
sample1d = torch.randn(2, 4, 16)
with torch.no_grad():
a = unet1d(sample1d, 0.5).sample
b = unet1d(sample1d, torch.tensor([0.5])).sample
print((a - b).abs().max().item()) # non-zero: scalar path was truncated
Relevant precedent:
|
def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | int) -> torch.Tensor | None: |
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can |
|
# This would be a good case for the `match` statement (Python 3.10+) |
|
is_mps = sample.device.type == "mps" |
|
is_npu = sample.device.type == "npu" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if (is_mps or is_npu) else torch.float64 |
|
else: |
|
dtype = torch.int32 if (is_mps or is_npu) else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
# `Timesteps` does not contain any weights and will always return f32 tensors |
|
# but time_embedding might actually be running in fp16. so we need to cast here. |
|
# there might be better ways to encapsulate this. |
|
t_emb = t_emb.to(dtype=sample.dtype) |
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can |
|
# This would be a good case for the `match` statement (Python 3.10+) |
|
is_mps = sample.device.type == "mps" |
|
is_npu = sample.device.type == "npu" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if (is_mps or is_npu) else torch.float64 |
|
else: |
|
dtype = torch.int32 if (is_mps or is_npu) else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
|
num_frames = sample.shape[2] |
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
|
|
# timesteps does not contain any weights and will always return f32 tensors |
|
# but time_embedding might actually be running in fp16. so we need to cast here. |
|
# there might be better ways to encapsulate this. |
|
t_emb = t_emb.to(dtype=self.dtype) |
Suggested fix:
if not torch.is_tensor(timesteps):
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timesteps, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
Issue 3: Existing duplicate: UVit2DModel gradient checkpointing crashes
Affected code:
|
for layer in self.transformer_layers: |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def layer_(*args): |
|
return checkpoint(layer, *args) |
|
|
|
else: |
|
layer_ = layer |
|
|
|
hidden_states = layer_( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, |
|
) |
|
|
Problem:
Duplicate of #11214. UVit2DModel sets _supports_gradient_checkpointing = True, but its checkpoint wrapper only accepts *args and is then called with keyword arguments.
Impact:
Training or fine-tuning with enable_gradient_checkpointing() fails immediately.
Reproduction:
import torch
from diffusers import UVit2DModel
model = UVit2DModel(
hidden_size=8,
cond_embed_dim=4,
micro_cond_encode_dim=2,
micro_cond_embed_dim=4,
encoder_hidden_size=6,
vocab_size=16,
codebook_size=15,
in_channels=4,
block_out_channels=4,
num_res_blocks=1,
block_num_heads=1,
num_hidden_layers=1,
num_attention_heads=1,
intermediate_size=16,
sample_size=2,
)
model.enable_gradient_checkpointing()
model(torch.randint(0, 16, (1, 2, 2)), torch.randn(1, 3, 6), torch.randn(1, 4), torch.tensor([[1.0, 2.0]]))
Relevant precedent:
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( |
|
block, |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
joint_attention_kwargs, |
|
) |
|
|
|
else: |
Suggested fix:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer,
hidden_states,
None,
encoder_hidden_states,
None,
None,
cross_attention_kwargs,
None,
{"pooled_text_emb": pooled_text_emb},
)
else:
hidden_states = layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
)
Issue 4: Missing fast/slow coverage for parts of the family
Affected code:
|
class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin): |
|
_supports_gradient_checkpointing = True |
|
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): |
|
r""" |
|
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample |
|
shaped output. |
|
|
|
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods |
|
implemented for all models (such as downloading or saving). |
|
|
|
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) |
|
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its |
|
general usage and behavior. |
|
|
|
Inherent JAX features such as the following are supported: |
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
|
|
|
Parameters: |
|
sample_size (`int`, *optional*): |
|
The size of the input sample. |
|
in_channels (`int`, *optional*, defaults to 4): |
|
The number of channels in the input sample. |
|
out_channels (`int`, *optional*, defaults to 4): |
|
The number of channels in the output. |
|
down_block_types (`tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): |
|
The tuple of downsample blocks to use. |
|
up_block_types (`tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): |
|
The tuple of upsample blocks to use. |
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): |
|
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer |
|
is skipped. |
|
block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): |
|
The tuple of output channels for each block. |
|
layers_per_block (`int`, *optional*, defaults to 2): |
|
The number of layers per block. |
|
attention_head_dim (`int` or `tuple[int]`, *optional*, defaults to 8): |
|
The dimension of the attention heads. |
|
num_attention_heads (`int` or `tuple[int]`, *optional*): |
|
The number of attention heads. |
|
cross_attention_dim (`int`, *optional*, defaults to 768): |
|
The dimension of the cross attention features. |
|
dropout (`float`, *optional*, defaults to 0): |
|
Dropout probability for down, up and bottleneck blocks. |
|
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): |
|
Whether to flip the sin to cos in the time embedding. |
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. |
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
|
Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682). |
|
split_head_dim (`bool`, *optional*, defaults to `False`): |
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
|
""" |
|
|
|
sample_size: int = 32 |
|
in_channels: int = 4 |
|
out_channels: int = 4 |
|
down_block_types: tuple[str, ...] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
) |
|
up_block_types: tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") |
|
mid_block_type: str | None = "UNetMidBlock2DCrossAttn" |
|
only_cross_attention: bool | tuple[bool] = False |
|
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280) |
|
layers_per_block: int = 2 |
|
attention_head_dim: int | tuple[int, ...] = 8 |
|
num_attention_heads: int | tuple[int, ...] | None = None |
|
cross_attention_dim: int = 1280 |
|
dropout: float = 0.0 |
|
use_linear_projection: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
Problem:
UVit2DModel has no model or pipeline tests under tests/. FlaxUNet2DConditionModel is only covered by dummy-object checks, not fast or slow behavior tests. UNet3DConditionModel has fast model tests, but no @slow integration coverage.
Impact:
The UViT checkpointing regression above is currently unguarded, Flax UNet serialization/runtime can regress unnoticed while still publicly exported, and 3D UNet published-checkpoint behavior is not covered by slow tests.
Reproduction:
from pathlib import Path
for needle in ["UVit2DModel", "FlaxUNet2DConditionModel", "UNet3DConditionModel"]:
hits = [str(p) for p in Path("tests").rglob("*.py") if needle in p.read_text(encoding="utf-8")]
print(needle, hits)
text = Path("tests/models/unets/test_models_unet_3d_condition.py").read_text(encoding="utf-8")
print("UNet3D slow marker?", "@slow" in text or "pytest.mark.slow" in text)
Relevant precedent:
|
@slow |
|
def test_from_pretrained_hub(self): |
|
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) |
|
self.assertIsNotNone(model) |
|
self.assertEqual(len(loading_info["missing_keys"]), 0) |
|
|
|
model.to(torch_device) |
|
inputs = self.dummy_input |
|
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) |
|
inputs["sample"] = noise |
|
image = model(**inputs) |
|
|
|
assert image is not None, "Make sure output is not None" |
|
|
|
@slow |
|
@slow |
|
class UNet2DConditionModelIntegrationTests(unittest.TestCase): |
Suggested fix:
# Add focused fast tests:
# - tests/models/unets/test_models_uvit_2d.py with a tiny UVit2DModel forward,
# save/load, and gradient-checkpointing regression.
# - tests/models/unets/test_models_unet_2d_condition_flax.py guarded by require_flax,
# or explicitly remove/limit public Flax coverage if deprecated support is no longer maintained.
# - Add at least one @slow UNet3DConditionModel checkpoint or deprecated text-to-video pipeline regression.
Verification: ran the import checks and all reproduction snippets above with .venv. JAX/Flax is not installed in this .venv, so Flax runtime behavior was code-reviewed but not executed.
model_unets_sharedmodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Files reviewed:
unet_1d.py,unet_1d_blocks.py,unet_2d.py,unet_2d_blocks.py,unet_2d_blocks_flax.py,unet_3d_blocks.py,uvit_2d.py,unet_2d_condition.py,unet_2d_condition_flax.py,unet_3d_condition.py.Duplicate search: searched GitHub Issues/PRs for
model_unets_shared, affected class/file names, and the specific failure modes. Exact duplicate found only for the UViT checkpointing failure: #11214. No exact duplicate found for the scalar timestep truncation,UNet3DConditionModelmask drop, or missing coverage; old issue #1890 is related to general attention masking but not this 3D UNet path.Issue 1:
UNet3DConditionModelacceptsattention_maskbut drops itAffected code:
diffusers/src/diffusers/models/unets/unet_3d_condition.py
Lines 483 to 547 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_3d_blocks.py
Lines 501 to 532 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_3d_blocks.py
Lines 728 to 778 in 0f1abc4
Problem:
UNet3DConditionModel.forward()documents and preparesattention_mask, butCrossAttnDownBlock3DandCrossAttnUpBlock3Dexplicitly do not use it. The mid block also declaresattention_maskand does not pass it to the spatial transformer. Padding masks therefore silently have no effect.Impact:
Batched prompts with padding can attend to discarded text tokens in video UNets. The API suggests masking is honored, so users get silently incorrect conditioning.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/unets/unet_2d_condition.py
Lines 1074 to 1076 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_2d_blocks.py
Lines 1264 to 1275 in 0f1abc4
Suggested fix:
Issue 2: scalar float timesteps are truncated in unconditional UNets
Affected code:
diffusers/src/diffusers/models/unets/unet_1d.py
Lines 228 to 234 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_2d.py
Lines 278 to 286 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_2d_condition_flax.py
Lines 371 to 376 in 0f1abc4
Problem:
UNet1DModelandUNet2DModeltype their publictimestepastorch.Tensor | float | int, but scalar Python floats are converted with integer dtype.UNet2DModelthen divides Fourier outputs by the integer timestep, so1e-4becomes0and produces non-finite output. The Flax model has the same scalar-float-to-int32pattern.Impact:
Calling the public API with a scalar float timestep changes semantics versus passing
torch.tensor([float]); VE/NCSN-style small sigma timesteps can produce NaNs.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/unets/unet_2d_condition.py
Lines 851 to 873 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_3d_condition.py
Lines 548 to 569 in 0f1abc4
Suggested fix:
Issue 3: Existing duplicate:
UVit2DModelgradient checkpointing crashesAffected code:
diffusers/src/diffusers/models/unets/uvit_2d.py
Lines 180 to 195 in 0f1abc4
Problem:
Duplicate of #11214.
UVit2DModelsets_supports_gradient_checkpointing = True, but its checkpoint wrapper only accepts*argsand is then called with keyword arguments.Impact:
Training or fine-tuning with
enable_gradient_checkpointing()fails immediately.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 715 to 725 in 0f1abc4
Suggested fix:
Issue 4: Missing fast/slow coverage for parts of the family
Affected code:
diffusers/src/diffusers/models/unets/uvit_2d.py
Lines 38 to 39 in 0f1abc4
diffusers/src/diffusers/models/unets/unet_2d_condition_flax.py
Lines 51 to 124 in 0f1abc4
diffusers/tests/models/unets/test_models_unet_3d_condition.py
Line 35 in 0f1abc4
Problem:
UVit2DModelhas no model or pipeline tests undertests/.FlaxUNet2DConditionModelis only covered by dummy-object checks, not fast or slow behavior tests.UNet3DConditionModelhas fast model tests, but no@slowintegration coverage.Impact:
The UViT checkpointing regression above is currently unguarded, Flax UNet serialization/runtime can regress unnoticed while still publicly exported, and 3D UNet published-checkpoint behavior is not covered by slow tests.
Reproduction:
Relevant precedent:
diffusers/tests/models/unets/test_models_unet_1d.py
Line 141 in 0f1abc4
diffusers/tests/models/unets/test_models_unet_2d.py
Lines 330 to 344 in 0f1abc4
diffusers/tests/models/unets/test_models_unet_2d_condition.py
Lines 1161 to 1162 in 0f1abc4
Suggested fix:
Verification: ran the import checks and all reproduction snippets above with
.venv. JAX/Flax is not installed in this.venv, so Flax runtime behavior was code-reviewed but not executed.