Skip to content

model_autoencoders_shared model/pipeline review #13652

@hlky

Description

@hlky

model_autoencoders_shared model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Full target file list reviewed: autoencoder_asym_kl.py, autoencoder_dc.py, autoencoder_kl.py, autoencoder_kl_kvae.py, autoencoder_kl_kvae_video.py, autoencoder_kl_magvit.py, autoencoder_kl_temporal_decoder.py, autoencoder_oobleck.py, autoencoder_rae.py, autoencoder_tiny.py, autoencoder_vidtok.py, consistency_decoder_vae.py, vae.py, vq_model.py.

Issue 1: AutoencoderKLKVAEVideo discards encoder log variance

Affected code:

def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
# Cached encoder processes by segments
cache = self._make_encoder_cache()
split_list = [seg_len + 1]
n_frames = x.size(2) - (seg_len + 1)
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
latent = []
for chunk in torch.split(x, split_list, dim=2):
l = self.encoder(chunk, cache)
sample, _ = torch.chunk(l, 2, dim=1)
latent.append(sample)
return torch.cat(latent, dim=2)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
# For cached encoder, we already did the split in _encode
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
posterior = DiagonalGaussianDistribution(h_double)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

Problem:
KVAECachedEncoder3D outputs 2 * z_channels, but _encode() splits the tensor and keeps only the first half. encode() then reconstructs a fake [mean, zeros] tensor, so the posterior log variance is always zero.

Impact:
sample_posterior=True samples from the wrong distribution, checkpoint log-variance weights are ignored, and parity with the source KVAE 3D VAE is broken.

Reproduction:

import torch
from diffusers import AutoencoderKLKVAEVideo

model = AutoencoderKLKVAEVideo(
    ch=32, ch_mult=(1, 2), num_res_blocks=1, z_channels=4, temporal_compress_times=2
).eval()
x = torch.randn(1, 3, 3, 16, 16)

with torch.no_grad():
    raw = model.encoder(x, model._make_encoder_cache())
    raw_mean, raw_logvar = raw.chunk(2, dim=1)
    posterior = model.encode(x).latent_dist

print(torch.allclose(posterior.mean, raw_mean, atol=1e-5))
print(raw_logvar.abs().max().item())
print(posterior.logvar.abs().max().item())  # always 0.0

Relevant precedent:

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

Suggested fix:

# _encode(): keep the full encoder output
latent.append(self.encoder(chunk, cache))

# encode(): do not synthesize a zero logvar half
posterior = DiagonalGaussianDistribution(h)

Issue 2: AutoencoderVidTok(return_dict=False) returns a tensor instead of a tuple

Affected code:

