Skip to content

model_unets_shared model/pipeline review #13654

@hlky

Description

@hlky

model_unets_shared model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Files reviewed: unet_1d.py, unet_1d_blocks.py, unet_2d.py, unet_2d_blocks.py, unet_2d_blocks_flax.py, unet_3d_blocks.py, uvit_2d.py, unet_2d_condition.py, unet_2d_condition_flax.py, unet_3d_condition.py.

Duplicate search: searched GitHub Issues/PRs for model_unets_shared, affected class/file names, and the specific failure modes. Exact duplicate found only for the UViT checkpointing failure: #11214. No exact duplicate found for the scalar timestep truncation, UNet3DConditionModel mask drop, or missing coverage; old issue #1890 is related to general attention masking but not this 3D UNet path.

Issue 1: UNet3DConditionModel accepts attention_mask but drops it

Affected code:

attention_mask: torch.Tensor | None = None,
cross_attention_kwargs: dict[str, Any] | None = None,
down_block_additional_residuals: tuple[torch.Tensor] | None = None,
mid_block_additional_residual: torch.Tensor | None = None,
return_dict: bool = True,
) -> UNet3DConditionOutput | tuple[torch.Tensor]:
r"""
The [`UNet3DConditionModel`] forward method.
Args:
sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`.
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
through the `self.time_embedding` layer to obtain the timestep embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
Returns:
[`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned,
otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):

def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
num_frames: int = 1,
cross_attention_kwargs: dict[str, Any] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
):
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
output_states += (hidden_states,)

hidden_states: torch.Tensor,
res_hidden_states_tuple: tuple[torch.Tensor, ...],
temb: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
upsample_size: int | None = None,
attention_mask: torch.Tensor | None = None,
num_frames: int = 1,
cross_attention_kwargs: dict[str, Any] = None,
) -> torch.Tensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# TODO(Patrick, William) - attention mask is not used
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,

Problem:
UNet3DConditionModel.forward() documents and prepares attention_mask, but CrossAttnDownBlock3D and CrossAttnUpBlock3D explicitly do not use it. The mid block also declares attention_mask and does not pass it to the spatial transformer. Padding masks therefore silently have no effect.

Impact:
Batched prompts with padding can attend to discarded text tokens in video UNets. The API suggests masking is honored, so users get silently incorrect conditioning.

Reproduction:

import torch
from diffusers import UNet3DConditionModel

torch.manual_seed(0)
model = UNet3DConditionModel(
    block_out_channels=(8, 16),
    norm_num_groups=4,
    down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
    up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
    cross_attention_dim=8,
    attention_head_dim=2,
    out_channels=4,
    in_channels=4,
    layers_per_block=1,
    sample_size=16,
).eval()

sample = torch.randn(1, 4, 2, 16, 16)
encoder_hidden_states = torch.randn(1, 4, 8)

with torch.no_grad():
    keep = model(sample, torch.tensor([10]), encoder_hidden_states, attention_mask=torch.ones(1, 4)).sample
    drop = model(sample, torch.tensor([10]), encoder_hidden_states, attention_mask=torch.zeros(1, 4)).sample

print((keep - drop).abs().max().item())  # 0.0: mask is ignored

Relevant precedent:

if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,

Suggested fix:

# In UNet3DConditionModel.forward(), after num_frames is known:
if attention_mask is not None:
    attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
    attention_mask = attention_mask.unsqueeze(1)
    attention_mask = attention_mask.repeat_interleave(
        num_frames, dim=0, output_size=attention_mask.shape[0] * num_frames
    )

# In 3D cross-attn blocks, pass it to the spatial Transformer2DModel cross-attn mask:
hidden_states = attn(
    hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=attention_mask,
    cross_attention_kwargs=cross_attention_kwargs,
    return_dict=False,
)[0]

Issue 2: scalar float timesteps are truncated in unconditional UNets

Affected code:

timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)

timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)

Problem:
UNet1DModel and UNet2DModel type their public timestep as torch.Tensor | float | int, but scalar Python floats are converted with integer dtype. UNet2DModel then divides Fourier outputs by the integer timestep, so 1e-4 becomes 0 and produces non-finite output. The Flax model has the same scalar-float-to-int32 pattern.

Impact:
Calling the public API with a scalar float timestep changes semantics versus passing torch.tensor([float]); VE/NCSN-style small sigma timesteps can produce NaNs.

Reproduction:

import torch
from diffusers import UNet1DModel, UNet2DModel

unet2d = UNet2DModel(
    sample_size=8,
    in_channels=3,
    out_channels=3,
    block_out_channels=(8,),
    layers_per_block=1,
    down_block_types=("DownBlock2D",),
    up_block_types=("UpBlock2D",),
    norm_num_groups=4,
    time_embedding_type="fourier",
).eval()

sample2d = torch.randn(1, 3, 8, 8)
with torch.no_grad():
    print(torch.isfinite(unet2d(sample2d, 1e-4).sample).all().item())              # False
    print(torch.isfinite(unet2d(sample2d, torch.tensor([1e-4])).sample).all().item())  # True

unet1d = UNet1DModel(
    sample_size=16,
    in_channels=4,
    out_channels=4,
    block_out_channels=(8, 8),
    down_block_types=("DownResnetBlock1D", "DownResnetBlock1D"),
    up_block_types=("UpResnetBlock1D",),
    mid_block_type="MidResTemporalBlock1D",
    out_block_type="OutConv1DBlock",
    time_embedding_type="positional",
    use_timestep_embedding=True,
    norm_num_groups=4,
    act_fn="swish",
).eval()

sample1d = torch.randn(2, 4, 16)
with torch.no_grad():
    a = unet1d(sample1d, 0.5).sample
    b = unet1d(sample1d, torch.tensor([0.5])).sample
print((a - b).abs().max().item())  # non-zero: scalar path was truncated

Relevant precedent:

def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | int) -> torch.Tensor | None:
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)

# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
num_frames = sample.shape[2]
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)

Suggested fix:

if not torch.is_tensor(timesteps):
    is_mps = sample.device.type == "mps"
    is_npu = sample.device.type == "npu"
    if isinstance(timesteps, float):
        dtype = torch.float32 if (is_mps or is_npu) else torch.float64
    else:
        dtype = torch.int32 if (is_mps or is_npu) else torch.int64
    timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)

Issue 3: Existing duplicate: UVit2DModel gradient checkpointing crashes

Affected code:

for layer in self.transformer_layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def layer_(*args):
return checkpoint(layer, *args)
else:
layer_ = layer
hidden_states = layer_(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
)

Problem:
Duplicate of #11214. UVit2DModel sets _supports_gradient_checkpointing = True, but its checkpoint wrapper only accepts *args and is then called with keyword arguments.

Impact:
Training or fine-tuning with enable_gradient_checkpointing() fails immediately.

Reproduction:

import torch
from diffusers import UVit2DModel

model = UVit2DModel(
    hidden_size=8,
    cond_embed_dim=4,
    micro_cond_encode_dim=2,
    micro_cond_embed_dim=4,
    encoder_hidden_size=6,
    vocab_size=16,
    codebook_size=15,
    in_channels=4,
    block_out_channels=4,
    num_res_blocks=1,
    block_num_heads=1,
    num_hidden_layers=1,
    num_attention_heads=1,
    intermediate_size=16,
    sample_size=2,
)
model.enable_gradient_checkpointing()
model(torch.randint(0, 16, (1, 2, 2)), torch.randn(1, 3, 6), torch.randn(1, 4), torch.tensor([[1.0, 2.0]]))

Relevant precedent:

if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
joint_attention_kwargs,
)
else:

Suggested fix:

if torch.is_grad_enabled() and self.gradient_checkpointing:
    hidden_states = self._gradient_checkpointing_func(
        layer,
        hidden_states,
        None,
        encoder_hidden_states,
        None,
        None,
        cross_attention_kwargs,
        None,
        {"pooled_text_emb": pooled_text_emb},
    )
else:
    hidden_states = layer(
        hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        cross_attention_kwargs=cross_attention_kwargs,
        added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
    )

Issue 4: Missing fast/slow coverage for parts of the family

Affected code:

class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True

class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
implemented for all models (such as downloading or saving).
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
general usage and behavior.
Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
down_block_types (`tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use.
up_block_types (`tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
is skipped.
block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
sample_size: int = 32
in_channels: int = 4
out_channels: int = 4
down_block_types: tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
mid_block_type: str | None = "UNetMidBlock2DCrossAttn"
only_cross_attention: bool | tuple[bool] = False
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: int | tuple[int, ...] = 8
num_attention_heads: int | tuple[int, ...] | None = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32

class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):

Problem:
UVit2DModel has no model or pipeline tests under tests/. FlaxUNet2DConditionModel is only covered by dummy-object checks, not fast or slow behavior tests. UNet3DConditionModel has fast model tests, but no @slow integration coverage.

Impact:
The UViT checkpointing regression above is currently unguarded, Flax UNet serialization/runtime can regress unnoticed while still publicly exported, and 3D UNet published-checkpoint behavior is not covered by slow tests.

Reproduction:

from pathlib import Path

for needle in ["UVit2DModel", "FlaxUNet2DConditionModel", "UNet3DConditionModel"]:
    hits = [str(p) for p in Path("tests").rglob("*.py") if needle in p.read_text(encoding="utf-8")]
    print(needle, hits)

text = Path("tests/models/unets/test_models_unet_3d_condition.py").read_text(encoding="utf-8")
print("UNet3D slow marker?", "@slow" in text or "pytest.mark.slow" in text)

Relevant precedent:


@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
inputs = self.dummy_input
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
assert image is not None, "Make sure output is not None"
@slow

@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):

Suggested fix:

# Add focused fast tests:
# - tests/models/unets/test_models_uvit_2d.py with a tiny UVit2DModel forward,
#   save/load, and gradient-checkpointing regression.
# - tests/models/unets/test_models_unet_2d_condition_flax.py guarded by require_flax,
#   or explicitly remove/limit public Flax coverage if deprecated support is no longer maintained.
# - Add at least one @slow UNet3DConditionModel checkpoint or deprecated text-to-video pipeline regression.

Verification: ran the import checks and all reproduction snippets above with .venv. JAX/Flax is not installed in this .venv, so Flax runtime behavior was code-reviewed but not executed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions