model_autoencoders_shared model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Full target file list reviewed: autoencoder_asym_kl.py, autoencoder_dc.py, autoencoder_kl.py, autoencoder_kl_kvae.py, autoencoder_kl_kvae_video.py, autoencoder_kl_magvit.py, autoencoder_kl_temporal_decoder.py, autoencoder_oobleck.py, autoencoder_rae.py, autoencoder_tiny.py, autoencoder_vidtok.py, consistency_decoder_vae.py, vae.py, vq_model.py.
Issue 1: AutoencoderKLKVAEVideo discards encoder log variance
Affected code:
|
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor: |
|
# Cached encoder processes by segments |
|
cache = self._make_encoder_cache() |
|
|
|
split_list = [seg_len + 1] |
|
n_frames = x.size(2) - (seg_len + 1) |
|
while n_frames > 0: |
|
split_list.append(seg_len) |
|
n_frames -= seg_len |
|
split_list[-1] += n_frames |
|
|
|
latent = [] |
|
for chunk in torch.split(x, split_list, dim=2): |
|
l = self.encoder(chunk, cache) |
|
sample, _ = torch.chunk(l, 2, dim=1) |
|
latent.append(sample) |
|
|
|
return torch.cat(latent, dim=2) |
|
|
|
@apply_forward_hook |
|
def encode( |
|
self, x: torch.Tensor, return_dict: bool = True |
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
|
""" |
|
Encode a batch of videos into latents. |
|
|
|
Args: |
|
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
The latent representations of the encoded videos. |
|
""" |
|
if self.use_slicing and x.shape[0] > 1: |
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
|
h = torch.cat(encoded_slices) |
|
else: |
|
h = self._encode(x) |
|
|
|
# For cached encoder, we already did the split in _encode |
|
h_double = torch.cat([h, torch.zeros_like(h)], dim=1) |
|
posterior = DiagonalGaussianDistribution(h_double) |
|
|
|
if not return_dict: |
|
return (posterior,) |
|
return AutoencoderKLOutput(latent_dist=posterior) |
Problem:
KVAECachedEncoder3D outputs 2 * z_channels, but _encode() splits the tensor and keeps only the first half. encode() then reconstructs a fake [mean, zeros] tensor, so the posterior log variance is always zero.
Impact:
sample_posterior=True samples from the wrong distribution, checkpoint log-variance weights are ignored, and parity with the source KVAE 3D VAE is broken.
Reproduction:
import torch
from diffusers import AutoencoderKLKVAEVideo
model = AutoencoderKLKVAEVideo(
ch=32, ch_mult=(1, 2), num_res_blocks=1, z_channels=4, temporal_compress_times=2
).eval()
x = torch.randn(1, 3, 3, 16, 16)
with torch.no_grad():
raw = model.encoder(x, model._make_encoder_cache())
raw_mean, raw_logvar = raw.chunk(2, dim=1)
posterior = model.encode(x).latent_dist
print(torch.allclose(posterior.mean, raw_mean, atol=1e-5))
print(raw_logvar.abs().max().item())
print(posterior.logvar.abs().max().item()) # always 0.0
Relevant precedent:
|
def _encode(self, x: torch.Tensor) -> torch.Tensor: |
|
batch_size, num_channels, height, width = x.shape |
|
|
|
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): |
|
return self._tiled_encode(x) |
|
|
|
enc = self.encoder(x) |
|
|
|
return enc |
|
|
|
@apply_forward_hook |
|
def encode( |
|
self, x: torch.Tensor, return_dict: bool = True |
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
|
""" |
|
Encode a batch of images into latents. |
|
|
|
Args: |
|
x (`torch.Tensor`): Input batch of images. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
The latent representations of the encoded images. If `return_dict` is True, a |
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. |
|
""" |
|
if self.use_slicing and x.shape[0] > 1: |
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
|
h = torch.cat(encoded_slices) |
|
else: |
|
h = self._encode(x) |
|
|
|
posterior = DiagonalGaussianDistribution(h) |
|
|
|
if not return_dict: |
|
return (posterior,) |
|
|
|
return AutoencoderKLOutput(latent_dist=posterior) |
Suggested fix:
# _encode(): keep the full encoder output
latent.append(self.encoder(chunk, cache))
# encode(): do not synthesize a zero logvar half
posterior = DiagonalGaussianDistribution(h)
Issue 2: AutoencoderVidTok(return_dict=False) returns a tensor instead of a tuple
Affected code:
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
sample_posterior: bool = True, |
|
encoder_mode: bool = False, |
|
return_dict: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
) -> Union[torch.Tensor, DecoderOutput]: |
|
x = sample |
|
res = 1 if self.is_causal else 0 |
|
if self.is_causal: |
|
if x.shape[2] % self.temporal_compression_ratio != res: |
|
time_padding = self.temporal_compression_ratio - x.shape[2] % self.temporal_compression_ratio + res |
|
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate") |
|
else: |
|
time_padding = 0 |
|
else: |
|
if x.shape[2] % self.num_sample_frames_batch_size != res: |
|
if not encoder_mode: |
|
time_padding = ( |
|
self.num_sample_frames_batch_size - x.shape[2] % self.num_sample_frames_batch_size + res |
|
) |
|
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate") |
|
else: |
|
assert x.shape[2] >= self.num_sample_frames_batch_size, ( |
|
f"Too short video. At least {self.num_sample_frames_batch_size} frames." |
|
) |
|
x = x[:, :, : x.shape[2] // self.num_sample_frames_batch_size * self.num_sample_frames_batch_size] |
|
else: |
|
time_padding = 0 |
|
|
|
if self.is_causal: |
|
x = self._pad_at_dim(x, (self.temporal_compression_ratio - 1, 0), dim=2, pad_mode="replicate") |
|
|
|
if self.regularizer == "kl": |
|
posterior = self.encode(x).latent_dist |
|
if sample_posterior: |
|
z = posterior.sample(generator=generator) |
|
else: |
|
z = posterior.mode() |
|
if encoder_mode: |
|
return z |
|
else: |
|
z, indices = self.encode(x) |
|
if encoder_mode: |
|
return z, indices |
|
|
|
dec = self.decode(z) |
|
if time_padding != 0: |
|
dec = dec[:, :, :-time_padding, :, :] |
|
|
|
if not return_dict: |
|
return dec |
|
return DecoderOutput(sample=dec) |
Problem:
Every other autoencoder forward path returns a tuple when return_dict=False. VidTok returns the raw decoded tensor, so common caller code like model(..., return_dict=False)[0] silently selects the first batch element.
Impact:
Pipeline/model utility code that relies on diffusers’ tuple convention gets the wrong shape and silently drops batch items.
Reproduction:
import torch
from diffusers import AutoencoderVidTok
model = AutoencoderVidTok(
is_causal=False, ch=8, ch_mult=[1], z_channels=2, double_z=True, num_res_blocks=1, regularizer="kl"
).eval()
x = torch.randn(2, 3, 2, 8, 8)
with torch.no_grad():
out = model(x, sample_posterior=False, return_dict=False)
print(type(out), out.shape) # torch.Tensor, not tuple
print(out[0].shape) # first batch item, not decoded tuple element
Relevant precedent:
|
def forward( |
|
self, |
|
sample: torch.Tensor, |
|
sample_posterior: bool = False, |
|
return_dict: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
) -> Union[DecoderOutput, torch.Tensor]: |
|
x = sample |
|
posterior = self.encode(x).latent_dist |
|
if sample_posterior: |
|
z = posterior.sample(generator=generator) |
|
else: |
|
z = posterior.mode() |
|
dec = self.decode(z).sample |
|
if not return_dict: |
|
return (dec,) |
|
return DecoderOutput(sample=dec) |
Suggested fix:
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
Issue 3: AutoencoderVidTok bypasses diffusers attention processors
Affected code:
|
class VidTokAttnBlock(nn.Module): |
|
r"""A 3D self-attention block used in VidTok Model.""" |
|
|
|
def __init__(self, in_channels: int, is_causal: bool = True): |
|
super().__init__() |
|
make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d |
|
self.norm = VidTokLayerNorm(dim=in_channels, eps=1e-6) |
|
self.q = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.k = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.v = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.proj_out = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
|
|
def attention(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
r"""Implement self-attention.""" |
|
hidden_states = self.norm(hidden_states) |
|
q = self.q(hidden_states) |
|
k = self.k(hidden_states) |
|
v = self.v(hidden_states) |
|
b, c, t, h, w = q.shape |
|
q, k, v = [x.permute(0, 2, 3, 4, 1).reshape(b, t, -1, c).contiguous() for x in [q, k, v]] |
|
hidden_states = F.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default |
|
return hidden_states.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) |
|
class AutoencoderVidTok(ModelMixin, ConfigMixin): |
|
r""" |
|
A VAE model for encoding videos into latents and decoding latent representations into videos, supporting both |
|
continuous and discrete latent representations. Used in [VidTok](https://github.com/microsoft/VidTok). |
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
|
for all models (such as downloading or saving). |
|
|
|
Args: |
|
in_channels (`int`, defaults to 3): |
|
The number of input channels. |
|
out_channels (`int`, defaults to 3): |
|
The number of output channels. |
|
ch (`int`, defaults to 128): |
|
The number of the basic channel. |
|
ch_mult (`List[int]`, defaults to `[1, 2, 4, 4]`): |
|
The multiple of the basic channel for each block. |
|
z_channels (`int`, defaults to 4): |
|
The number of latent channels. |
|
double_z (`bool`, defaults to `True`): |
|
Whether or not to double the z_channels. |
|
num_res_blocks (`int`, defaults to 2): |
|
The number of resblocks. |
|
spatial_ds (`List`, *optional*, defaults to `None`): |
|
Spatial downsample layers. |
|
spatial_us (`List`, *optional*, defaults to `None`): |
|
Spatial upsample layers. |
|
tempo_ds (`List`, *optional*, defaults to `None`): |
|
Temporal downsample layers. |
|
tempo_us (`List`, *optional*, defaults to `None`): |
|
Temporal upsample layers. |
|
dropout (`float`, defaults to 0.0): |
|
Dropout rate. |
|
regularizer (`str`, defaults to `"kl"`): |
|
The regularizer type - "kl" for continuous cases and "fsq" for discrete cases. |
|
codebook_size (`int`, defaults to 262144): |
|
The codebook size used only in discrete cases. |
|
is_causal (`bool`, defaults to `True`): |
|
Whether it is a causal module. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
Problem:
VidTokAttnBlock calls F.scaled_dot_product_attention directly, and AutoencoderVidTok does not inherit AttentionMixin. This bypasses the repository’s attention processor/backend path.
Impact:
Users cannot inspect or replace attention processors, attention backend selection is ignored for VidTok attention, and the implementation violates the model attention rule.
Reproduction:
from diffusers import AutoencoderVidTok
model = AutoencoderVidTok(
is_causal=False, ch=8, ch_mult=[1], z_channels=2, double_z=True, num_res_blocks=1, regularizer="kl"
)
print(hasattr(model, "set_attn_processor")) # False
print([type(m).__name__ for m in model.modules() if type(m).__name__ == "VidTokAttnBlock"])
Relevant precedent:
|
self.attention = Attention( |
|
query_dim=hidden_size, |
|
heads=num_attention_heads, |
|
dim_head=hidden_size // num_attention_heads, |
|
dropout=attention_probs_dropout_prob, |
|
bias=qkv_bias, |
|
) |
|
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): |
|
r""" |
|
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. |
|
|
|
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct |
|
images from learned representations. |
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for |
|
all models (such as downloading or saving). |
|
|
|
Args: |
|
encoder_type (`str`, *optional*, defaults to `"dinov2"`): |
|
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. |
|
encoder_hidden_size (`int`, *optional*, defaults to `768`): |
|
Hidden size of the encoder model. |
|
encoder_patch_size (`int`, *optional*, defaults to `14`): |
|
Patch size of the encoder model. |
|
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): |
|
Number of hidden layers in the encoder model. |
|
patch_size (`int`, *optional*, defaults to `16`): |
|
Decoder patch size (used for unpatchify and decoder head). |
|
encoder_input_size (`int`, *optional*, defaults to `224`): |
|
Input size expected by the encoder. |
|
image_size (`int`, *optional*): |
|
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like |
|
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // |
|
encoder_patch_size) ** 2`. |
|
num_channels (`int`, *optional*, defaults to `3`): |
|
Number of input/output channels. |
|
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): |
|
Channel-wise mean for encoder input normalization (ImageNet defaults). |
|
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): |
|
Channel-wise std for encoder input normalization (ImageNet defaults). |
|
latents_mean (`list` or `tuple`, *optional*): |
|
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable |
|
lists. |
|
latents_std (`list` or `tuple`, *optional*): |
|
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to |
|
config-serializable lists. |
|
noise_tau (`float`, *optional*, defaults to `0.0`): |
|
Noise level for training (adds noise to latents during training). |
|
reshape_to_2d (`bool`, *optional*, defaults to `True`): |
|
Whether to reshape latents to 2D (B, C, H, W) format. |
|
use_encoder_loss (`bool`, *optional*, defaults to `False`): |
|
Whether to use encoder hidden states in the loss (for advanced training). |
|
""" |
|
|
|
# NOTE: gradient checkpointing is not wired up for this model yet. |
|
_supports_gradient_checkpointing = False |
|
_no_split_modules = ["ViTMAELayer"] |
Suggested fix:
Refactor VidTokAttnBlock to use the diffusers attention processor pattern with dispatch_attention_fn, and make AutoencoderVidTok inherit AttentionMixin.
Issue 4: AutoencoderTiny fails bfloat16 forward because it creates float32 activations
Affected code:
|
enc = self.encode(sample).latents |
|
|
|
# scale latents to be in [0, 1], then quantize latents to a byte tensor, |
|
# as if we were storing the latents in an RGBA uint8 image. |
|
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() |
|
|
|
# unquantize latents back into [0, 1], then unscale latents back to their original range, |
|
# as if we were loading the latents from an RGBA uint8 image. |
|
unscaled_enc = self.unscale_latents(scaled_enc / 255.0) |
|
|
|
dec = self.decode(unscaled_enc).sample |
|
@unittest.skip( |
|
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" |
|
"1. Change the forward pass to be dtype agnostic.\n" |
|
"2. Unskip this test." |
|
) |
|
def test_layerwise_casting_inference(self): |
|
pass |
|
|
|
@unittest.skip( |
|
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" |
|
"1. Change the forward pass to be dtype agnostic.\n" |
|
"2. Unskip this test." |
|
) |
|
def test_layerwise_casting_memory(self): |
Problem:
scaled_enc / 255.0 promotes the byte tensor to float32 before decode. With bfloat16 weights, the decoder receives float32 inputs and errors. The fast tests already skip layerwise-casting coverage with this exact reason.
Impact:
Layerwise casting / bfloat16 inference cannot cover AutoencoderTiny, and users hit dtype mismatch errors.
Reproduction:
import torch
from diffusers import AutoencoderTiny
model = AutoencoderTiny(
encoder_block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
num_encoder_blocks=(1, 1, 1, 1),
num_decoder_blocks=(1, 1, 1, 1),
).eval().to(dtype=torch.bfloat16)
x = torch.randn(1, 3, 32, 32, dtype=torch.bfloat16)
model(x)
Relevant precedent:
The skipped tests at the link above already describe the expected fix.
Suggested fix:
unscaled_enc = self.unscale_latents(scaled_enc.to(dtype=enc.dtype) / 255.0)
Issue 5: AutoencoderTiny tiled paths hardcode channel counts and derive scale from out_channels
Affected code:
|
# only relevant if vae tiling is enabled |
|
self.spatial_scale_factor = 2**out_channels |
|
self.tile_overlap_factor = 0.125 |
|
self.tile_sample_min_size = 512 |
|
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor |
|
blend_masks = torch.stack( |
|
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") |
|
) |
|
blend_masks = blend_masks.clamp(0, 1).to(x.device) |
|
|
|
# output array |
|
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) |
|
blend_masks = torch.stack( |
|
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") |
|
) |
|
blend_masks = blend_masks.clamp(0, 1).to(x.device) |
|
|
|
# output array |
|
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) |
Problem:
Tiling assumes latent channels are always 4, output channels are always 3, and spatial scale is 2 ** out_channels. These are configurable constructor values.
Impact:
Non-default but valid AutoencoderTiny configs fail as soon as tiling is enabled.
Reproduction:
import torch
from diffusers import AutoencoderTiny
model = AutoencoderTiny(
out_channels=1,
encoder_block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
num_encoder_blocks=(1, 1, 1, 1),
num_decoder_blocks=(1, 1, 1, 1),
latent_channels=4,
)
model.enable_tiling()
x = torch.randn(1, 3, 32, 32)
model.encode(x)
Relevant precedent:
Other tiled autoencoders derive shapes from tensor/model config rather than fixed RGB/latent-channel constants.
Suggested fix:
self.spatial_scale_factor = 2 ** (len(encoder_block_out_channels) - 1)
out = torch.zeros(
x.shape[0], self.config.latent_channels,
x.shape[-2] // sf, x.shape[-1] // sf,
device=x.device, dtype=x.dtype,
)
out = torch.zeros(
x.shape[0], self.config.out_channels,
x.shape[-2] * sf, x.shape[-1] * sf,
device=x.device, dtype=x.dtype,
)
Issue 6: Tuple sample_size breaks VAE tiling comparisons
Affected code:
|
# only relevant if vae tiling is enabled |
|
self.tile_sample_min_size = self.config.sample_size |
|
sample_size = ( |
|
self.config.sample_size[0] |
|
if isinstance(self.config.sample_size, (list, tuple)) |
|
else self.config.sample_size |
|
) |
|
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) |
|
self.tile_overlap_factor = 0.25 |
|
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
|
def set_default_attn_processor(self): |
|
""" |
|
Disables custom attention processors and sets the default attention implementation. |
|
""" |
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnAddedKVProcessor() |
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnProcessor() |
|
else: |
|
raise ValueError( |
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
|
) |
|
|
|
self.set_attn_processor(processor) |
|
|
|
def _encode(self, x: torch.Tensor) -> torch.Tensor: |
|
batch_size, num_channels, height, width = x.shape |
|
|
|
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): |
|
return self._tiled_encode(x) |
|
# only relevant if vae tiling is enabled |
|
self.tile_sample_min_size = self.config.sample_size |
|
sample_size = ( |
|
self.config.sample_size[0] |
|
if isinstance(self.config.sample_size, (list, tuple)) |
|
else self.config.sample_size |
|
) |
|
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1))) |
|
self.tile_overlap_factor = 0.25 |
|
|
|
def _encode(self, x: torch.Tensor) -> torch.Tensor: |
|
batch_size, num_channels, height, width = x.shape |
|
|
|
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): |
|
return self._tiled_encode(x) |
|
# only relevant if vae tiling is enabled |
|
self.tile_sample_min_size = self.config.sample_size |
|
sample_size = ( |
|
self.config.sample_size[0] |
|
if isinstance(self.config.sample_size, (list, tuple)) |
|
else self.config.sample_size |
|
) |
|
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) |
|
self.tile_overlap_factor = 0.25 |
|
|
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
|
def set_default_attn_processor(self): |
|
""" |
|
Disables custom attention processors and sets the default attention implementation. |
|
""" |
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnAddedKVProcessor() |
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnProcessor() |
|
else: |
|
raise ValueError( |
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
|
) |
|
|
|
self.set_attn_processor(processor) |
|
|
|
@apply_forward_hook |
|
def encode( |
|
self, x: torch.Tensor, return_dict: bool = True |
|
) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]: |
|
""" |
|
Encode a batch of images into latents. |
|
|
|
Args: |
|
x (`torch.Tensor`): Input batch of images. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] |
|
instead of a plain tuple. |
|
|
|
Returns: |
|
The latent representations of the encoded images. If `return_dict` is True, a |
|
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a |
|
plain `tuple` is returned. |
|
""" |
|
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): |
|
return self.tiled_encode(x, return_dict=return_dict) |
Problem:
The constructors partly support sample_size as a tuple/list when computing latent tile size, but leave self.tile_sample_min_size as the tuple. Later tiling checks compare int > tuple.
Impact:
Loading or constructing rectangular VAE configs with tuple sample sizes crashes when tiling is enabled.
Reproduction:
import torch
from diffusers import AutoencoderKL
model = AutoencoderKL(
sample_size=(32, 64),
block_out_channels=(4,),
norm_num_groups=1,
latent_channels=2,
)
model.enable_tiling()
model.encode(torch.randn(1, 3, 65, 65))
Relevant precedent:
AutoencoderDC keeps separate height/width tile thresholds.
Suggested fix:
sample_size = self.config.sample_size
if isinstance(sample_size, (list, tuple)):
self.tile_sample_min_height, self.tile_sample_min_width = sample_size
else:
self.tile_sample_min_height = self.tile_sample_min_width = sample_size
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
...
Issue 7: Several target models have no slow model-level tests
Affected code:
|
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin): |
|
base_precision = 1e-2 |
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) |
|
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): |
|
if dtype == torch.bfloat16 and IS_GITHUB_ACTIONS: |
|
pytest.skip("Skipping bf16 test inside GitHub Actions environment") |
|
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype) |
|
|
|
|
|
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin): |
|
"""Training tests for AutoencoderDC.""" |
|
|
|
|
|
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin): |
|
"""Memory optimization tests for AutoencoderDC.""" |
|
|
|
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") |
|
def test_layerwise_casting_memory(self): |
|
super().test_layerwise_casting_memory() |
|
|
|
|
|
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin): |
|
"""Slicing and tiling tests for AutoencoderDC.""" |
|
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): |
|
model_class = AutoencoderKLKVAE |
|
main_input_name = "sample" |
|
base_precision = 1e-2 |
|
|
|
def get_autoencoder_kl_kvae_config(self): |
|
return { |
|
"in_channels": 3, |
|
"channels": 32, |
|
"num_enc_blocks": 1, |
|
"num_dec_blocks": 1, |
|
"z_channels": 4, |
|
"double_z": True, |
|
"ch_mult": (1, 2), |
|
"sample_size": 32, |
|
} |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 2 |
|
num_channels = 3 |
|
sizes = (32, 32) |
|
|
|
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) |
|
|
|
return {"sample": image} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 32, 32) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 32, 32) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = self.get_autoencoder_kl_kvae_config() |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
expected_set = { |
|
"KVAEEncoder2D", |
|
"KVAEDecoder2D", |
|
} |
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): |
|
model_class = AutoencoderKLKVAEVideo |
|
main_input_name = "sample" |
|
base_precision = 1e-2 |
|
|
|
def get_autoencoder_kl_kvae_video_config(self): |
|
return { |
|
"ch": 32, |
|
"ch_mult": (1, 2), |
|
"num_res_blocks": 1, |
|
"in_channels": 3, |
|
"out_ch": 3, |
|
"z_channels": 4, |
|
"temporal_compress_times": 2, |
|
} |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 2 |
|
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2 |
|
num_channels = 3 |
|
sizes = (16, 16) |
|
|
|
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) |
|
|
|
return {"sample": video} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 3, 16, 16) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 3, 16, 16) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = self.get_autoencoder_kl_kvae_video_config() |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
expected_set = { |
|
"KVAECachedEncoder3D", |
|
"KVAECachedDecoder3D", |
|
} |
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
|
|
@unittest.skip("Unsupported test.") |
|
def test_outputs_equivalence(self): |
|
pass |
|
|
|
@unittest.skip( |
|
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass." |
|
) |
|
def test_model_parallelism(self): |
|
pass |
|
|
|
@unittest.skip( |
|
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass." |
|
) |
|
def test_sharded_checkpoints_device_map(self): |
|
pass |
|
|
|
def _run_nondeterministic(self, fn): |
|
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation; |
|
# temporarily relax the requirement for training tests that do backward passes. |
|
import torch |
|
|
|
torch.use_deterministic_algorithms(False) |
|
try: |
|
fn() |
|
finally: |
|
torch.use_deterministic_algorithms(True) |
|
|
|
def test_training(self): |
|
self._run_nondeterministic(super().test_training) |
|
|
|
def test_ema_training(self): |
|
self._run_nondeterministic(super().test_ema_training) |
|
|
|
@unittest.skip( |
|
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict " |
|
"that is mutated during the first forward. On recomputation the cache is already populated, " |
|
"causing a different execution path and numerically different gradients. " |
|
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation." |
|
) |
|
def test_effective_gradient_checkpointing(self): |
|
pass |
|
|
|
def test_layerwise_casting_training(self): |
|
self._run_nondeterministic(super().test_layerwise_casting_training) |
|
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): |
|
model_class = AutoencoderKLMagvit |
|
main_input_name = "sample" |
|
base_precision = 1e-2 |
|
|
|
def get_autoencoder_kl_magvit_config(self): |
|
return { |
|
"in_channels": 3, |
|
"latent_channels": 4, |
|
"out_channels": 3, |
|
"block_out_channels": [8, 8, 8, 8], |
|
"down_block_types": [ |
|
"SpatialDownBlock3D", |
|
"SpatialTemporalDownBlock3D", |
|
"SpatialTemporalDownBlock3D", |
|
"SpatialTemporalDownBlock3D", |
|
], |
|
"up_block_types": [ |
|
"SpatialUpBlock3D", |
|
"SpatialTemporalUpBlock3D", |
|
"SpatialTemporalUpBlock3D", |
|
"SpatialTemporalUpBlock3D", |
|
], |
|
"layers_per_block": 1, |
|
"norm_num_groups": 8, |
|
"spatial_group_norm": True, |
|
} |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 2 |
|
num_frames = 9 |
|
num_channels = 3 |
|
height = 16 |
|
width = 16 |
|
|
|
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device) |
|
|
|
return {"sample": image} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 9, 16, 16) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 9, 16, 16) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = self.get_autoencoder_kl_magvit_config() |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"} |
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
|
|
@unittest.skip("Not quite sure why this test fails. Revisit later.") |
|
def test_effective_gradient_checkpointing(self): |
|
pass |
|
|
|
@unittest.skip("Unsupported test.") |
|
def test_forward_with_norm_groups(self): |
|
pass |
|
|
|
@unittest.skip( |
|
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list." |
|
) |
|
def test_enable_disable_slicing(self): |
|
pass |
|
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): |
|
model_class = AutoencoderKLTemporalDecoder |
|
main_input_name = "sample" |
|
base_precision = 1e-2 |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 3 |
|
num_channels = 3 |
|
sizes = (32, 32) |
|
|
|
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) |
|
num_frames = 3 |
|
|
|
return {"sample": image, "num_frames": num_frames} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 32, 32) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 32, 32) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = { |
|
"block_out_channels": [32, 64], |
|
"in_channels": 3, |
|
"out_channels": 3, |
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], |
|
"latent_channels": 4, |
|
"layers_per_block": 2, |
|
} |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} |
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
class AutoencoderVidTokTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
|
model_class = AutoencoderVidTok |
|
main_input_name = "sample" |
|
base_precision = 1e-2 |
|
|
|
def get_autoencoder_vidtok_config(self): |
|
return { |
|
"is_causal": False, |
|
"in_channels": 3, |
|
"out_channels": 3, |
|
"ch": 128, |
|
"ch_mult": [1, 2, 4, 4, 4], |
|
"z_channels": 6, |
|
"double_z": False, |
|
"num_res_blocks": 2, |
|
"regularizer": "fsq", |
|
"codebook_size": 262144, |
|
} |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 4 |
|
num_frames = 16 |
|
num_channels = 3 |
|
sizes = (32, 32) |
|
|
|
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) |
|
|
|
return {"sample": image} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 16, 32, 32) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 16, 32, 32) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = self.get_autoencoder_vidtok_config() |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_enable_disable_tiling(self): |
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
|
torch.manual_seed(0) |
|
model = self.model_class(**init_dict).to(torch_device) |
|
|
|
torch.manual_seed(0) |
|
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
torch.manual_seed(0) |
|
model.enable_tiling() |
|
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
self.assertLess( |
|
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), |
|
0.5, |
|
"VAE tiling should not affect the inference results", |
|
) |
|
|
|
torch.manual_seed(0) |
|
model.disable_tiling() |
|
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
self.assertEqual( |
|
output_without_tiling.detach().cpu().numpy().all(), |
|
output_without_tiling_2.detach().cpu().numpy().all(), |
|
"Without tiling outputs should match with the outputs when tiling is manually disabled.", |
|
) |
|
|
|
def test_enable_disable_slicing(self): |
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
|
torch.manual_seed(0) |
|
model = self.model_class(**init_dict).to(torch_device) |
|
|
|
inputs_dict.update({"return_dict": False}) |
|
|
|
torch.manual_seed(0) |
|
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
torch.manual_seed(0) |
|
model.enable_slicing() |
|
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
self.assertLess( |
|
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), |
|
0.5, |
|
"VAE slicing should not affect the inference results", |
|
) |
|
|
|
torch.manual_seed(0) |
|
model.disable_slicing() |
|
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
|
|
|
self.assertEqual( |
|
output_without_slicing.detach().cpu().numpy().all(), |
|
output_without_slicing_2.detach().cpu().numpy().all(), |
|
"Without slicing outputs should match with the outputs when slicing is manually disabled.", |
|
) |
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
expected_set = { |
|
"VidTokEncoder3D", |
|
"VidTokDecoder3D", |
|
} |
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
|
|
def test_forward_with_norm_groups(self): |
|
r"""VidTok uses layernorm instead of groupnorm.""" |
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
model = self.model_class(**init_dict) |
|
model.to(torch_device) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
output = model(**inputs_dict) |
|
|
|
if isinstance(output, dict): |
|
output = output.to_tuple()[0] |
|
|
|
self.assertIsNotNone(output) |
|
expected_shape = inputs_dict["sample"].shape |
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") |
|
|
|
@unittest.skip("Unsupported test.") |
|
def test_outputs_equivalence(self): |
|
pass |
|
|
|
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") |
|
def test_layerwise_casting_training(self): |
|
super().test_layerwise_casting_training() |
|
model_class = VQModel |
|
main_input_name = "sample" |
|
|
|
@property |
|
def dummy_input(self, sizes=(32, 32)): |
|
batch_size = 4 |
|
num_channels = 3 |
|
|
|
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) |
|
|
|
return {"sample": image} |
|
|
|
@property |
|
def input_shape(self): |
|
return (3, 32, 32) |
|
|
|
@property |
|
def output_shape(self): |
|
return (3, 32, 32) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = { |
|
"block_out_channels": [8, 16], |
|
"norm_num_groups": 8, |
|
"in_channels": 3, |
|
"out_channels": 3, |
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], |
|
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], |
|
"latent_channels": 3, |
|
} |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
@unittest.skip("Test not supported.") |
|
def test_forward_signature(self): |
|
pass |
|
|
|
@unittest.skip("Test not supported.") |
|
def test_training(self): |
|
pass |
|
|
|
def test_from_pretrained_hub(self): |
|
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) |
|
self.assertIsNotNone(model) |
|
self.assertEqual(len(loading_info["missing_keys"]), 0) |
|
|
|
model.to(torch_device) |
|
image = model(**self.dummy_input) |
|
|
|
assert image is not None, "Make sure output is not None" |
|
|
|
def test_output_pretrained(self): |
|
model = VQModel.from_pretrained("fusing/vqgan-dummy") |
|
model.to(torch_device).eval() |
|
|
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
image = image.to(torch_device) |
|
with torch.no_grad(): |
|
output = model(image).sample |
|
|
|
output_slice = output[0, -1, -3:, -3:].flatten().cpu() |
|
# fmt: off |
|
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) |
|
# fmt: on |
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) |
|
|
|
def test_loss_pretrained(self): |
|
model = VQModel.from_pretrained("fusing/vqgan-dummy") |
|
model.to(torch_device).eval() |
|
|
|
torch.manual_seed(0) |
|
backend_manual_seed(torch_device, 0) |
|
|
|
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) |
|
image = image.to(torch_device) |
|
with torch.no_grad(): |
|
output = model(image).commit_loss.cpu() |
|
# fmt: off |
|
expected_output = torch.tensor([0.1936]) |
|
# fmt: on |
|
self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) |
Problem:
These model test files contain fast coverage but no @slow tests for published checkpoints or expected slices.
Impact:
Checkpoint config/loading/parity regressions can land unnoticed, especially for newer KVAE, VidTok, DC, and Magvit autoencoders.
Reproduction:
from pathlib import Path
checks = {
"AutoencoderDC": "tests/models/autoencoders/test_models_autoencoder_dc.py",
"AutoencoderKLKVAE": "tests/models/autoencoders/test_models_autoencoder_kl_kvae.py",
"AutoencoderKLKVAEVideo": "tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py",
"AutoencoderKLMagvit": "tests/models/autoencoders/test_models_autoencoder_magvit.py",
"AutoencoderKLTemporalDecoder": "tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py",
"AutoencoderVidTok": "tests/models/autoencoders/test_models_autoencoder_vidtok.py",
"VQModel": "tests/models/autoencoders/test_models_vq.py",
}
for name, file in checks.items():
print(name, "@slow" in Path(file).read_text())
Relevant precedent:
Existing autoencoder files such as test_models_autoencoder_tiny.py, test_models_autoencoder_oobleck.py, and test_models_consistency_decoder_vae.py include slow integration coverage.
Suggested fix:
Add one slow checkpoint test per missing model where a public or hf-internal-testing checkpoint exists, asserting load, output shape, finite output, and a small expected slice.
Duplicate-search status
Searched GitHub issues and PRs in huggingface/diffusers for model_autoencoders_shared, AutoencoderKLKVAEVideo, KVAEVideo, autoencoder_kl_kvae_video.py, AutoencoderVidTok, AutoencoderTiny, AutoencoderTiny bfloat16, AutoencoderTiny layerwise casting, AutoencoderTiny tiling latent_channels, AutoencoderKL sample_size tuple tiling, and the specific failure modes above. No duplicate issue or PR was found. Broad related matches were not duplicates: open issue #13628 is a Marigold review, PR #11261 added VidTok, and PR #10347 added layerwise casting.
Coverage Status
Public top-level imports, diffusers.models lazy imports, and PyTorch dummy objects are present for the target public classes. Fast tests exist for the target models. Slow model-level test gaps are listed in Issue 7. API docs are present for most target public models; AutoencoderKLTemporalDecoder and AutoencoderVidTok do not appear in the English API model toctree at:
|
- local: api/models/asymmetricautoencoderkl |
|
title: AsymmetricAutoencoderKL |
|
- local: api/models/autoencoder_dc |
|
title: AutoencoderDC |
|
- local: api/models/autoencoderkl |
|
title: AutoencoderKL |
|
- local: api/models/autoencoderkl_allegro |
|
title: AutoencoderKLAllegro |
|
- local: api/models/autoencoderkl_cogvideox |
|
title: AutoencoderKLCogVideoX |
|
- local: api/models/autoencoderkl_cosmos |
|
title: AutoencoderKLCosmos |
|
- local: api/models/autoencoder_kl_hunyuanimage |
|
title: AutoencoderKLHunyuanImage |
|
- local: api/models/autoencoder_kl_hunyuanimage_refiner |
|
title: AutoencoderKLHunyuanImageRefiner |
|
- local: api/models/autoencoder_kl_hunyuan_video |
|
title: AutoencoderKLHunyuanVideo |
|
- local: api/models/autoencoder_kl_hunyuan_video15 |
|
title: AutoencoderKLHunyuanVideo15 |
|
- local: api/models/autoencoder_kl_kvae |
|
title: AutoencoderKLKVAE |
|
- local: api/models/autoencoder_kl_kvae_video |
|
title: AutoencoderKLKVAEVideo |
|
- local: api/models/autoencoderkl_audio_ltx_2 |
|
title: AutoencoderKLLTX2Audio |
|
- local: api/models/autoencoderkl_ltx_2 |
|
title: AutoencoderKLLTX2Video |
|
- local: api/models/autoencoderkl_ltx_video |
|
title: AutoencoderKLLTXVideo |
|
- local: api/models/autoencoderkl_magvit |
|
title: AutoencoderKLMagvit |
|
- local: api/models/autoencoderkl_mochi |
|
title: AutoencoderKLMochi |
|
- local: api/models/autoencoderkl_qwenimage |
|
title: AutoencoderKLQwenImage |
|
- local: api/models/autoencoder_kl_wan |
|
title: AutoencoderKLWan |
|
- local: api/models/autoencoder_rae |
|
title: AutoencoderRAE |
|
- local: api/models/consistency_decoder_vae |
|
title: ConsistencyDecoderVAE |
|
- local: api/models/autoencoder_oobleck |
|
title: Oobleck AutoEncoder |
|
- local: api/models/autoencoder_tiny |
|
title: Tiny AutoEncoder |
|
- local: api/models/vq |
model_autoencoders_sharedmodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Full target file list reviewed:
autoencoder_asym_kl.py,autoencoder_dc.py,autoencoder_kl.py,autoencoder_kl_kvae.py,autoencoder_kl_kvae_video.py,autoencoder_kl_magvit.py,autoencoder_kl_temporal_decoder.py,autoencoder_oobleck.py,autoencoder_rae.py,autoencoder_tiny.py,autoencoder_vidtok.py,consistency_decoder_vae.py,vae.py,vq_model.py.Issue 1:
AutoencoderKLKVAEVideodiscards encoder log varianceAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Lines 848 to 894 in 0f1abc4
Problem:
KVAECachedEncoder3Doutputs2 * z_channels, but_encode()splits the tensor and keeps only the first half.encode()then reconstructs a fake[mean, zeros]tensor, so the posterior log variance is always zero.Impact:
sample_posterior=Truesamples from the wrong distribution, checkpoint log-variance weights are ignored, and parity with the source KVAE 3D VAE is broken.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Lines 592 to 629 in 0f1abc4
Suggested fix:
Issue 2:
AutoencoderVidTok(return_dict=False)returns a tensor instead of a tupleAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_vidtok.py
Lines 1435 to 1488 in 0f1abc4
Problem:
Every other autoencoder forward path returns a tuple when
return_dict=False. VidTok returns the raw decoded tensor, so common caller code likemodel(..., return_dict=False)[0]silently selects the first batch element.Impact:
Pipeline/model utility code that relies on diffusers’ tuple convention gets the wrong shape and silently drops batch items.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Lines 938 to 954 in 0f1abc4
Suggested fix:
Issue 3:
AutoencoderVidTokbypasses diffusers attention processorsAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_vidtok.py
Lines 426 to 447 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/autoencoder_vidtok.py
Lines 938 to 979 in 0f1abc4
Problem:
VidTokAttnBlockcallsF.scaled_dot_product_attentiondirectly, andAutoencoderVidTokdoes not inheritAttentionMixin. This bypasses the repository’s attention processor/backend path.Impact:
Users cannot inspect or replace attention processors, attention backend selection is ignored for VidTok attention, and the implementation violates the model attention rule.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/autoencoders/autoencoder_rae.py
Lines 202 to 208 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/autoencoder_rae.py
Lines 393 to 442 in 0f1abc4
Suggested fix:
Refactor
VidTokAttnBlockto use the diffusers attention processor pattern withdispatch_attention_fn, and makeAutoencoderVidTokinheritAttentionMixin.Issue 4:
AutoencoderTinyfails bfloat16 forward because it creates float32 activationsAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py
Lines 302 to 312 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_tiny.py
Lines 146 to 159 in 0f1abc4
Problem:
scaled_enc / 255.0promotes the byte tensor to float32 before decode. With bfloat16 weights, the decoder receives float32 inputs and errors. The fast tests already skip layerwise-casting coverage with this exact reason.Impact:
Layerwise casting / bfloat16 inference cannot cover
AutoencoderTiny, and users hit dtype mismatch errors.Reproduction:
Relevant precedent:
The skipped tests at the link above already describe the expected fix.
Suggested fix:
Issue 5:
AutoencoderTinytiled paths hardcode channel counts and derive scale fromout_channelsAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py
Lines 147 to 151 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py
Lines 190 to 196 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/autoencoder_tiny.py
Lines 238 to 244 in 0f1abc4
Problem:
Tiling assumes latent channels are always
4, output channels are always3, and spatial scale is2 ** out_channels. These are configurable constructor values.Impact:
Non-default but valid
AutoencoderTinyconfigs fail as soon as tiling is enabled.Reproduction:
Relevant precedent:
Other tiled autoencoders derive shapes from tensor/model config rather than fixed RGB/latent-channel constants.
Suggested fix:
Issue 6: Tuple
sample_sizebreaks VAE tiling comparisonsAffected code:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py
Lines 132 to 162 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Lines 582 to 596 in 0f1abc4
diffusers/src/diffusers/models/autoencoders/consistency_decoder_vae.py
Lines 159 to 204 in 0f1abc4
Problem:
The constructors partly support
sample_sizeas a tuple/list when computing latent tile size, but leaveself.tile_sample_min_sizeas the tuple. Later tiling checks compareint > tuple.Impact:
Loading or constructing rectangular VAE configs with tuple sample sizes crashes when tiling is enabled.
Reproduction:
Relevant precedent:
AutoencoderDCkeeps separate height/width tile thresholds.Suggested fix:
Issue 7: Several target models have no slow model-level tests
Affected code:
diffusers/tests/models/autoencoders/test_models_autoencoder_dc.py
Lines 81 to 104 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Lines 28 to 73 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py
Lines 28 to 118 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_magvit.py
Lines 28 to 97 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
Lines 32 to 70 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_autoencoder_vidtok.py
Lines 30 to 163 in 0f1abc4
diffusers/tests/models/autoencoders/test_models_vq.py
Lines 31 to 114 in 0f1abc4
Problem:
These model test files contain fast coverage but no
@slowtests for published checkpoints or expected slices.Impact:
Checkpoint config/loading/parity regressions can land unnoticed, especially for newer KVAE, VidTok, DC, and Magvit autoencoders.
Reproduction:
Relevant precedent:
Existing autoencoder files such as
test_models_autoencoder_tiny.py,test_models_autoencoder_oobleck.py, andtest_models_consistency_decoder_vae.pyinclude slow integration coverage.Suggested fix:
Add one slow checkpoint test per missing model where a public or
hf-internal-testingcheckpoint exists, asserting load, output shape, finite output, and a small expected slice.Duplicate-search status
Searched GitHub issues and PRs in
huggingface/diffusersformodel_autoencoders_shared,AutoencoderKLKVAEVideo,KVAEVideo,autoencoder_kl_kvae_video.py,AutoencoderVidTok,AutoencoderTiny,AutoencoderTiny bfloat16,AutoencoderTiny layerwise casting,AutoencoderTiny tiling latent_channels,AutoencoderKL sample_size tuple tiling, and the specific failure modes above. No duplicate issue or PR was found. Broad related matches were not duplicates: open issue#13628is a Marigold review, PR#11261added VidTok, and PR#10347added layerwise casting.Coverage Status
Public top-level imports,
diffusers.modelslazy imports, and PyTorch dummy objects are present for the target public classes. Fast tests exist for the target models. Slow model-level test gaps are listed in Issue 7. API docs are present for most target public models;AutoencoderKLTemporalDecoderandAutoencoderVidTokdo not appear in the English API model toctree at:diffusers/docs/source/en/_toctree.yml
Lines 435 to 481 in 0f1abc4