def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = True,
encoder_mode: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, DecoderOutput]:
x = sample
res = 1 if self.is_causal else 0
if self.is_causal:
if x.shape[2] % self.temporal_compression_ratio != res:
time_padding = self.temporal_compression_ratio - x.shape[2] % self.temporal_compression_ratio + res
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate")
else:
time_padding = 0
else:
if x.shape[2] % self.num_sample_frames_batch_size != res:
if not encoder_mode:
time_padding = (
self.num_sample_frames_batch_size - x.shape[2] % self.num_sample_frames_batch_size + res
)
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate")
else:
assert x.shape[2] >= self.num_sample_frames_batch_size, (
f"Too short video. At least {self.num_sample_frames_batch_size} frames."
)
x = x[:, :, : x.shape[2] // self.num_sample_frames_batch_size * self.num_sample_frames_batch_size]
else:
time_padding = 0
if self.is_causal:
x = self._pad_at_dim(x, (self.temporal_compression_ratio - 1, 0), dim=2, pad_mode="replicate")
if self.regularizer == "kl":
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
if encoder_mode:
return z
else:
z, indices = self.encode(x)
if encoder_mode:
return z, indices
dec = self.decode(z)
if time_padding != 0:
dec = dec[:, :, :-time_padding, :, :]
if not return_dict:
return dec
return DecoderOutput(sample=dec)

Problem:
Every other autoencoder forward path returns a tuple when return_dict=False. VidTok returns the raw decoded tensor, so common caller code like model(..., return_dict=False)[0] silently selects the first batch element.

Impact:
Pipeline/model utility code that relies on diffusers’ tuple convention gets the wrong shape and silently drops batch items.

Reproduction:

import torch
from diffusers import AutoencoderVidTok

model = AutoencoderVidTok(
    is_causal=False, ch=8, ch_mult=[1], z_channels=2, double_z=True, num_res_blocks=1, regularizer="kl"
).eval()
x = torch.randn(2, 3, 2, 8, 8)

with torch.no_grad():
    out = model(x, sample_posterior=False, return_dict=False)

print(type(out), out.shape)      # torch.Tensor, not tuple
print(out[0].shape)              # first batch item, not decoded tuple element

Relevant precedent:

def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

Suggested fix:

if not return_dict:
    return (dec,)
return DecoderOutput(sample=dec)

Issue 3: AutoencoderVidTok bypasses diffusers attention processors

Affected code:

class VidTokAttnBlock(nn.Module):
r"""A 3D self-attention block used in VidTok Model."""
def __init__(self, in_channels: int, is_causal: bool = True):
super().__init__()
make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d
self.norm = VidTokLayerNorm(dim=in_channels, eps=1e-6)
self.q = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def attention(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Implement self-attention."""
hidden_states = self.norm(hidden_states)
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
b, c, t, h, w = q.shape
q, k, v = [x.permute(0, 2, 3, 4, 1).reshape(b, t, -1, c).contiguous() for x in [q, k, v]]
hidden_states = F.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
return hidden_states.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3)

class AutoencoderVidTok(ModelMixin, ConfigMixin):
r"""
A VAE model for encoding videos into latents and decoding latent representations into videos, supporting both
continuous and discrete latent representations. Used in [VidTok](https://github.com/microsoft/VidTok).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Args:
in_channels (`int`, defaults to 3):
The number of input channels.
out_channels (`int`, defaults to 3):
The number of output channels.
ch (`int`, defaults to 128):
The number of the basic channel.
ch_mult (`List[int]`, defaults to `[1, 2, 4, 4]`):
The multiple of the basic channel for each block.
z_channels (`int`, defaults to 4):
The number of latent channels.
double_z (`bool`, defaults to `True`):
Whether or not to double the z_channels.
num_res_blocks (`int`, defaults to 2):
The number of resblocks.
spatial_ds (`List`, *optional*, defaults to `None`):
Spatial downsample layers.
spatial_us (`List`, *optional*, defaults to `None`):
Spatial upsample layers.
tempo_ds (`List`, *optional*, defaults to `None`):
Temporal downsample layers.
tempo_us (`List`, *optional*, defaults to `None`):
Temporal upsample layers.
dropout (`float`, defaults to 0.0):
Dropout rate.
regularizer (`str`, defaults to `"kl"`):
The regularizer type - "kl" for continuous cases and "fsq" for discrete cases.
codebook_size (`int`, defaults to 262144):
The codebook size used only in discrete cases.
is_causal (`bool`, defaults to `True`):
Whether it is a causal module.
"""
_supports_gradient_checkpointing = True

Problem:
VidTokAttnBlock calls F.scaled_dot_product_attention directly, and AutoencoderVidTok does not inherit AttentionMixin. This bypasses the repository’s attention processor/backend path.

Impact:
Users cannot inspect or replace attention processors, attention backend selection is ignored for VidTok attention, and the implementation violates the model attention rule.

Reproduction:

from diffusers import AutoencoderVidTok

model = AutoencoderVidTok(
    is_causal=False, ch=8, ch_mult=[1], z_channels=2, double_z=True, num_res_blocks=1, regularizer="kl"
)

print(hasattr(model, "set_attn_processor"))  # False
print([type(m).__name__ for m in model.modules() if type(m).__name__ == "VidTokAttnBlock"])

Relevant precedent:

self.attention = Attention(
query_dim=hidden_size,
heads=num_attention_heads,
dim_head=hidden_size // num_attention_heads,
dropout=attention_probs_dropout_prob,
bias=qkv_bias,
)

class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
images from learned representations.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Args:
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
encoder_hidden_size (`int`, *optional*, defaults to `768`):
Hidden size of the encoder model.
encoder_patch_size (`int`, *optional*, defaults to `14`):
Patch size of the encoder model.
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
Number of hidden layers in the encoder model.
patch_size (`int`, *optional*, defaults to `16`):
Decoder patch size (used for unpatchify and decoder head).
encoder_input_size (`int`, *optional*, defaults to `224`):
Input size expected by the encoder.
image_size (`int`, *optional*):
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
encoder_patch_size) ** 2`.
num_channels (`int`, *optional*, defaults to `3`):
Number of input/output channels.
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
Channel-wise mean for encoder input normalization (ImageNet defaults).
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
Channel-wise std for encoder input normalization (ImageNet defaults).
latents_mean (`list` or `tuple`, *optional*):
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
lists.
latents_std (`list` or `tuple`, *optional*):
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
config-serializable lists.
noise_tau (`float`, *optional*, defaults to `0.0`):
Noise level for training (adds noise to latents during training).
reshape_to_2d (`bool`, *optional*, defaults to `True`):
Whether to reshape latents to 2D (B, C, H, W) format.
use_encoder_loss (`bool`, *optional*, defaults to `False`):
Whether to use encoder hidden states in the loss (for advanced training).
"""
# NOTE: gradient checkpointing is not wired up for this model yet.
_supports_gradient_checkpointing = False
_no_split_modules = ["ViTMAELayer"]

Suggested fix:
Refactor VidTokAttnBlock to use the diffusers attention processor pattern with dispatch_attention_fn, and make AutoencoderVidTok inherit AttentionMixin.

Issue 4: AutoencoderTiny fails bfloat16 forward because it creates float32 activations

Affected code:

enc = self.encode(sample).latents
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
# as if we were storing the latents in an RGBA uint8 image.
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
# unquantize latents back into [0, 1], then unscale latents back to their original range,
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
dec = self.decode(unscaled_enc).sample

@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_memory(self):

Problem:
scaled_enc / 255.0 promotes the byte tensor to float32 before decode. With bfloat16 weights, the decoder receives float32 inputs and errors. The fast tests already skip layerwise-casting coverage with this exact reason.

Impact:
Layerwise casting / bfloat16 inference cannot cover AutoencoderTiny, and users hit dtype mismatch errors.

Reproduction:

import torch
from diffusers import AutoencoderTiny

model = AutoencoderTiny(
    encoder_block_out_channels=(8, 8, 8, 8),
    decoder_block_out_channels=(8, 8, 8, 8),
    num_encoder_blocks=(1, 1, 1, 1),
    num_decoder_blocks=(1, 1, 1, 1),
).eval().to(dtype=torch.bfloat16)

x = torch.randn(1, 3, 32, 32, dtype=torch.bfloat16)
model(x)

Relevant precedent:
The skipped tests at the link above already describe the expected fix.

Suggested fix:

unscaled_enc = self.unscale_latents(scaled_enc.to(dtype=enc.dtype) / 255.0)

Issue 5: AutoencoderTiny tiled paths hardcode channel counts and derive scale from out_channels

Affected code:

# only relevant if vae tiling is enabled
self.spatial_scale_factor = 2**out_channels
self.tile_overlap_factor = 0.125
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor

blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)

blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)

Problem:
Tiling assumes latent channels are always 4, output channels are always 3, and spatial scale is 2 ** out_channels. These are configurable constructor values.

Impact:
Non-default but valid AutoencoderTiny configs fail as soon as tiling is enabled.

Reproduction:

import torch
from diffusers import AutoencoderTiny

model = AutoencoderTiny(
    out_channels=1,
    encoder_block_out_channels=(8, 8, 8, 8),
    decoder_block_out_channels=(8, 8, 8, 8),
    num_encoder_blocks=(1, 1, 1, 1),
    num_decoder_blocks=(1, 1, 1, 1),
    latent_channels=4,
)
model.enable_tiling()

x = torch.randn(1, 3, 32, 32)
model.encode(x)

Relevant precedent:
Other tiled autoencoders derive shapes from tensor/model config rather than fixed RGB/latent-channel constants.

Suggested fix:

self.spatial_scale_factor = 2 ** (len(encoder_block_out_channels) - 1)

out = torch.zeros(
    x.shape[0], self.config.latent_channels,
    x.shape[-2] // sf, x.shape[-1] // sf,
    device=x.device, dtype=x.dtype,
)

out = torch.zeros(
    x.shape[0], self.config.out_channels,
    x.shape[-2] * sf, x.shape[-1] * sf,
    device=x.device, dtype=x.dtype,
)

Issue 6: Tuple sample_size breaks VAE tiling comparisons

Affected code:

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
self.tile_overlap_factor = 0.25
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)

Problem:
The constructors partly support sample_size as a tuple/list when computing latent tile size, but leave self.tile_sample_min_size as the tuple. Later tiling checks compare int > tuple.

Impact:
Loading or constructing rectangular VAE configs with tuple sample sizes crashes when tiling is enabled.

Reproduction:

import torch
from diffusers import AutoencoderKL

model = AutoencoderKL(
    sample_size=(32, 64),
    block_out_channels=(4,),
    norm_num_groups=1,
    latent_channels=2,
)
model.enable_tiling()
model.encode(torch.randn(1, 3, 65, 65))

Relevant precedent:
AutoencoderDC keeps separate height/width tile thresholds.

Suggested fix:

sample_size = self.config.sample_size
if isinstance(sample_size, (list, tuple)):
    self.tile_sample_min_height, self.tile_sample_min_width = sample_size
else:
    self.tile_sample_min_height = self.tile_sample_min_width = sample_size

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
    ...

Issue 7: Several target models have no slow model-level tests

Affected code:

class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
base_precision = 1e-2
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
if dtype == torch.bfloat16 and IS_GITHUB_ACTIONS:
pytest.skip("Skipping bf16 test inside GitHub Actions environment")
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype)
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderDC."""
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderDC."""
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderDC."""

class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_config(self):
return {
"in_channels": 3,
"channels": 32,
"num_enc_blocks": 1,
"num_dec_blocks": 1,
"z_channels": 4,
"double_z": True,
"ch_mult": (1, 2),
"sample_size": 32,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_video_config(self):
return {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 1,
"in_channels": 3,
"out_ch": 3,
"z_channels": 4,
"temporal_compress_times": 2,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": video}
@property
def input_shape(self):
return (3, 3, 16, 16)
@property
def output_shape(self):
return (3, 3, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass
def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_training(self):
self._run_nondeterministic(super().test_training)
def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
@unittest.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
)
def test_effective_gradient_checkpointing(self):
pass
def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)

class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_magvit_config(self):
return {
"in_channels": 3,
"latent_channels": 4,
"out_channels": 3,
"block_out_channels": [8, 8, 8, 8],
"down_block_types": [
"SpatialDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
"SpatialTemporalDownBlock3D",
],
"up_block_types": [
"SpatialUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
"SpatialTemporalUpBlock3D",
],
"layers_per_block": 1,
"norm_num_groups": 8,
"spatial_group_norm": True,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
height = 16
width = 16
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_magvit_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Not quite sure why this test fails. Revisit later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
)
def test_enable_disable_slicing(self):
pass

class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@property
def dummy_input(self):
batch_size = 3
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
num_frames = 3
return {"sample": image, "num_frames": num_frames}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"latent_channels": 4,
"layers_per_block": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

class AutoencoderVidTokTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderVidTok
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_vidtok_config(self):
return {
"is_causal": False,
"in_channels": 3,
"out_channels": 3,
"ch": 128,
"ch_mult": [1, 2, 4, 4, 4],
"z_channels": 6,
"double_z": False,
"num_res_blocks": 2,
"regularizer": "fsq",
"codebook_size": 262144,
}
@property
def dummy_input(self):
batch_size = 4
num_frames = 16
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 16, 32, 32)
@property
def output_shape(self):
return (3, 16, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_vidtok_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"VidTokEncoder3D",
"VidTokDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_forward_with_norm_groups(self):
r"""VidTok uses layernorm instead of groupnorm."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_training(self):
super().test_layerwise_casting_training()

model_class = VQModel
main_input_name = "sample"
@property
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [8, 16],
"norm_num_groups": 8,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 3,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Test not supported.")
def test_forward_signature(self):
pass
@unittest.skip("Test not supported.")
def test_training(self):
pass
def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.to(torch_device).eval()
torch.manual_seed(0)
backend_manual_seed(torch_device, 0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_loss_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.to(torch_device).eval()
torch.manual_seed(0)
backend_manual_seed(torch_device, 0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image).commit_loss.cpu()
# fmt: off
expected_output = torch.tensor([0.1936])
# fmt: on
self.assertTrue(torch.allclose(output, expected_output, atol=1e-3))

Problem:
These model test files contain fast coverage but no @slow tests for published checkpoints or expected slices.

Impact:
Checkpoint config/loading/parity regressions can land unnoticed, especially for newer KVAE, VidTok, DC, and Magvit autoencoders.

Reproduction:

from pathlib import Path

checks = {
    "AutoencoderDC": "tests/models/autoencoders/test_models_autoencoder_dc.py",
    "AutoencoderKLKVAE": "tests/models/autoencoders/test_models_autoencoder_kl_kvae.py",
    "AutoencoderKLKVAEVideo": "tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py",
    "AutoencoderKLMagvit": "tests/models/autoencoders/test_models_autoencoder_magvit.py",
    "AutoencoderKLTemporalDecoder": "tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py",
    "AutoencoderVidTok": "tests/models/autoencoders/test_models_autoencoder_vidtok.py",
    "VQModel": "tests/models/autoencoders/test_models_vq.py",
}
for name, file in checks.items():
    print(name, "@slow" in Path(file).read_text())

Relevant precedent:
Existing autoencoder files such as test_models_autoencoder_tiny.py, test_models_autoencoder_oobleck.py, and test_models_consistency_decoder_vae.py include slow integration coverage.

Suggested fix:
Add one slow checkpoint test per missing model where a public or hf-internal-testing checkpoint exists, asserting load, output shape, finite output, and a small expected slice.

Duplicate-search status

Searched GitHub issues and PRs in huggingface/diffusers for model_autoencoders_shared, AutoencoderKLKVAEVideo, KVAEVideo, autoencoder_kl_kvae_video.py, AutoencoderVidTok, AutoencoderTiny, AutoencoderTiny bfloat16, AutoencoderTiny layerwise casting, AutoencoderTiny tiling latent_channels, AutoencoderKL sample_size tuple tiling, and the specific failure modes above. No duplicate issue or PR was found. Broad related matches were not duplicates: open issue #13628 is a Marigold review, PR #11261 added VidTok, and PR #10347 added layerwise casting.

Coverage Status

Public top-level imports, diffusers.models lazy imports, and PyTorch dummy objects are present for the target public classes. Fast tests exist for the target models. Slow model-level test gaps are listed in Issue 7. API docs are present for most target public models; AutoencoderKLTemporalDecoder and AutoencoderVidTok do not appear in the English API model toctree at:

- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_allegro
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuanimage
title: AutoencoderKLHunyuanImage
- local: api/models/autoencoder_kl_hunyuanimage_refiner
title: AutoencoderKLHunyuanImageRefiner
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoder_kl_kvae
title: AutoencoderKLKVAE
- local: api/models/autoencoder_kl_kvae_video
title: AutoencoderKLKVAEVideo
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/autoencoderkl_qwenimage
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
title: Oobleck AutoEncoder
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/vq

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions