From 382aad0a6c844bcda26615266d06a04c65d1c437 Mon Sep 17 00:00:00 2001 From: wangyuqi Date: Sun, 25 Jan 2026 02:54:35 +0800 Subject: [PATCH 01/30] feat: implement three RAE encoders(dinov2, siglip2, mae) --- .../models/autoencoders/autoencoder_rae.py | 211 ++++++++++++++++++ .../test_models_autoencoder_rae.py | 58 +++++ 2 files changed, 269 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_rae.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_rae.py diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py new file mode 100644 index 000000000000..b8b440ea270c --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass +from math import sqrt +from typing import Dict, Type, Optional, Callable, Union, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..attention import FeedForward +from ..attention_processor import Attention +from ..modeling_utils import ModelMixin + +ENCODER_ARCHS: Dict[str, Type] = {} + +def register_encoder(cls: Optional[Type] = None, *, name: Optional[str] = None) -> Union[Callable[[Type], Type], Type]: + def decorator(inner_cls: Type) -> Type: + encoder_name = name or inner_cls.__name__ + if encoder_name in ENCODER_ARCHS and ENCODER_ARCHS[encoder_name] is not inner_cls: + raise ValueError(f"Encoder '{encoder_name}' is already registered.") + ENCODER_ARCHS[encoder_name] = inner_cls + return inner_cls + + if cls is None: + return decorator + return decorator(cls) + + +@register_encoder(name="dinov2") +class Dinov2Encoder(nn.Module): + def __init__( + self, + encoder_name_or_path: str = "facebook/dinov2-with-registers-base" + ): + super().__init__() + from transformers import Dinov2WithRegistersModel + + self.model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path) + self.model.requires_grad_(False) + self.model.layernorm.elementwise_affine = False + self.model.layernorm.weight = None + self.model.layernorm.bias = None + + self.patch_size = self.model.config.patch_size + self.hidden_size = self.model.config.hidden_size + + @torch.no_grad() + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + images is of shape (B, C, H, W) + where B is batch size, C is number of channels, H and W are height and + """ + outputs = self.model(images, output_hidden_states=True) + unused_token_num = 5 # 1 CLS + 4 register tokens + image_features = outputs.last_hidden_state[:, unused_token_num:] + return image_features + + +@register_encoder(name="siglip2") +class Siglip2Encoder(nn.Module): + def __init__( + self, + encoder_name_or_path: str = "google/siglip2-base-patch16-256" + ): + super().__init__() + from transformers import SiglipModel + self.model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + # remove the affine of final layernorm + self.model.post_layernorm.elementwise_affine = False + # remove the param + self.model.post_layernorm.weight = None + self.model.post_layernorm.bias = None + self.hidden_size = self.model.config.hidden_size + self.patch_size = self.model.config.patch_size + + @torch.no_grad() + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + images is of shape (B, C, H, W) + where B is batch size, C is number of channels, H and W are height and + """ + outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding = True) + image_features = outputs.last_hidden_state + return image_features + + +@register_encoder(name="mae") +class MAEEncoder(nn.Module): + def __init__(self, encoder_name_or_path: str = "facebook/vit-mae-base"): + super().__init__() + from transformers import ViTMAEForPreTraining + self.model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + # remove the affine of final layernorm + self.model.layernorm.elementwise_affine = False + # remove the param + self.model.layernorm.weight = None + self.model.layernorm.bias = None + self.hidden_size = self.model.config.hidden_size + self.patch_size = self.model.config.patch_size + self.model.config.mask_ratio = 0. # no masking + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + images is of shape (B, C, H, W) + where B is batch size, C is number of channels, H and W are height and width of the image + """ + h,w = images.shape[2], images.shape[3] + patch_num = int(h * w // self.patch_size ** 2) + assert patch_num * self.patch_size ** 2 == h * w, 'image size should be divisible by patch size' + noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0],-1).to(images.device).to(images.dtype) + outputs = self.model(images, noise, interpolate_pos_encoding = True) + image_features = outputs.last_hidden_state[:, 1:] # remove cls token + return image_features + + +@dataclass +class AutoencoderRAEOutput(BaseOutput): + """ + Output of AutoencoderRAE encoding method. + + Args: + latent (`torch.Tensor`): + Encoded outputs of the encoder (frozen representation encoder). + Shape: (batch_size, hidden_size, latent_height, latent_width) + """ + + latent: torch.Tensor + + +class RAEDecoderOutput(BaseOutput): + """ + Output of RAEDecoder. + + Args: + sample (`torch.Tensor`): + Decoded output from decoder. Shape: (batch_size, num_channels, image_height, image_width) + """ + + sample: torch.Tensor + + + + +class AutoencoderRAE( + ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin +): + r""" + Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. + + This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder + to reconstruct images from learned representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods + implemented for all models (such as downloading or saving). + + Args: + encoder_cls (`str`, *optional*, defaults to `"dinov2"`): + Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. + encoder_name_or_path (`str`, *optional*, defaults to `"facebook/dinov2-with-registers-base"`): + Path to pretrained encoder model or model identifier from huggingface.co/models. + decoder_config (`ViTMAEDecoderConfig`, *optional*): + Configuration for the decoder. If None, a default config will be used. + num_patches (`int`, *optional*, defaults to `196`): + Number of patches in the latent space (14x14 = 196 for 224x224 image with patch size 16). + patch_size (`int`, *optional*, defaults to `16`): + Patch size for both encoder and decoder. + encoder_input_size (`int`, *optional*, defaults to `224`): + Input size expected by the encoder. + image_size (`int`, *optional*, defaults to `256`): + Output image size. + num_channels (`int`, *optional*, defaults to `3`): + Number of input/output channels. + latent_mean (`torch.Tensor`, *optional*): + Optional mean for latent normalization. + latent_var (`torch.Tensor`, *optional*): + Optional variance for latent normalization. + noise_tau (`float`, *optional*, defaults to `0.0`): + Noise level for training (adds noise to latents during training). + reshape_to_2d (`bool`, *optional*, defaults to `True`): + Whether to reshape latents to 2D (B, C, H, W) format. + use_encoder_loss (`bool`, *optional*, defaults to `False`): + Whether to use encoder hidden states in the loss (for advanced training). + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["ViTMAEDecoderLayer"] + + @register_to_config + def __init__( + self, + encoder_cls: str = "dinov2", + encoder_name_or_path: str = "facebook/dinov2-with-registers-base", + decoder_config: str = None, + num_patches: int = 196, + patch_size: int = 16, + encoder_input_size: int = 224, + image_size: int = 256, + num_channels: int = 3, + latent_mean: Optional[torch.Tensor] = None, + latent_var: Optional[torch.Tensor] = None, + noise_tau: float = 0.0, + reshape_to_2d: bool = True, + use_encoder_loss: bool = False, + ): + super().__init__() + diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py new file mode 100644 index 000000000000..500e7257df5b --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -0,0 +1,58 @@ +import os +import unittest + +import torch + +from diffusers.models.autoencoders.autoencoder_rae import Dinov2Encoder, Siglip2Encoder, MAEEncoder +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + floats_tensor, + load_hf_numpy, + slow, + torch_all_close, + torch_device, +) + +enable_full_determinism() + + +class AutoencoderRAEEncoderUnitTests(unittest.TestCase): + + + def test_dinov2_encoder_forward_shape(self): + dino_path = os.environ.get("DINO_PATH", "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/dinov2-with-registers-base") + + encoder = Dinov2Encoder(encoder_name_or_path=dino_path).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = encoder(x) + + self.assertEqual(y.ndim, 3) + self.assertEqual(y.shape[0], 1) + self.assertEqual(y.shape[1], 256) + self.assertEqual(y.shape[2], encoder.hidden_size) + + def test_siglip2_encoder_forward_shape(self): + siglip2_path = "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/siglip2-base-patch16-256" + + encoder = Siglip2Encoder(encoder_name_or_path=siglip2_path).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = encoder(x) + + self.assertEqual(y.ndim, 3) + self.assertEqual(y.shape[0], 1) + self.assertEqual(y.shape[1], 196) + self.assertEqual(y.shape[2], encoder.hidden_size) + + def test_mae_encoder_forward_shape(self): + mae_path = "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/vit-mae-base" + + encoder = MAEEncoder(encoder_name_or_path=mae_path).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = encoder(x) + + self.assertEqual(y.ndim, 3) + self.assertEqual(y.shape[0], 1) + self.assertEqual(y.shape[1], 196) + self.assertEqual(y.shape[2], encoder.hidden_size) + \ No newline at end of file From f82cecc29828e18b69ec586f1c13c22506f04437 Mon Sep 17 00:00:00 2001 From: wangyuqi Date: Wed, 28 Jan 2026 20:19:31 +0800 Subject: [PATCH 02/30] feat: finish first version of autoencoder_rae --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../models/autoencoders/autoencoder_rae.py | 588 +++++++++++++++++- .../test_models_autoencoder_rae.py | 121 +++- 5 files changed, 658 insertions(+), 56 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 24b9c12db6d4..5272d5e7f76e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -203,6 +203,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", + "AutoencoderRAE", "AutoencoderTiny", "AutoModel", "BriaFiboTransformer2DModel", @@ -962,6 +963,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + AutoencoderRAE, AutoencoderTiny, AutoModel, BriaFiboTransformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4d1db36a7352..cae97f9ffb7a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -49,6 +49,7 @@ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] + _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] @@ -166,6 +167,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + AutoencoderRAE, AutoencoderTiny, ConsistencyDecoderVAE, VQModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 8e7a9c81d2ad..23665ee0532e 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -18,6 +18,7 @@ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index b8b440ea270c..0aa826e9d452 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from math import sqrt -from typing import Dict, Type, Optional, Callable, Union, Tuple +from types import SimpleNamespace +from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch @@ -12,9 +13,8 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook -from ..attention import FeedForward -from ..attention_processor import Attention from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput ENCODER_ARCHS: Dict[str, Type] = {} @@ -118,36 +118,358 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: return image_features +def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / (10000**omega) # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2) + emb_sin = np.sin(out) + emb_cos = np.cos(out) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, add_cls_token: bool = False) -> np.ndarray: + """ + Returns: + pos_embed: (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim) + """ + grid_h = np.arange(grid_size, dtype=np.float64) + grid_w = np.arange(grid_size, dtype=np.float64) + grid = np.meshgrid(grid_w, grid_h) # w first + grid = np.stack(grid, axis=0) # (2, grid, grid) + grid = grid.reshape([2, 1, grid_size, grid_size]) + + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + + if add_cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim], dtype=np.float64), pos_embed], axis=0) + + return pos_embed + + @dataclass -class AutoencoderRAEOutput(BaseOutput): +class RAEDecoderOutput(BaseOutput): """ - Output of AutoencoderRAE encoding method. + Output of `RAEDecoder`. Args: - latent (`torch.Tensor`): - Encoded outputs of the encoder (frozen representation encoder). - Shape: (batch_size, hidden_size, latent_height, latent_width) + logits (`torch.Tensor`): + Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`. """ - latent: torch.Tensor + logits: torch.Tensor -class RAEDecoderOutput(BaseOutput): +ACT2FN = { + "gelu": F.gelu, + "relu": F.relu, + "silu": F.silu, + "swish": F.silu, +} + + +class ViTMAESelfAttention(nn.Module): + def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}") + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) + self.key = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) + self.value = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) + self.dropout = nn.Dropout(attn_dropout) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + mixed_query_layer = self.query(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / (self.attention_head_size**0.5) + attention_probs = torch.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + return context_layer + + +class ViTMAESelfOutput(nn.Module): + def __init__(self, hidden_size: int, hidden_dropout_prob: float = 0.0): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ViTMAEAttention(nn.Module): + def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): + super().__init__() + self.attention = ViTMAESelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qkv_bias=qkv_bias, + attn_dropout=attn_dropout, + ) + self.output = ViTMAESelfOutput(hidden_size=hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attn_output = self.attention(hidden_states) + attn_output = self.output(attn_output) + return attn_output + + +class ViTMAEIntermediate(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = ACT2FN.get(hidden_act, None) + if self.intermediate_act_fn is None: + raise ValueError(f"Unsupported hidden_act={hidden_act}") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ViTMAEOutput(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class ViTMAELayer(nn.Module): + """ + This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block). """ - Output of RAEDecoder. - Args: - sample (`torch.Tensor`): - Decoded output from decoder. Shape: (batch_size, num_channels, image_height, image_width) + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + intermediate_size: int, + qkv_bias: bool = True, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + hidden_act: str = "gelu", + ): + super().__init__() + self.attention = ViTMAEAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qkv_bias=qkv_bias, + attn_dropout=attention_probs_dropout_prob, + ) + self.intermediate = ViTMAEIntermediate(hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act) + self.output = ViTMAEOutput(hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob) + self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class GeneralDecoder(nn.Module): + """ + Decoder implementation ported from RAE-main to keep checkpoint compatibility. + + Key attributes (must match checkpoint keys): + - decoder_embed + - decoder_pos_embed + - decoder_layers + - decoder_norm + - decoder_pred + - trainable_cls_token """ - sample: torch.Tensor + def __init__(self, config, num_patches: int): + super().__init__() + self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True) + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False) + + self.decoder_layers = nn.ModuleList( + [ + ViTMAELayer( + hidden_size=config.decoder_hidden_size, + num_attention_heads=config.decoder_num_attention_heads, + intermediate_size=config.decoder_intermediate_size, + qkv_bias=config.qkv_bias, + layer_norm_eps=config.layer_norm_eps, + hidden_dropout_prob=config.hidden_dropout_prob, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + hidden_act=config.hidden_act, + ) + for _ in range(config.decoder_num_hidden_layers) + ] + ) + + self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) + self.decoder_pred = nn.Linear(config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True) + self.gradient_checkpointing = False + self.config = config + self.num_patches = num_patches + + self._initialize_weights(num_patches) + self.set_trainable_cls_token() + + def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None): + tensor = torch.zeros(1, 1, self.config.decoder_hidden_size) if tensor is None else tensor + self.trainable_cls_token = nn.Parameter(tensor) + + def _initialize_weights(self, num_patches: int): + grid_size = int(num_patches**0.5) + pos_embed = _get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], grid_size, add_cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + embeddings_positions = embeddings.shape[1] - 1 + num_positions = self.decoder_pos_embed.shape[1] - 1 + + class_pos_embed = self.decoder_pos_embed[:, 0, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + dim = self.decoder_pos_embed.shape[-1] + + patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2) + patch_pos_embed = F.interpolate( + patch_pos_embed, + scale_factor=(1, embeddings_positions / num_positions), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: + b, l, c = x.shape + if l == self.num_patches: + return x + h = w = int(l**0.5) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5)) + x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) + x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) + return x + + def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None): + patch_size, num_channels = self.config.patch_size, self.config.num_channels + original_image_size = ( + original_image_size if original_image_size is not None else (self.config.image_size, self.config.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size + if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: + raise ValueError( + f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" + ) + + batch_size = patchified_pixel_values.shape[0] + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, + num_patches_h, + num_patches_w, + patch_size, + patch_size, + num_channels, + ) + patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) + pixel_values = patchified_pixel_values.reshape( + batch_size, + num_channels, + num_patches_h * patch_size, + num_patches_w * patch_size, + ) + return pixel_values + + def forward( + self, + hidden_states: torch.Tensor, + *, + interpolate_pos_encoding: bool = False, + drop_cls_token: bool = False, + return_dict: bool = True, + ) -> Union[RAEDecoderOutput, Tuple[torch.Tensor]]: + x = self.decoder_embed(hidden_states) + if drop_cls_token: + x_ = x[:, 1:, :] + x_ = self.interpolate_latent(x_) + else: + x_ = self.interpolate_latent(x) + + cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1) + x = torch.cat([cls_token, x_], dim=1) + + if interpolate_pos_encoding: + if not drop_cls_token: + raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True") + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype) + for layer_module in self.decoder_layers: + hidden_states = layer_module(hidden_states) + + hidden_states = self.decoder_norm(hidden_states) + logits = self.decoder_pred(hidden_states) + logits = logits[:, 1:, :] + + if not return_dict: + return (logits,) + return RAEDecoderOutput(logits=logits) + + +# Backward-compatible alias: keep `RAEDecoder` name used by `AutoencoderRAE` +class RAEDecoder(GeneralDecoder): + pass class AutoencoderRAE( - ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin + ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin ): r""" Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. @@ -163,16 +485,13 @@ class AutoencoderRAE( Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. encoder_name_or_path (`str`, *optional*, defaults to `"facebook/dinov2-with-registers-base"`): Path to pretrained encoder model or model identifier from huggingface.co/models. - decoder_config (`ViTMAEDecoderConfig`, *optional*): - Configuration for the decoder. If None, a default config will be used. - num_patches (`int`, *optional*, defaults to `196`): - Number of patches in the latent space (14x14 = 196 for 224x224 image with patch size 16). patch_size (`int`, *optional*, defaults to `16`): - Patch size for both encoder and decoder. + Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): Input size expected by the encoder. - image_size (`int`, *optional*, defaults to `256`): - Output image size. + image_size (`int`, *optional*): + Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like RAE-main: + `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // encoder.patch_size) ** 2`. num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. latent_mean (`torch.Tensor`, *optional*): @@ -187,25 +506,238 @@ class AutoencoderRAE( Whether to use encoder hidden states in the loss (for advanced training). """ - _supports_gradient_checkpointing = True - _no_split_modules = ["ViTMAEDecoderLayer"] + # NOTE: gradient checkpointing is not wired up for this model yet. + _supports_gradient_checkpointing = False + _no_split_modules = ["ViTMAELayer"] @register_to_config def __init__( self, encoder_cls: str = "dinov2", encoder_name_or_path: str = "facebook/dinov2-with-registers-base", - decoder_config: str = None, - num_patches: int = 196, + decoder_hidden_size: int = 512, + decoder_num_hidden_layers: int = 8, + decoder_num_attention_heads: int = 16, + decoder_intermediate_size: int = 2048, patch_size: int = 16, encoder_input_size: int = 224, - image_size: int = 256, + image_size: Optional[int] = None, num_channels: int = 3, latent_mean: Optional[torch.Tensor] = None, latent_var: Optional[torch.Tensor] = None, noise_tau: float = 0.0, reshape_to_2d: bool = True, use_encoder_loss: bool = False, + scaling_factor: float = 1.0, ): super().__init__() + if encoder_cls not in ENCODER_ARCHS: + raise ValueError(f"Unknown encoder_cls='{encoder_cls}'. Available: {sorted(ENCODER_ARCHS.keys())}") + + self.encoder_input_size = encoder_input_size + self.noise_tau = float(noise_tau) + self.reshape_to_2d = bool(reshape_to_2d) + self.use_encoder_loss = bool(use_encoder_loss) + self.scaling_factor = float(scaling_factor) + + # Frozen representation encoder + self.encoder: nn.Module = ENCODER_ARCHS[encoder_cls](encoder_name_or_path=encoder_name_or_path) + + # RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2 + encoder_patch_size = getattr(self.encoder, "patch_size", None) + if encoder_patch_size is None: + raise ValueError(f"Encoder '{encoder_cls}' must define `.patch_size` attribute.") + encoder_patch_size = int(encoder_patch_size) + if self.encoder_input_size % encoder_patch_size != 0: + raise ValueError( + f"encoder_input_size={self.encoder_input_size} must be divisible by encoder.patch_size={encoder_patch_size}." + ) + num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 + + # Decoder patch size is independent from encoder patch size. + decoder_patch_size = int(patch_size) + if decoder_patch_size <= 0: + raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") + + grid = int(sqrt(num_patches)) + if grid * grid != num_patches: + raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") + + derived_image_size = decoder_patch_size * grid + if image_size is None: + image_size = derived_image_size + else: + image_size = int(image_size) + if image_size != derived_image_size: + raise ValueError( + f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " + f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." + ) + + # Normalization stats from the encoder's image processor + # RAE-main uses AutoImageProcessor mean/std; we follow the same. + encoder_mean = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1) + encoder_std = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1) + try: + from transformers import AutoImageProcessor + + try: + proc = AutoImageProcessor.from_pretrained(encoder_name_or_path, local_files_only=True) + except Exception: + proc = AutoImageProcessor.from_pretrained(encoder_name_or_path, local_files_only=False) + encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1) + encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1) + except Exception: + # Keep default 0.5/0.5 if processor is unavailable. + pass + + self.register_buffer("encoder_mean", encoder_mean, persistent=True) + self.register_buffer("encoder_std", encoder_std, persistent=True) + + # Optional latent normalization (RAE-main uses mean/var) + self.do_latent_normalization = latent_mean is not None or latent_var is not None + if latent_mean is not None: + self.register_buffer("latent_mean", latent_mean.detach().clone(), persistent=True) + else: + self.latent_mean = None + if latent_var is not None: + self.register_buffer("latent_var", latent_var.detach().clone(), persistent=True) + else: + self.latent_var = None + + # ViT-MAE style decoder + encoder_hidden_size = getattr(self.encoder, "hidden_size", None) + if encoder_hidden_size is None: + raise ValueError(f"Encoder '{encoder_cls}' must define `.hidden_size` attribute.") + + decoder_config = SimpleNamespace( + hidden_size=int(encoder_hidden_size), + decoder_hidden_size=int(decoder_hidden_size), + decoder_num_hidden_layers=int(decoder_num_hidden_layers), + decoder_num_attention_heads=int(decoder_num_attention_heads), + decoder_intermediate_size=int(decoder_intermediate_size), + patch_size=int(decoder_patch_size), + image_size=int(image_size), + num_channels=int(num_channels), + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + layer_norm_eps=1e-12, + hidden_act="gelu", + ) + self.decoder = RAEDecoder(decoder_config, num_patches=int(num_patches)) + self.num_patches = int(num_patches) + self.decoder_patch_size = int(decoder_patch_size) + self.decoder_image_size = int(image_size) + + # Slicing support (batch dimension) similar to other diffusers autoencoders + self.use_slicing = False + + def _noising(self, x: torch.Tensor) -> torch.Tensor: + # Per-sample random sigma in [0, noise_tau] + noise_sigma = self.noise_tau * torch.rand((x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype) + return x + noise_sigma * torch.randn_like(x) + + def _maybe_resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.shape + if h != self.encoder_input_size or w != self.encoder_input_size: + x = F.interpolate( + x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False + ) + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return (x - mean) / std + + def _maybe_denormalize_image(self, x: torch.Tensor) -> torch.Tensor: + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return x * std + mean + + def _maybe_normalize_latents(self, z: torch.Tensor) -> torch.Tensor: + if not self.do_latent_normalization: + return z + latent_mean = self.latent_mean.to(device=z.device, dtype=z.dtype) if self.latent_mean is not None else 0 + latent_var = self.latent_var.to(device=z.device, dtype=z.dtype) if self.latent_var is not None else 1 + return (z - latent_mean) / torch.sqrt(latent_var + 1e-5) + + def _maybe_denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: + if not self.do_latent_normalization: + return z + latent_mean = self.latent_mean.to(device=z.device, dtype=z.dtype) if self.latent_mean is not None else 0 + latent_var = self.latent_var.to(device=z.device, dtype=z.dtype) if self.latent_var is not None else 1 + return z * torch.sqrt(latent_var + 1e-5) + latent_mean + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = self._maybe_resize_and_normalize(x) + + # Encoder is frozen; many encoders already run under no_grad + tokens = self.encoder(x) # (B, N, C) + + if self.training and self.noise_tau > 0: + tokens = self._noising(tokens) + + if self.reshape_to_2d: + b, n, c = tokens.shape + side = int(sqrt(n)) + if side * side != n: + raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.") + z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w) + else: + z = tokens + + z = self._maybe_normalize_latents(z) + + # Follow diffusers convention: optionally scale latents for diffusion + if self.scaling_factor != 1.0: + z = z * self.scaling_factor + + return z + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]: + if self.use_slicing and x.shape[0] > 1: + latents = torch.cat([self._encode(x_slice) for x_slice in x.split(1)], dim=0) + else: + latents = self._encode(x) + + if not return_dict: + return (latents,) + return EncoderOutput(latent=latents) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # Undo scaling factor if applied at encode time + if self.scaling_factor != 1.0: + z = z / self.scaling_factor + + z = self._maybe_denormalize_latents(z) + + if self.reshape_to_2d: + b, c, h, w = z.shape + tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C) + else: + tokens = z + + logits = self.decoder(tokens, return_dict=True).logits + x_rec = self.decoder.unpatchify(logits) + x_rec = self._maybe_denormalize_image(x_rec) + return x_rec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + if self.use_slicing and z.shape[0] > 1: + decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + latents = self.encode(sample, return_dict=False)[0] + decoded = self.decode(latents, return_dict=False)[0] + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 500e7257df5b..3aa9641f922f 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -1,58 +1,123 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc import os import unittest import torch -from diffusers.models.autoencoders.autoencoder_rae import Dinov2Encoder, Siglip2Encoder, MAEEncoder -from ...testing_utils import ( - backend_empty_cache, - enable_full_determinism, - floats_tensor, - load_hf_numpy, - slow, - torch_all_close, - torch_device, -) +from diffusers.models.autoencoders.autoencoder_rae import AutoencoderRAE, Dinov2Encoder, MAEEncoder, Siglip2Encoder + +from ...testing_utils import backend_empty_cache, enable_full_determinism, slow, torch_device + enable_full_determinism() -class AutoencoderRAEEncoderUnitTests(unittest.TestCase): +def _get_required_local_path(env_name: str) -> str: + path = os.environ.get(env_name) + assert path is not None and len(path) > 0, f"Please set `{env_name}` to a local pretrained model directory." + assert os.path.exists(path), f"Path from `{env_name}` does not exist: {path}" + return path + +@slow +class AutoencoderRAEEncoderIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) def test_dinov2_encoder_forward_shape(self): - dino_path = os.environ.get("DINO_PATH", "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/dinov2-with-registers-base") + dino_path = _get_required_local_path("DINO_PATH") encoder = Dinov2Encoder(encoder_name_or_path=dino_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) - self.assertEqual(y.ndim, 3) - self.assertEqual(y.shape[0], 1) - self.assertEqual(y.shape[1], 256) - self.assertEqual(y.shape[2], encoder.hidden_size) + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 256 + assert y.shape[2] == encoder.hidden_size def test_siglip2_encoder_forward_shape(self): - siglip2_path = "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/siglip2-base-patch16-256" + siglip2_path = _get_required_local_path("SIGLIP2_PATH") encoder = Siglip2Encoder(encoder_name_or_path=siglip2_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) - - self.assertEqual(y.ndim, 3) - self.assertEqual(y.shape[0], 1) - self.assertEqual(y.shape[1], 196) - self.assertEqual(y.shape[2], encoder.hidden_size) + + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 196 + assert y.shape[2] == encoder.hidden_size def test_mae_encoder_forward_shape(self): - mae_path = "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/models/vit-mae-base" + mae_path = _get_required_local_path("MAE_PATH") encoder = MAEEncoder(encoder_name_or_path=mae_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) - self.assertEqual(y.ndim, 3) - self.assertEqual(y.shape[0], 1) - self.assertEqual(y.shape[1], 196) - self.assertEqual(y.shape[2], encoder.hidden_size) - \ No newline at end of file + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 196 + assert y.shape[2] == encoder.hidden_size + + +@slow +class AutoencoderRAEIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_autoencoder_rae_encode_decode_forward_shapes_dinov2(self): + # This is a shape & numerical-sanity test. The decoder is randomly initialized unless you load trained weights. + dino_path = _get_required_local_path("DINO_PATH") + + encoder_input_size = 224 + decoder_patch_size = 16 + # dinov2 patch=14 -> (224/14)^2 = 256 tokens -> decoder output 256 for patch 16 + image_size = 256 + + model = AutoencoderRAE( + encoder_cls="dinov2", + encoder_name_or_path=dino_path, + image_size=image_size, + encoder_input_size=encoder_input_size, + patch_size=decoder_patch_size, + # keep the decoder lightweight for test runtime + decoder_hidden_size=128, + decoder_num_hidden_layers=1, + decoder_num_attention_heads=4, + decoder_intermediate_size=256, + ).to(torch_device) + model.eval() + + x = torch.rand(1, 3, encoder_input_size, encoder_input_size, device=torch_device) + + with torch.no_grad(): + latents = model.encode(x).latent + assert latents.ndim == 4 + assert latents.shape[0] == 1 + + decoded = model.decode(latents).sample + assert decoded.shape == (1, 3, image_size, image_size) + + recon = model(x).sample + assert recon.shape == (1, 3, image_size, image_size) + assert torch.isfinite(recon).all().item() \ No newline at end of file From 0850c8cdc974af36cb4e188e09b96cd8865cf0e0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 22:39:59 +0000 Subject: [PATCH 03/30] fix formatting --- .../models/autoencoders/autoencoder_rae.py | 162 +++++++++++------- .../test_models_autoencoder_rae.py | 150 ++++++++++++++-- 2 files changed, 241 insertions(+), 71 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 0aa826e9d452..95f65cedd001 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -1,3 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from math import sqrt from types import SimpleNamespace @@ -11,12 +25,20 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import BaseOutput +from ...utils import BaseOutput, logging from ...utils.accelerate_utils import apply_forward_hook from ..modeling_utils import ModelMixin from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput + ENCODER_ARCHS: Dict[str, Type] = {} +ENCODER_DEFAULT_NAME_OR_PATH = { + "dinov2": "facebook/dinov2-with-registers-base", + "siglip2": "google/siglip2-base-patch16-256", + "mae": "facebook/vit-mae-base", +} +logger = logging.get_logger(__name__) + def register_encoder(cls: Optional[Type] = None, *, name: Optional[str] = None) -> Union[Callable[[Type], Type], Type]: def decorator(inner_cls: Type) -> Type: @@ -33,10 +55,7 @@ def decorator(inner_cls: Type) -> Type: @register_encoder(name="dinov2") class Dinov2Encoder(nn.Module): - def __init__( - self, - encoder_name_or_path: str = "facebook/dinov2-with-registers-base" - ): + def __init__(self, encoder_name_or_path: str = "facebook/dinov2-with-registers-base"): super().__init__() from transformers import Dinov2WithRegistersModel @@ -52,8 +71,7 @@ def __init__( @torch.no_grad() def forward(self, images: torch.Tensor) -> torch.Tensor: """ - images is of shape (B, C, H, W) - where B is batch size, C is number of channels, H and W are height and + images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and """ outputs = self.model(images, output_hidden_states=True) unused_token_num = 5 # 1 CLS + 4 register tokens @@ -63,13 +81,12 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: @register_encoder(name="siglip2") class Siglip2Encoder(nn.Module): - def __init__( - self, - encoder_name_or_path: str = "google/siglip2-base-patch16-256" - ): + def __init__(self, encoder_name_or_path: str = "google/siglip2-base-patch16-256"): super().__init__() from transformers import SiglipModel + self.model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + self.model.requires_grad_(False) # remove the affine of final layernorm self.model.post_layernorm.elementwise_affine = False # remove the param @@ -81,10 +98,9 @@ def __init__( @torch.no_grad() def forward(self, images: torch.Tensor) -> torch.Tensor: """ - images is of shape (B, C, H, W) - where B is batch size, C is number of channels, H and W are height and + images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and """ - outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding = True) + outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding=True) image_features = outputs.last_hidden_state return image_features @@ -94,7 +110,9 @@ class MAEEncoder(nn.Module): def __init__(self, encoder_name_or_path: str = "facebook/vit-mae-base"): super().__init__() from transformers import ViTMAEForPreTraining + self.model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + self.model.requires_grad_(False) # remove the affine of final layernorm self.model.layernorm.elementwise_affine = False # remove the param @@ -102,19 +120,21 @@ def __init__(self, encoder_name_or_path: str = "facebook/vit-mae-base"): self.model.layernorm.bias = None self.hidden_size = self.model.config.hidden_size self.patch_size = self.model.config.patch_size - self.model.config.mask_ratio = 0. # no masking + self.model.config.mask_ratio = 0.0 # no masking + @torch.no_grad() def forward(self, images: torch.Tensor) -> torch.Tensor: """ - images is of shape (B, C, H, W) - where B is batch size, C is number of channels, H and W are height and width of the image + images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and width of + the image """ - h,w = images.shape[2], images.shape[3] - patch_num = int(h * w // self.patch_size ** 2) - assert patch_num * self.patch_size ** 2 == h * w, 'image size should be divisible by patch size' - noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0],-1).to(images.device).to(images.dtype) - outputs = self.model(images, noise, interpolate_pos_encoding = True) - image_features = outputs.last_hidden_state[:, 1:] # remove cls token + h, w = images.shape[2], images.shape[3] + patch_num = int(h * w // self.patch_size**2) + if patch_num * self.patch_size**2 != h * w: + raise ValueError("Image size should be divisible by patch size.") + noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) + outputs = self.model(images, noise, interpolate_pos_encoding=True) + image_features = outputs.last_hidden_state[:, 1:] # remove cls token return image_features @@ -183,7 +203,9 @@ class ViTMAESelfAttention(nn.Module): def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): super().__init__() if hidden_size % num_attention_heads != 0: - raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}") + raise ValueError( + f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" + ) self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) @@ -297,8 +319,12 @@ def __init__( qkv_bias=qkv_bias, attn_dropout=attention_probs_dropout_prob, ) - self.intermediate = ViTMAEIntermediate(hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act) - self.output = ViTMAEOutput(hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob) + self.intermediate = ViTMAEIntermediate( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act + ) + self.output = ViTMAEOutput( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob + ) self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) @@ -328,7 +354,9 @@ class GeneralDecoder(nn.Module): def __init__(self, config, num_patches: int): super().__init__() self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True) - self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False + ) self.decoder_layers = nn.ModuleList( [ @@ -347,7 +375,9 @@ def __init__(self, config, num_patches: int): ) self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) - self.decoder_pred = nn.Linear(config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True) + self.decoder_pred = nn.Linear( + config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True + ) self.gradient_checkpointing = False self.config = config self.num_patches = num_patches @@ -396,7 +426,9 @@ def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None): patch_size, num_channels = self.config.patch_size, self.config.num_channels original_image_size = ( - original_image_size if original_image_size is not None else (self.config.image_size, self.config.image_size) + original_image_size + if original_image_size is not None + else (self.config.image_size, self.config.image_size) ) original_height, original_width = original_image_size num_patches_h = original_height // patch_size @@ -468,30 +500,30 @@ class RAEDecoder(GeneralDecoder): pass -class AutoencoderRAE( - ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin -): +class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. - This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder - to reconstruct images from learned representations. + This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct + images from learned representations. - This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods - implemented for all models (such as downloading or saving). + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). Args: encoder_cls (`str`, *optional*, defaults to `"dinov2"`): Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. - encoder_name_or_path (`str`, *optional*, defaults to `"facebook/dinov2-with-registers-base"`): - Path to pretrained encoder model or model identifier from huggingface.co/models. + encoder_name_or_path (`str`, *optional*): + Path to pretrained encoder model or model identifier from huggingface.co/models. If not provided, uses an + encoder-specific default model id. patch_size (`int`, *optional*, defaults to `16`): Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): Input size expected by the encoder. image_size (`int`, *optional*): - Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like RAE-main: - `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // encoder.patch_size) ** 2`. + Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like + RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // + encoder.patch_size) ** 2`. num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. latent_mean (`torch.Tensor`, *optional*): @@ -514,7 +546,7 @@ class AutoencoderRAE( def __init__( self, encoder_cls: str = "dinov2", - encoder_name_or_path: str = "facebook/dinov2-with-registers-base", + encoder_name_or_path: Optional[str] = None, decoder_hidden_size: int = 512, decoder_num_hidden_layers: int = 8, decoder_num_attention_heads: int = 16, @@ -534,12 +566,13 @@ def __init__( if encoder_cls not in ENCODER_ARCHS: raise ValueError(f"Unknown encoder_cls='{encoder_cls}'. Available: {sorted(ENCODER_ARCHS.keys())}") + if encoder_name_or_path is None: + encoder_name_or_path = ENCODER_DEFAULT_NAME_OR_PATH[encoder_cls] self.encoder_input_size = encoder_input_size self.noise_tau = float(noise_tau) self.reshape_to_2d = bool(reshape_to_2d) self.use_encoder_loss = bool(use_encoder_loss) - self.scaling_factor = float(scaling_factor) # Frozen representation encoder self.encoder: nn.Module = ENCODER_ARCHS[encoder_cls](encoder_name_or_path=encoder_name_or_path) @@ -588,23 +621,35 @@ def __init__( proc = AutoImageProcessor.from_pretrained(encoder_name_or_path, local_files_only=False) encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1) encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1) - except Exception: + except (OSError, ValueError): # Keep default 0.5/0.5 if processor is unavailable. - pass + logger.warning( + "Falling back to encoder mean/std [0.5, 0.5, 0.5] for `%s` because AutoImageProcessor could not be loaded.", + encoder_name_or_path, + ) self.register_buffer("encoder_mean", encoder_mean, persistent=True) self.register_buffer("encoder_std", encoder_std, persistent=True) # Optional latent normalization (RAE-main uses mean/var) + def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]: + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().clone() + return torch.tensor(value, dtype=torch.float32) + + latent_mean_tensor = _as_optional_tensor(latent_mean) + latent_var_tensor = _as_optional_tensor(latent_var) self.do_latent_normalization = latent_mean is not None or latent_var is not None - if latent_mean is not None: - self.register_buffer("latent_mean", latent_mean.detach().clone(), persistent=True) + if latent_mean_tensor is not None: + self.register_buffer("_latent_mean", latent_mean_tensor, persistent=True) else: - self.latent_mean = None - if latent_var is not None: - self.register_buffer("latent_var", latent_var.detach().clone(), persistent=True) + self._latent_mean = None + if latent_var_tensor is not None: + self.register_buffer("_latent_var", latent_var_tensor, persistent=True) else: - self.latent_var = None + self._latent_var = None # ViT-MAE style decoder encoder_hidden_size = getattr(self.encoder, "hidden_size", None) @@ -657,15 +702,15 @@ def _maybe_denormalize_image(self, x: torch.Tensor) -> torch.Tensor: def _maybe_normalize_latents(self, z: torch.Tensor) -> torch.Tensor: if not self.do_latent_normalization: return z - latent_mean = self.latent_mean.to(device=z.device, dtype=z.dtype) if self.latent_mean is not None else 0 - latent_var = self.latent_var.to(device=z.device, dtype=z.dtype) if self.latent_var is not None else 1 + latent_mean = self._latent_mean.to(device=z.device, dtype=z.dtype) if self._latent_mean is not None else 0 + latent_var = self._latent_var.to(device=z.device, dtype=z.dtype) if self._latent_var is not None else 1 return (z - latent_mean) / torch.sqrt(latent_var + 1e-5) def _maybe_denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: if not self.do_latent_normalization: return z - latent_mean = self.latent_mean.to(device=z.device, dtype=z.dtype) if self.latent_mean is not None else 0 - latent_var = self.latent_var.to(device=z.device, dtype=z.dtype) if self.latent_var is not None else 1 + latent_mean = self._latent_mean.to(device=z.device, dtype=z.dtype) if self._latent_mean is not None else 0 + latent_var = self._latent_var.to(device=z.device, dtype=z.dtype) if self._latent_var is not None else 1 return z * torch.sqrt(latent_var + 1e-5) + latent_mean def _encode(self, x: torch.Tensor) -> torch.Tensor: @@ -689,8 +734,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: z = self._maybe_normalize_latents(z) # Follow diffusers convention: optionally scale latents for diffusion - if self.scaling_factor != 1.0: - z = z * self.scaling_factor + if self.config.scaling_factor != 1.0: + z = z * self.config.scaling_factor return z @@ -707,8 +752,8 @@ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutp def _decode(self, z: torch.Tensor) -> torch.Tensor: # Undo scaling factor if applied at encode time - if self.scaling_factor != 1.0: - z = z / self.scaling_factor + if self.config.scaling_factor != 1.0: + z = z / self.config.scaling_factor z = self._maybe_denormalize_latents(z) @@ -740,4 +785,3 @@ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[Decod if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) - diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 3aa9641f922f..432df61725cc 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -14,12 +14,18 @@ # limitations under the License. import gc -import os import unittest import torch +import torch.nn.functional as F -from diffusers.models.autoencoders.autoencoder_rae import AutoencoderRAE, Dinov2Encoder, MAEEncoder, Siglip2Encoder +from diffusers.models.autoencoders.autoencoder_rae import ( + AutoencoderRAE, + Dinov2Encoder, + MAEEncoder, + Siglip2Encoder, + register_encoder, +) from ...testing_utils import backend_empty_cache, enable_full_determinism, slow, torch_device @@ -27,11 +33,131 @@ enable_full_determinism() -def _get_required_local_path(env_name: str) -> str: - path = os.environ.get(env_name) - assert path is not None and len(path) > 0, f"Please set `{env_name}` to a local pretrained model directory." - assert os.path.exists(path), f"Path from `{env_name}` does not exist: {path}" - return path +DINO_MODEL_ID = "facebook/dinov2-with-registers-base" +SIGLIP2_MODEL_ID = "google/siglip2-base-patch16-256" +MAE_MODEL_ID = "facebook/vit-mae-base" + + +@register_encoder(name="tiny_test") +class TinyTestEncoder(torch.nn.Module): + def __init__(self, encoder_name_or_path: str = "unused"): + super().__init__() + self.patch_size = 8 + self.hidden_size = 16 + + def forward(self, images: torch.Tensor) -> torch.Tensor: + pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) + tokens = pooled.flatten(2).transpose(1, 2).contiguous() + return tokens.repeat(1, 1, self.hidden_size) + + +class AutoencoderRAETests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def _make_model(self, **overrides) -> AutoencoderRAE: + config = { + "encoder_cls": "tiny_test", + "encoder_name_or_path": "unused", + "encoder_input_size": 32, + "patch_size": 4, + "image_size": 16, + "decoder_hidden_size": 32, + "decoder_num_hidden_layers": 1, + "decoder_num_attention_heads": 4, + "decoder_intermediate_size": 64, + "num_channels": 3, + "noise_tau": 0.0, + "reshape_to_2d": True, + "scaling_factor": 1.0, + } + config.update(overrides) + return AutoencoderRAE(**config).to(torch_device) + + def test_fast_encode_decode_and_forward_shapes(self): + model = self._make_model().eval() + x = torch.rand(2, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + z = model.encode(x).latent + decoded = model.decode(z).sample + recon = model(x).sample + + self.assertEqual(z.shape, (2, 16, 4, 4)) + self.assertEqual(decoded.shape, (2, 3, 16, 16)) + self.assertEqual(recon.shape, (2, 3, 16, 16)) + self.assertTrue(torch.isfinite(recon).all().item()) + + def test_fast_scaling_factor_encode_and_decode_consistency(self): + torch.manual_seed(0) + model_base = self._make_model(scaling_factor=1.0).eval() + torch.manual_seed(0) + model_scaled = self._make_model(scaling_factor=2.0).eval() + + x = torch.rand(2, 3, 32, 32, device=torch_device) + with torch.no_grad(): + z_base = model_base.encode(x).latent + z_scaled = model_scaled.encode(x).latent + recon_base = model_base.decode(z_base).sample + recon_scaled = model_scaled.decode(z_scaled).sample + + self.assertTrue(torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)) + self.assertTrue(torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)) + + def test_fast_latent_normalization_matches_formula(self): + latent_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32) + latent_var = torch.full((1, 16, 1, 1), 4.0, dtype=torch.float32) + + model_raw = self._make_model().eval() + model_norm = self._make_model(latent_mean=latent_mean, latent_var=latent_var).eval() + x = torch.rand(1, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + z_raw = model_raw.encode(x).latent + z_norm = model_norm.encode(x).latent + + expected = (z_raw - latent_mean.to(z_raw.device, z_raw.dtype)) / torch.sqrt( + latent_var.to(z_raw.device, z_raw.dtype) + 1e-5 + ) + self.assertTrue(torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)) + + def test_fast_slicing_matches_non_slicing(self): + model = self._make_model().eval() + x = torch.rand(3, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + model.use_slicing = False + z_no_slice = model.encode(x).latent + out_no_slice = model.decode(z_no_slice).sample + + model.use_slicing = True + z_slice = model.encode(x).latent + out_slice = model.decode(z_slice).sample + + self.assertTrue(torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5)) + self.assertTrue(torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5)) + + def test_fast_noise_tau_applies_only_in_train(self): + model = self._make_model(noise_tau=0.5).to(torch_device) + x = torch.rand(2, 3, 32, 32, device=torch_device) + + model.train() + torch.manual_seed(0) + z_train_1 = model.encode(x).latent + torch.manual_seed(1) + z_train_2 = model.encode(x).latent + + model.eval() + torch.manual_seed(0) + z_eval_1 = model.encode(x).latent + torch.manual_seed(1) + z_eval_2 = model.encode(x).latent + + self.assertEqual(z_train_1.shape, z_eval_1.shape) + self.assertFalse(torch.allclose(z_train_1, z_train_2)) + self.assertTrue(torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)) @slow @@ -42,7 +168,7 @@ def tearDown(self): backend_empty_cache(torch_device) def test_dinov2_encoder_forward_shape(self): - dino_path = _get_required_local_path("DINO_PATH") + dino_path = DINO_MODEL_ID encoder = Dinov2Encoder(encoder_name_or_path=dino_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) @@ -54,7 +180,7 @@ def test_dinov2_encoder_forward_shape(self): assert y.shape[2] == encoder.hidden_size def test_siglip2_encoder_forward_shape(self): - siglip2_path = _get_required_local_path("SIGLIP2_PATH") + siglip2_path = SIGLIP2_MODEL_ID encoder = Siglip2Encoder(encoder_name_or_path=siglip2_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) @@ -66,7 +192,7 @@ def test_siglip2_encoder_forward_shape(self): assert y.shape[2] == encoder.hidden_size def test_mae_encoder_forward_shape(self): - mae_path = _get_required_local_path("MAE_PATH") + mae_path = MAE_MODEL_ID encoder = MAEEncoder(encoder_name_or_path=mae_path).to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) @@ -87,7 +213,7 @@ def tearDown(self): def test_autoencoder_rae_encode_decode_forward_shapes_dinov2(self): # This is a shape & numerical-sanity test. The decoder is randomly initialized unless you load trained weights. - dino_path = _get_required_local_path("DINO_PATH") + dino_path = DINO_MODEL_ID encoder_input_size = 224 decoder_patch_size = 16 @@ -120,4 +246,4 @@ def test_autoencoder_rae_encode_decode_forward_shapes_dinov2(self): recon = model(x).sample assert recon.shape == (1, 3, image_size, image_size) - assert torch.isfinite(recon).all().item() \ No newline at end of file + assert torch.isfinite(recon).all().item() From 24acab0bcc92bfbe5b54e6a80a9aa961e6021329 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 22:44:16 +0000 Subject: [PATCH 04/30] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4e402921aa5f..2b71910840b5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -656,6 +656,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderRAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] From 25bc9e334ca7871e5fea5e8fc32519a87e4b32c3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 22:44:46 +0000 Subject: [PATCH 05/30] initial doc --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/models/autoencoder_rae.md | 55 ++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_rae.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 64a4222845b0..ac255bd4cdef 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -460,6 +460,8 @@ title: ConsistencyDecoderVAE - local: api/models/autoencoder_oobleck title: Oobleck AutoEncoder + - local: api/models/autoencoder_rae + title: AutoencoderRAE - local: api/models/autoencoder_tiny title: Tiny AutoEncoder - local: api/models/vq diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md new file mode 100644 index 000000000000..27269ccd674f --- /dev/null +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -0,0 +1,55 @@ + + +# AutoencoderRAE + +`AutoencoderRAE` is a representation autoencoder that combines a frozen vision encoder (DINOv2, SigLIP2, or MAE) with a ViT-MAE-style decoder. + +Paper: [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690). + +The model follows the standard diffusers autoencoder API: +- `encode(...)` returns an `EncoderOutput` with a `latent` tensor. +- `decode(...)` returns a `DecoderOutput` with a `sample` tensor. + +## Usage + +```python +import torch +from diffusers import AutoencoderRAE + +model = AutoencoderRAE( + encoder_cls="dinov2", + encoder_name_or_path="facebook/dinov2-with-registers-base", + encoder_input_size=224, + patch_size=16, + image_size=256, +).to("cuda").eval() + +# Encode and decode +x = torch.randn(1, 3, 256, 256, device="cuda") +with torch.no_grad(): + latents = model.encode(x).latent + recon = model.decode(latents).sample +``` + +`encoder_cls` supports `"dinov2"`, `"siglip2"`, and `"mae"`. + +## AutoencoderRAE class + +[[autodoc]] AutoencoderRAE + - encode + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput From f06ea7a9012500a6ee82fd3cb1286973bcc0f37a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 22:51:36 +0000 Subject: [PATCH 06/30] fix latent_mean / latent_var init types to accept config-friendly inputs --- .../models/autoencoders/autoencoder_rae.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 95f65cedd001..4901016fdc03 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from math import sqrt from types import SimpleNamespace -from typing import Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch @@ -526,10 +526,12 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode encoder.patch_size) ** 2`. num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. - latent_mean (`torch.Tensor`, *optional*): - Optional mean for latent normalization. - latent_var (`torch.Tensor`, *optional*): - Optional variance for latent normalization. + latent_mean (`list` or `tuple`, *optional*): + Optional mean for latent normalization. Tensor inputs are accepted for backward compatibility and converted + to config-serializable lists. + latent_var (`list` or `tuple`, *optional*): + Optional variance for latent normalization. Tensor inputs are accepted for backward compatibility and + converted to config-serializable lists. noise_tau (`float`, *optional*, defaults to `0.0`): Noise level for training (adds noise to latents during training). reshape_to_2d (`bool`, *optional*, defaults to `True`): @@ -555,8 +557,8 @@ def __init__( encoder_input_size: int = 224, image_size: Optional[int] = None, num_channels: int = 3, - latent_mean: Optional[torch.Tensor] = None, - latent_var: Optional[torch.Tensor] = None, + latent_mean: Optional[Union[list, tuple, torch.Tensor]] = None, + latent_var: Optional[Union[list, tuple, torch.Tensor]] = None, noise_tau: float = 0.0, reshape_to_2d: bool = True, use_encoder_loss: bool = False, @@ -569,6 +571,22 @@ def __init__( if encoder_name_or_path is None: encoder_name_or_path = ENCODER_DEFAULT_NAME_OR_PATH[encoder_cls] + def _to_config_compatible(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, tuple): + return [_to_config_compatible(v) for v in value] + if isinstance(value, list): + return [_to_config_compatible(v) for v in value] + return value + + # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. + self.register_to_config( + encoder_name_or_path=encoder_name_or_path, + latent_mean=_to_config_compatible(latent_mean), + latent_var=_to_config_compatible(latent_var), + ) + self.encoder_input_size = encoder_input_size self.noise_tau = float(noise_tau) self.reshape_to_2d = bool(reshape_to_2d) From d7cb12470b908b5715539c70f7e99d546491eac7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 22:57:02 +0000 Subject: [PATCH 07/30] use mean and std convention --- docs/source/en/api/models/autoencoder_rae.md | 2 + .../models/autoencoders/autoencoder_rae.py | 75 ++++++++++++------- .../test_models_autoencoder_rae.py | 8 +- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md index 27269ccd674f..e69b6451c3c0 100644 --- a/docs/source/en/api/models/autoencoder_rae.md +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -43,6 +43,8 @@ with torch.no_grad(): `encoder_cls` supports `"dinov2"`, `"siglip2"`, and `"mae"`. +For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders). + ## AutoencoderRAE class [[autodoc]] AutoencoderRAE diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 4901016fdc03..56800fc3235f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -526,12 +526,16 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode encoder.patch_size) ** 2`. num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. - latent_mean (`list` or `tuple`, *optional*): + latents_mean (`list` or `tuple`, *optional*): Optional mean for latent normalization. Tensor inputs are accepted for backward compatibility and converted to config-serializable lists. + latents_std (`list` or `tuple`, *optional*): + Optional standard deviation for latent normalization. Tensor inputs are accepted for backward compatibility + and converted to config-serializable lists. + latent_mean (`list` or `tuple`, *optional*): + Deprecated alias of `latents_mean`. latent_var (`list` or `tuple`, *optional*): - Optional variance for latent normalization. Tensor inputs are accepted for backward compatibility and - converted to config-serializable lists. + Deprecated alias of latent variance. If provided, it is converted to `latents_std = sqrt(latent_var + 1e-5)`. noise_tau (`float`, *optional*, defaults to `0.0`): Noise level for training (adds noise to latents during training). reshape_to_2d (`bool`, *optional*, defaults to `True`): @@ -557,6 +561,8 @@ def __init__( encoder_input_size: int = 224, image_size: Optional[int] = None, num_channels: int = 3, + latents_mean: Optional[Union[list, tuple, torch.Tensor]] = None, + latents_std: Optional[Union[list, tuple, torch.Tensor]] = None, latent_mean: Optional[Union[list, tuple, torch.Tensor]] = None, latent_var: Optional[Union[list, tuple, torch.Tensor]] = None, noise_tau: float = 0.0, @@ -580,11 +586,34 @@ def _to_config_compatible(value: Any) -> Any: return [_to_config_compatible(v) for v in value] return value + if latents_mean is not None and latent_mean is not None: + raise ValueError("Please provide only one of `latents_mean` or deprecated `latent_mean`.") + if latents_std is not None and latent_var is not None: + raise ValueError("Please provide only one of `latents_std` or deprecated `latent_var`.") + + if latents_mean is None: + latents_mean = latent_mean + + def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]: + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().clone() + return torch.tensor(value, dtype=torch.float32) + + latents_std_tensor = _as_optional_tensor(latents_std) + latent_var_tensor = _as_optional_tensor(latent_var) + if latents_std_tensor is None and latent_var_tensor is not None: + latents_std_tensor = torch.sqrt(latent_var_tensor + 1e-5) + latents_std = latents_std_tensor + # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. self.register_to_config( encoder_name_or_path=encoder_name_or_path, - latent_mean=_to_config_compatible(latent_mean), - latent_var=_to_config_compatible(latent_var), + latents_mean=_to_config_compatible(latents_mean), + latents_std=_to_config_compatible(latents_std), + latent_mean=None, + latent_var=None, ) self.encoder_input_size = encoder_input_size @@ -650,24 +679,16 @@ def _to_config_compatible(value: Any) -> Any: self.register_buffer("encoder_std", encoder_std, persistent=True) # Optional latent normalization (RAE-main uses mean/var) - def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]: - if value is None: - return None - if isinstance(value, torch.Tensor): - return value.detach().clone() - return torch.tensor(value, dtype=torch.float32) - - latent_mean_tensor = _as_optional_tensor(latent_mean) - latent_var_tensor = _as_optional_tensor(latent_var) - self.do_latent_normalization = latent_mean is not None or latent_var is not None - if latent_mean_tensor is not None: - self.register_buffer("_latent_mean", latent_mean_tensor, persistent=True) + latents_mean_tensor = _as_optional_tensor(latents_mean) + self.do_latent_normalization = latents_mean is not None or latents_std is not None or latent_var is not None + if latents_mean_tensor is not None: + self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) else: - self._latent_mean = None - if latent_var_tensor is not None: - self.register_buffer("_latent_var", latent_var_tensor, persistent=True) + self._latents_mean = None + if latents_std_tensor is not None: + self.register_buffer("_latents_std", latents_std_tensor, persistent=True) else: - self._latent_var = None + self._latents_std = None # ViT-MAE style decoder encoder_hidden_size = getattr(self.encoder, "hidden_size", None) @@ -720,16 +741,16 @@ def _maybe_denormalize_image(self, x: torch.Tensor) -> torch.Tensor: def _maybe_normalize_latents(self, z: torch.Tensor) -> torch.Tensor: if not self.do_latent_normalization: return z - latent_mean = self._latent_mean.to(device=z.device, dtype=z.dtype) if self._latent_mean is not None else 0 - latent_var = self._latent_var.to(device=z.device, dtype=z.dtype) if self._latent_var is not None else 1 - return (z - latent_mean) / torch.sqrt(latent_var + 1e-5) + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) if self._latents_mean is not None else 0 + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) if self._latents_std is not None else 1 + return (z - latents_mean) / (latents_std + 1e-5) def _maybe_denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: if not self.do_latent_normalization: return z - latent_mean = self._latent_mean.to(device=z.device, dtype=z.dtype) if self._latent_mean is not None else 0 - latent_var = self._latent_var.to(device=z.device, dtype=z.dtype) if self._latent_var is not None else 1 - return z * torch.sqrt(latent_var + 1e-5) + latent_mean + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) if self._latents_mean is not None else 0 + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) if self._latents_std is not None else 1 + return z * (latents_std + 1e-5) + latents_mean def _encode(self, x: torch.Tensor) -> torch.Tensor: x = self._maybe_resize_and_normalize(x) diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 432df61725cc..8cf2c0bb1114 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -108,19 +108,17 @@ def test_fast_scaling_factor_encode_and_decode_consistency(self): def test_fast_latent_normalization_matches_formula(self): latent_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32) - latent_var = torch.full((1, 16, 1, 1), 4.0, dtype=torch.float32) + latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32) model_raw = self._make_model().eval() - model_norm = self._make_model(latent_mean=latent_mean, latent_var=latent_var).eval() + model_norm = self._make_model(latents_mean=latent_mean, latents_std=latents_std).eval() x = torch.rand(1, 3, 32, 32, device=torch_device) with torch.no_grad(): z_raw = model_raw.encode(x).latent z_norm = model_norm.encode(x).latent - expected = (z_raw - latent_mean.to(z_raw.device, z_raw.dtype)) / torch.sqrt( - latent_var.to(z_raw.device, z_raw.dtype) + 1e-5 - ) + expected = (z_raw - latent_mean.to(z_raw.device, z_raw.dtype)) / (latents_std.to(z_raw.device, z_raw.dtype) + 1e-5) self.assertTrue(torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)) def test_fast_slicing_matches_non_slicing(self): From 0d59b2273280d2c2ca9ace17221967c7e31fb33d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:19:13 +0000 Subject: [PATCH 08/30] cleanup --- docs/source/en/_toctree.yml | 4 +-- .../models/autoencoders/autoencoder_rae.py | 30 ++++--------------- .../test_models_autoencoder_rae.py | 10 ++++--- 3 files changed, 13 insertions(+), 31 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ac255bd4cdef..e7cb92ef9f2e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -456,12 +456,12 @@ title: AutoencoderKLQwenImage - local: api/models/autoencoder_kl_wan title: AutoencoderKLWan + - local: api/models/autoencoder_rae + title: AutoencoderRAE - local: api/models/consistency_decoder_vae title: ConsistencyDecoderVAE - local: api/models/autoencoder_oobleck title: Oobleck AutoEncoder - - local: api/models/autoencoder_rae - title: AutoencoderRAE - local: api/models/autoencoder_tiny title: Tiny AutoEncoder - local: api/models/vq diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 56800fc3235f..042949168a50 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -527,15 +527,11 @@ class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMode num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. latents_mean (`list` or `tuple`, *optional*): - Optional mean for latent normalization. Tensor inputs are accepted for backward compatibility and converted - to config-serializable lists. + Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable + lists. latents_std (`list` or `tuple`, *optional*): - Optional standard deviation for latent normalization. Tensor inputs are accepted for backward compatibility - and converted to config-serializable lists. - latent_mean (`list` or `tuple`, *optional*): - Deprecated alias of `latents_mean`. - latent_var (`list` or `tuple`, *optional*): - Deprecated alias of latent variance. If provided, it is converted to `latents_std = sqrt(latent_var + 1e-5)`. + Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to + config-serializable lists. noise_tau (`float`, *optional*, defaults to `0.0`): Noise level for training (adds noise to latents during training). reshape_to_2d (`bool`, *optional*, defaults to `True`): @@ -563,8 +559,6 @@ def __init__( num_channels: int = 3, latents_mean: Optional[Union[list, tuple, torch.Tensor]] = None, latents_std: Optional[Union[list, tuple, torch.Tensor]] = None, - latent_mean: Optional[Union[list, tuple, torch.Tensor]] = None, - latent_var: Optional[Union[list, tuple, torch.Tensor]] = None, noise_tau: float = 0.0, reshape_to_2d: bool = True, use_encoder_loss: bool = False, @@ -586,14 +580,6 @@ def _to_config_compatible(value: Any) -> Any: return [_to_config_compatible(v) for v in value] return value - if latents_mean is not None and latent_mean is not None: - raise ValueError("Please provide only one of `latents_mean` or deprecated `latent_mean`.") - if latents_std is not None and latent_var is not None: - raise ValueError("Please provide only one of `latents_std` or deprecated `latent_var`.") - - if latents_mean is None: - latents_mean = latent_mean - def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Optional[torch.Tensor]: if value is None: return None @@ -602,18 +588,12 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op return torch.tensor(value, dtype=torch.float32) latents_std_tensor = _as_optional_tensor(latents_std) - latent_var_tensor = _as_optional_tensor(latent_var) - if latents_std_tensor is None and latent_var_tensor is not None: - latents_std_tensor = torch.sqrt(latent_var_tensor + 1e-5) - latents_std = latents_std_tensor # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. self.register_to_config( encoder_name_or_path=encoder_name_or_path, latents_mean=_to_config_compatible(latents_mean), latents_std=_to_config_compatible(latents_std), - latent_mean=None, - latent_var=None, ) self.encoder_input_size = encoder_input_size @@ -680,7 +660,7 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op # Optional latent normalization (RAE-main uses mean/var) latents_mean_tensor = _as_optional_tensor(latents_mean) - self.do_latent_normalization = latents_mean is not None or latents_std is not None or latent_var is not None + self.do_latent_normalization = latents_mean is not None or latents_std is not None if latents_mean_tensor is not None: self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) else: diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 8cf2c0bb1114..899b42e757e9 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -106,19 +106,21 @@ def test_fast_scaling_factor_encode_and_decode_consistency(self): self.assertTrue(torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)) self.assertTrue(torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)) - def test_fast_latent_normalization_matches_formula(self): - latent_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32) + def test_fast_latents_normalization_matches_formula(self): + latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32) latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32) model_raw = self._make_model().eval() - model_norm = self._make_model(latents_mean=latent_mean, latents_std=latents_std).eval() + model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval() x = torch.rand(1, 3, 32, 32, device=torch_device) with torch.no_grad(): z_raw = model_raw.encode(x).latent z_norm = model_norm.encode(x).latent - expected = (z_raw - latent_mean.to(z_raw.device, z_raw.dtype)) / (latents_std.to(z_raw.device, z_raw.dtype) + 1e-5) + expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / ( + latents_std.to(z_raw.device, z_raw.dtype) + 1e-5 + ) self.assertTrue(torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)) def test_fast_slicing_matches_non_slicing(self): From 202b14f6a4347d5213c7bfa574b736f4ca1c50fb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:19:53 +0000 Subject: [PATCH 09/30] add rae to diffusers script --- scripts/convert_rae_to_diffusers.py | 290 ++++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 scripts/convert_rae_to_diffusers.py diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py new file mode 100644 index 000000000000..c751ae6f084d --- /dev/null +++ b/scripts/convert_rae_to_diffusers.py @@ -0,0 +1,290 @@ +import argparse +import json +from pathlib import Path +from typing import Any + +import torch +from huggingface_hub import HfApi, hf_hub_download + +from diffusers import AutoencoderRAE + + +DECODER_CONFIGS = { + "ViTB": { + "decoder_hidden_size": 768, + "decoder_intermediate_size": 3072, + "decoder_num_attention_heads": 12, + "decoder_num_hidden_layers": 12, + }, + "ViTL": { + "decoder_hidden_size": 1024, + "decoder_intermediate_size": 4096, + "decoder_num_attention_heads": 16, + "decoder_num_hidden_layers": 24, + }, + "ViTXL": { + "decoder_hidden_size": 1152, + "decoder_intermediate_size": 4096, + "decoder_num_attention_heads": 16, + "decoder_num_hidden_layers": 28, + }, +} + +ENCODER_DEFAULT_NAME_OR_PATH = { + "dinov2": "facebook/dinov2-with-registers-base", + "siglip2": "google/siglip2-base-patch16-256", + "mae": "facebook/vit-mae-base", +} + +DEFAULT_DECODER_SUBDIR = { + "dinov2": "decoders/dinov2/wReg_base", + "mae": "decoders/mae/base_p16", + "siglip2": "decoders/siglip2/base_p16_i256", +} + +DEFAULT_STATS_SUBDIR = { + "dinov2": "stats/dinov2/wReg_base", + "mae": "stats/mae/base_p16", + "siglip2": "stats/siglip2/base_p16_i256", +} + +DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt") +STATS_FILE_CANDIDATES = ("stat.pt",) + + +def dataset_case_candidates(name: str) -> tuple[str, ...]: + return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k") + + +class RepoAccessor: + def __init__(self, repo_or_path: str, cache_dir: str | None = None): + self.repo_or_path = repo_or_path + self.cache_dir = cache_dir + self.local_root: Path | None = None + self.repo_id: str | None = None + self.repo_files: set[str] | None = None + + root = Path(repo_or_path) + if root.exists() and root.is_dir(): + self.local_root = root + else: + self.repo_id = repo_or_path + self.repo_files = set(HfApi().list_repo_files(repo_or_path)) + + def exists(self, relative_path: str) -> bool: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return (self.local_root / relative_path).is_file() + return relative_path in self.repo_files + + def fetch(self, relative_path: str) -> Path: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return self.local_root / relative_path + downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir) + return Path(downloaded) + + +def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]: + state_dict = maybe_wrapped + for k in ("model", "module", "state_dict"): + if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict): + state_dict = state_dict[k] + + out = dict(state_dict) + if len(out) > 0 and all(key.startswith("module.") for key in out): + out = {key[len("module.") :]: value for key, value in out.items()} + if len(out) > 0 and all(key.startswith("decoder.") for key in out): + out = {key[len("decoder.") :]: value for key, value in out.items()} + return out + + +def resolve_decoder_file( + accessor: RepoAccessor, encoder_cls: str, variant: str, decoder_checkpoint: str | None +) -> str: + if decoder_checkpoint is not None: + if accessor.exists(decoder_checkpoint): + return decoder_checkpoint + raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}") + + base = f"{DEFAULT_DECODER_SUBDIR[encoder_cls]}/{variant}" + for name in DECODER_FILE_CANDIDATES: + candidate = f"{base}/{name}" + if accessor.exists(candidate): + return candidate + + raise FileNotFoundError( + f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}" + ) + + +def resolve_stats_file( + accessor: RepoAccessor, + encoder_cls: str, + dataset_name: str, + stats_checkpoint: str | None, +) -> str | None: + if stats_checkpoint is not None: + if accessor.exists(stats_checkpoint): + return stats_checkpoint + raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}") + + base = DEFAULT_STATS_SUBDIR[encoder_cls] + for dataset in dataset_case_candidates(dataset_name): + for name in STATS_FILE_CANDIDATES: + candidate = f"{base}/{dataset}/{name}" + if accessor.exists(candidate): + return candidate + + return None + + +def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]: + if not isinstance(stats_obj, dict): + return None, None + + if "latents_mean" in stats_obj or "latents_std" in stats_obj: + return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None) + + mean = stats_obj.get("mean", None) + var = stats_obj.get("var", None) + if mean is None and var is None: + return None, None + + latents_std = None + if var is not None: + if isinstance(var, torch.Tensor): + latents_std = torch.sqrt(var + 1e-5) + else: + latents_std = torch.sqrt(torch.tensor(var) + 1e-5) + return mean, latents_std + + +def convert(args: argparse.Namespace) -> None: + accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) + encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_cls] + + decoder_relpath = resolve_decoder_file(accessor, args.encoder_cls, args.variant, args.decoder_checkpoint) + stats_relpath = resolve_stats_file(accessor, args.encoder_cls, args.dataset_name, args.stats_checkpoint) + + print(f"Using decoder checkpoint: {decoder_relpath}") + if stats_relpath is not None: + print(f"Using stats checkpoint: {stats_relpath}") + else: + print("No stats checkpoint found; conversion will proceed without latent stats.") + + if args.dry_run: + return + + decoder_path = accessor.fetch(decoder_relpath) + decoder_obj = torch.load(decoder_path, map_location="cpu") + decoder_state_dict = unwrap_state_dict(decoder_obj) + + latents_mean, latents_std = None, None + if stats_relpath is not None: + stats_path = accessor.fetch(stats_relpath) + stats_obj = torch.load(stats_path, map_location="cpu") + latents_mean, latents_std = extract_latent_stats(stats_obj) + + decoder_cfg = DECODER_CONFIGS[args.decoder_config_name] + + model = AutoencoderRAE( + encoder_cls=args.encoder_cls, + encoder_name_or_path=encoder_name_or_path, + encoder_input_size=args.encoder_input_size, + patch_size=args.patch_size, + image_size=args.image_size, + num_channels=args.num_channels, + decoder_hidden_size=decoder_cfg["decoder_hidden_size"], + decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"], + decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"], + decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"], + latents_mean=latents_mean, + latents_std=latents_std, + scaling_factor=args.scaling_factor, + ) + + load_result = model.decoder.load_state_dict(decoder_state_dict, strict=False) + allowed_missing = {"trainable_cls_token"} + missing = set(load_result.missing_keys) + unexpected = set(load_result.unexpected_keys) + + if unexpected: + raise RuntimeError(f"Unexpected decoder keys after conversion: {sorted(unexpected)}") + if missing - allowed_missing: + raise RuntimeError(f"Missing decoder keys after conversion: {sorted(missing - allowed_missing)}") + + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + model.save_pretrained(output_path) + + metadata = { + "source": args.repo_or_path, + "encoder_cls": args.encoder_cls, + "encoder_name_or_path": encoder_name_or_path, + "decoder_checkpoint": decoder_relpath, + "stats_checkpoint": stats_relpath, + "variant": args.variant, + "dataset_name": args.dataset_name, + "decoder_config_name": args.decoder_config_name, + "missing_decoder_keys": sorted(missing), + "unexpected_decoder_keys": sorted(unexpected), + } + with open(output_path / "conversion_metadata.json", "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + if args.verify_load: + print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...") + loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False) + if not isinstance(loaded_model, AutoencoderRAE): + raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.") + print("Verification passed.") + + print(f"Saved converted AutoencoderRAE to: {output_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format") + parser.add_argument( + "--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path" + ) + parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model") + + parser.add_argument("--encoder_cls", type=str, choices=["dinov2", "mae", "siglip2"], required=True) + parser.add_argument("--encoder_name_or_path", type=str, default=None, help="Optional encoder HF id/path override") + + parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name") + parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name") + + parser.add_argument( + "--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path" + ) + parser.add_argument( + "--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path" + ) + + parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL") + parser.add_argument("--encoder_input_size", type=int, default=224) + parser.add_argument("--patch_size", type=int, default=16) + parser.add_argument("--image_size", type=int, default=None) + parser.add_argument("--num_channels", type=int, default=3) + parser.add_argument("--scaling_factor", type=float, default=1.0) + + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files") + parser.add_argument( + "--verify_load", + action="store_true", + help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert(args) + + +if __name__ == "__main__": + main() From 7cbbf271f3e4d6009117636e63b0a017c3a1ad90 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:33:30 +0000 Subject: [PATCH 10/30] use imports --- .../models/autoencoders/autoencoder_rae.py | 76 ++++--------------- 1 file changed, 16 insertions(+), 60 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 042949168a50..de9aaa9ce436 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -17,7 +17,6 @@ from types import SimpleNamespace from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -27,6 +26,8 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import get_2d_sincos_pos_embed from ..modeling_utils import ModelMixin from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput @@ -138,46 +139,6 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: return image_features -def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / (10000**omega) # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2) - emb_sin = np.sin(out) - emb_cos = np.cos(out) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, add_cls_token: bool = False) -> np.ndarray: - """ - Returns: - pos_embed: (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim) - """ - grid_h = np.arange(grid_size, dtype=np.float64) - grid_w = np.arange(grid_size, dtype=np.float64) - grid = np.meshgrid(grid_w, grid_h) # w first - grid = np.stack(grid, axis=0) # (2, grid, grid) - grid = grid.reshape([2, 1, grid_size, grid_size]) - - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be even") - - emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - - if add_cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim], dtype=np.float64), pos_embed], axis=0) - - return pos_embed - - @dataclass class RAEDecoderOutput(BaseOutput): """ @@ -191,14 +152,6 @@ class RAEDecoderOutput(BaseOutput): logits: torch.Tensor -ACT2FN = { - "gelu": F.gelu, - "relu": F.relu, - "silu": F.silu, - "swish": F.silu, -} - - class ViTMAESelfAttention(nn.Module): def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): super().__init__() @@ -272,9 +225,10 @@ class ViTMAEIntermediate(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): super().__init__() self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = ACT2FN.get(hidden_act, None) - if self.intermediate_act_fn is None: - raise ValueError(f"Unsupported hidden_act={hidden_act}") + try: + self.intermediate_act_fn = get_activation(hidden_act) + except ValueError as e: + raise ValueError(f"Unsupported hidden_act={hidden_act}") from e def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) @@ -338,7 +292,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return layer_output -class GeneralDecoder(nn.Module): +class RAEDecoder(nn.Module): """ Decoder implementation ported from RAE-main to keep checkpoint compatibility. @@ -391,8 +345,15 @@ def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None): def _initialize_weights(self, num_patches: int): grid_size = int(num_patches**0.5) - pos_embed = _get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], grid_size, add_cls_token=True) - self.decoder_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + grid_size, + cls_token=True, + extra_tokens=1, + output_type="pt", + device=self.decoder_pos_embed.device, + ) + self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype)) def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: embeddings_positions = embeddings.shape[1] - 1 @@ -495,11 +456,6 @@ def forward( return RAEDecoderOutput(logits=logits) -# Backward-compatible alias: keep `RAEDecoder` name used by `AutoencoderRAE` -class RAEDecoder(GeneralDecoder): - pass - - class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. From e6d449933df04e829b677d962ca66612536b9ce9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:50:52 +0000 Subject: [PATCH 11/30] use attention --- scripts/convert_rae_to_diffusers.py | 22 +++ .../models/autoencoders/autoencoder_rae.py | 185 +++++++++++------- .../test_models_autoencoder_rae.py | 28 ++- 3 files changed, 164 insertions(+), 71 deletions(-) diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index c751ae6f084d..8db9f448c4a2 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -99,6 +99,27 @@ def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]: return out +def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder. + + Example mappings: + - `...attention.attention.query.*` -> `...attention.attention.to_q.*` + - `...attention.attention.key.*` -> `...attention.attention.to_k.*` + - `...attention.attention.value.*` -> `...attention.attention.to_v.*` + - `...attention.output.dense.*` -> `...attention.attention.to_out.0.*` + """ + remapped: dict[str, Any] = {} + for key, value in state_dict.items(): + new_key = key + new_key = new_key.replace(".attention.attention.query.", ".attention.attention.to_q.") + new_key = new_key.replace(".attention.attention.key.", ".attention.attention.to_k.") + new_key = new_key.replace(".attention.attention.value.", ".attention.attention.to_v.") + new_key = new_key.replace(".attention.output.dense.", ".attention.attention.to_out.0.") + remapped[new_key] = value + return remapped + + def resolve_decoder_file( accessor: RepoAccessor, encoder_cls: str, variant: str, decoder_checkpoint: str | None ) -> str: @@ -179,6 +200,7 @@ def convert(args: argparse.Namespace) -> None: decoder_path = accessor.fetch(decoder_relpath) decoder_obj = torch.load(decoder_path, map_location="cpu") decoder_state_dict = unwrap_state_dict(decoder_obj) + decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict) latents_mean, latents_std = None, None if stats_relpath is not None: diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index de9aaa9ce436..2a7e58a4da45 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -27,6 +27,8 @@ from ...utils import BaseOutput, logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import Attention from ..embeddings import get_2d_sincos_pos_embed from ..modeling_utils import ModelMixin from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput @@ -69,12 +71,15 @@ def __init__(self, encoder_name_or_path: str = "facebook/dinov2-with-registers-b self.patch_size = self.model.config.patch_size self.hidden_size = self.model.config.hidden_size - @torch.no_grad() - def forward(self, images: torch.Tensor) -> torch.Tensor: + def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and """ - outputs = self.model(images, output_hidden_states=True) + if requires_grad: + outputs = self.model(images, output_hidden_states=True) + else: + with torch.no_grad(): + outputs = self.model(images, output_hidden_states=True) unused_token_num = 5 # 1 CLS + 4 register tokens image_features = outputs.last_hidden_state[:, unused_token_num:] return image_features @@ -96,12 +101,15 @@ def __init__(self, encoder_name_or_path: str = "google/siglip2-base-patch16-256" self.hidden_size = self.model.config.hidden_size self.patch_size = self.model.config.patch_size - @torch.no_grad() - def forward(self, images: torch.Tensor) -> torch.Tensor: + def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and """ - outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding=True) + if requires_grad: + outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding=True) + else: + with torch.no_grad(): + outputs = self.model(images, output_hidden_states=True, interpolate_pos_encoding=True) image_features = outputs.last_hidden_state return image_features @@ -123,8 +131,7 @@ def __init__(self, encoder_name_or_path: str = "facebook/vit-mae-base"): self.patch_size = self.model.config.patch_size self.model.config.mask_ratio = 0.0 # no masking - @torch.no_grad() - def forward(self, images: torch.Tensor) -> torch.Tensor: + def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ images is of shape (B, C, H, W) where B is batch size, C is number of channels, H and W are height and width of the image @@ -134,7 +141,11 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: if patch_num * self.patch_size**2 != h * w: raise ValueError("Image size should be divisible by patch size.") noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) - outputs = self.model(images, noise, interpolate_pos_encoding=True) + if requires_grad: + outputs = self.model(images, noise, interpolate_pos_encoding=True) + else: + with torch.no_grad(): + outputs = self.model(images, noise, interpolate_pos_encoding=True) image_features = outputs.last_hidden_state[:, 1:] # remove cls token return image_features @@ -152,73 +163,45 @@ class RAEDecoderOutput(BaseOutput): logits: torch.Tensor -class ViTMAESelfAttention(nn.Module): - def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): - super().__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" - ) - - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) - self.key = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) - self.value = nn.Linear(hidden_size, self.all_head_size, bias=qkv_bias) - self.dropout = nn.Dropout(attn_dropout) - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / (self.attention_head_size**0.5) - attention_probs = torch.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - return context_layer - +@dataclass +class AutoencoderRAELossOutput(BaseOutput): + """ + Output of `AutoencoderRAE.forward(..., return_loss=True)`. -class ViTMAESelfOutput(nn.Module): - def __init__(self, hidden_size: int, hidden_dropout_prob: float = 0.0): - super().__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + Args: + sample (`torch.Tensor`): + Reconstructed image tensor of shape `(batch_size, num_channels, image_height, image_width)`. + loss (`torch.Tensor`): + Total training loss. + reconstruction_loss (`torch.Tensor`): + Pixel-space reconstruction loss. + encoder_loss (`torch.Tensor`): + Optional encoder feature-space loss. Zero when disabled. + """ - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - return hidden_states + sample: torch.Tensor + loss: torch.Tensor + reconstruction_loss: torch.Tensor + encoder_loss: torch.Tensor class ViTMAEAttention(nn.Module): def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): super().__init__() - self.attention = ViTMAESelfAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - qkv_bias=qkv_bias, - attn_dropout=attn_dropout, + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" + ) + self.attention = Attention( + query_dim=hidden_size, + heads=num_attention_heads, + dim_head=hidden_size // num_attention_heads, + dropout=attn_dropout, + bias=qkv_bias, ) - self.output = ViTMAESelfOutput(hidden_size=hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - attn_output = self.attention(hidden_states) - attn_output = self.output(attn_output) - return attn_output + return self.attention(hidden_states) class ViTMAEIntermediate(nn.Module): @@ -456,7 +439,9 @@ def forward( return RAEDecoderOutput(logits=logits) -class AutoencoderRAE(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class AutoencoderRAE( + ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin +): r""" Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. @@ -688,11 +673,24 @@ def _maybe_denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) if self._latents_std is not None else 1 return z * (latents_std + 1e-5) + latents_mean + def _encode_tokens(self, x: torch.Tensor, *, requires_grad: bool) -> torch.Tensor: + # Keep compatibility with custom registered encoders that may not accept `requires_grad`. + try: + return self.encoder(x, requires_grad=requires_grad) + except TypeError: + if requires_grad: + logger.warning( + "Encoder class `%s` does not accept `requires_grad`; falling back to default forward for " + "encoder loss computation.", + self.encoder.__class__.__name__, + ) + return self.encoder(x) + def _encode(self, x: torch.Tensor) -> torch.Tensor: x = self._maybe_resize_and_normalize(x) - # Encoder is frozen; many encoders already run under no_grad - tokens = self.encoder(x) # (B, N, C) + # Encoder is frozen by default for latent extraction. + tokens = self._encode_tokens(x, requires_grad=False) # (B, N, C) if self.training and self.noise_tau > 0: tokens = self._noising(tokens) @@ -714,6 +712,29 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: return z + def _compute_reconstruction_loss( + self, reconstructed: torch.Tensor, target: torch.Tensor, reconstruction_loss_type: str = "l1" + ) -> torch.Tensor: + if reconstructed.shape[-2:] != target.shape[-2:]: + target = F.interpolate( + target, + size=reconstructed.shape[-2:], + mode="bicubic", + align_corners=False, + ) + if reconstruction_loss_type == "l1": + return F.l1_loss(reconstructed.float(), target.float()) + if reconstruction_loss_type == "mse": + return F.mse_loss(reconstructed.float(), target.float()) + raise ValueError( + f"Unsupported reconstruction_loss_type='{reconstruction_loss_type}'. Expected one of ['l1', 'mse']." + ) + + def _compute_encoder_feature_loss(self, reconstructed: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + target_tokens = self._encode_tokens(self._maybe_resize_and_normalize(target), requires_grad=False).detach() + reconstructed_tokens = self._encode_tokens(self._maybe_resize_and_normalize(reconstructed), requires_grad=True) + return F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) + @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]: if self.use_slicing and x.shape[0] > 1: @@ -754,9 +775,33 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) - def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + def forward( + self, + sample: torch.Tensor, + return_dict: bool = True, + return_loss: bool = False, + reconstruction_loss_type: str = "l1", + encoder_loss_weight: float = 0.0, + ) -> Union[DecoderOutput, AutoencoderRAELossOutput, Tuple[torch.Tensor]]: latents = self.encode(sample, return_dict=False)[0] decoded = self.decode(latents, return_dict=False)[0] + if return_loss: + reconstruction_loss = self._compute_reconstruction_loss( + decoded, sample, reconstruction_loss_type=reconstruction_loss_type + ) + encoder_loss = torch.zeros_like(reconstruction_loss) + if self.use_encoder_loss and encoder_loss_weight > 0: + encoder_loss = self._compute_encoder_feature_loss(decoded, sample) + total_loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss + + if not return_dict: + return (decoded, total_loss, reconstruction_loss, encoder_loss) + return AutoencoderRAELossOutput( + sample=decoded, + loss=total_loss, + reconstruction_loss=reconstruction_loss, + encoder_loss=encoder_loss, + ) if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 899b42e757e9..ee5d74c9b5a4 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -45,7 +45,7 @@ def __init__(self, encoder_name_or_path: str = "unused"): self.patch_size = 8 self.hidden_size = 16 - def forward(self, images: torch.Tensor) -> torch.Tensor: + def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) tokens = pooled.flatten(2).transpose(1, 2).contiguous() return tokens.repeat(1, 1, self.hidden_size) @@ -159,6 +159,32 @@ def test_fast_noise_tau_applies_only_in_train(self): self.assertFalse(torch.allclose(z_train_1, z_train_2)) self.assertTrue(torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)) + def test_fast_forward_return_loss_reconstruction_only(self): + model = self._make_model(use_encoder_loss=False).train() + x = torch.rand(2, 3, 32, 32, device=torch_device) + + output = model(x, return_loss=True) + + self.assertEqual(output.sample.shape, (2, 3, 16, 16)) + self.assertTrue(torch.isfinite(output.loss).all().item()) + self.assertTrue(torch.isfinite(output.reconstruction_loss).all().item()) + self.assertTrue(torch.isfinite(output.encoder_loss).all().item()) + self.assertEqual(output.encoder_loss.item(), 0.0) + self.assertTrue(torch.allclose(output.loss, output.reconstruction_loss)) + + def test_fast_forward_return_loss_with_encoder_loss(self): + model = self._make_model(use_encoder_loss=True).train() + x = torch.rand(2, 3, 32, 32, device=torch_device) + + output = model(x, return_loss=True, encoder_loss_weight=0.5, reconstruction_loss_type="mse") + + self.assertEqual(output.sample.shape, (2, 3, 16, 16)) + self.assertTrue(torch.isfinite(output.loss).all().item()) + self.assertTrue(torch.isfinite(output.reconstruction_loss).all().item()) + self.assertTrue(torch.isfinite(output.encoder_loss).all().item()) + self.assertGreaterEqual(output.encoder_loss.item(), 0.0) + self.assertGreaterEqual(output.loss.item(), output.reconstruction_loss.item()) + @slow class AutoencoderRAEEncoderIntegrationTests(unittest.TestCase): From 6a9bde6964d13872ee8b9bec462ef23cdce846c5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:55:06 +0000 Subject: [PATCH 12/30] remove unneeded class --- scripts/convert_rae_to_diffusers.py | 16 ++++----- .../models/autoencoders/autoencoder_rae.py | 34 ++++++------------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index 8db9f448c4a2..39285ad61569 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -104,18 +104,18 @@ def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> di Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder. Example mappings: - - `...attention.attention.query.*` -> `...attention.attention.to_q.*` - - `...attention.attention.key.*` -> `...attention.attention.to_k.*` - - `...attention.attention.value.*` -> `...attention.attention.to_v.*` - - `...attention.output.dense.*` -> `...attention.attention.to_out.0.*` + - `...attention.attention.query.*` -> `...attention.to_q.*` + - `...attention.attention.key.*` -> `...attention.to_k.*` + - `...attention.attention.value.*` -> `...attention.to_v.*` + - `...attention.output.dense.*` -> `...attention.to_out.0.*` """ remapped: dict[str, Any] = {} for key, value in state_dict.items(): new_key = key - new_key = new_key.replace(".attention.attention.query.", ".attention.attention.to_q.") - new_key = new_key.replace(".attention.attention.key.", ".attention.attention.to_k.") - new_key = new_key.replace(".attention.attention.value.", ".attention.attention.to_v.") - new_key = new_key.replace(".attention.output.dense.", ".attention.attention.to_out.0.") + new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.") + new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.") + new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.") + new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.") remapped[new_key] = value return remapped diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 2a7e58a4da45..cafceacd64f4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -185,25 +185,6 @@ class AutoencoderRAELossOutput(BaseOutput): encoder_loss: torch.Tensor -class ViTMAEAttention(nn.Module): - def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool = True, attn_dropout: float = 0.0): - super().__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" - ) - self.attention = Attention( - query_dim=hidden_size, - heads=num_attention_heads, - dim_head=hidden_size // num_attention_heads, - dropout=attn_dropout, - bias=qkv_bias, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.attention(hidden_states) - - class ViTMAEIntermediate(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): super().__init__() @@ -250,11 +231,16 @@ def __init__( hidden_act: str = "gelu", ): super().__init__() - self.attention = ViTMAEAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - qkv_bias=qkv_bias, - attn_dropout=attention_probs_dropout_prob, + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" + ) + self.attention = Attention( + query_dim=hidden_size, + heads=num_attention_heads, + dim_head=hidden_size // num_attention_heads, + dropout=attention_probs_dropout_prob, + bias=qkv_bias, ) self.intermediate = ViTMAEIntermediate( hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act From 9522e68a5b2b0f45c659f6249b351cacc3049b64 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Feb 2026 23:56:19 +0000 Subject: [PATCH 13/30] example traiing script --- docs/source/en/api/models/autoencoder_rae.md | 7 + .../autoencoder_rae/README.md | 41 +++ .../autoencoder_rae/train_autoencoder_rae.py | 301 ++++++++++++++++++ 3 files changed, 349 insertions(+) create mode 100644 examples/research_projects/autoencoder_rae/README.md create mode 100644 examples/research_projects/autoencoder_rae/train_autoencoder_rae.py diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md index e69b6451c3c0..a34b7623e141 100644 --- a/docs/source/en/api/models/autoencoder_rae.md +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -45,6 +45,13 @@ with torch.no_grad(): For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders). +For training, `forward(...)` also supports: +- `return_loss=True` +- `reconstruction_loss_type` (`"l1"` or `"mse"`) +- `encoder_loss_weight` (used when `use_encoder_loss=True`) + +See `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` for a stage-1 style training script. + ## AutoencoderRAE class [[autodoc]] AutoencoderRAE diff --git a/examples/research_projects/autoencoder_rae/README.md b/examples/research_projects/autoencoder_rae/README.md new file mode 100644 index 000000000000..beea300f21d9 --- /dev/null +++ b/examples/research_projects/autoencoder_rae/README.md @@ -0,0 +1,41 @@ +# Training AutoencoderRAE + +This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen. + +It follows the same high-level training recipe as the official RAE stage-1 setup: +- frozen encoder +- train decoder +- pixel reconstruction loss +- optional encoder feature consistency loss + +## Quickstart + +```bash +accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/autoencoder-rae \ + --encoder_cls dinov2 \ + --encoder_input_size 224 \ + --patch_size 16 \ + --image_size 256 \ + --decoder_hidden_size 1152 \ + --decoder_num_hidden_layers 28 \ + --decoder_num_attention_heads 16 \ + --decoder_intermediate_size 4096 \ + --train_batch_size 8 \ + --learning_rate 1e-4 \ + --num_train_epochs 10 \ + --reconstruction_loss_type l1 \ + --use_encoder_loss \ + --encoder_loss_weight 0.1 +``` + +Dataset format is expected to be `ImageFolder`-compatible: + +```text +train_data_dir/ + class_a/ + img_0001.jpg + class_b/ + img_0002.jpg +``` diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py new file mode 100644 index 000000000000..b25facba0ed4 --- /dev/null +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +from pathlib import Path + +import torch +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +from tqdm.auto import tqdm + +from diffusers import AutoencoderRAE +from diffusers.optimization import get_scheduler + + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.") + parser.add_argument( + "--train_data_dir", + type=str, + required=True, + help="Path to an ImageFolder-style dataset root.", + ) + parser.add_argument( + "--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model." + ) + parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.") + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--center_crop", action="store_true") + parser.add_argument("--random_flip", action="store_true") + + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--num_train_epochs", type=int, default=10) + parser.add_argument("--max_train_steps", type=int, default=None) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-8) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + + parser.add_argument("--checkpointing_steps", type=int, default=1000) + parser.add_argument("--validation_steps", type=int, default=500) + + parser.add_argument("--encoder_cls", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2") + parser.add_argument("--encoder_name_or_path", type=str, default=None) + parser.add_argument("--encoder_input_size", type=int, default=224) + parser.add_argument("--patch_size", type=int, default=16) + parser.add_argument("--image_size", type=int, default=256) + parser.add_argument("--num_channels", type=int, default=3) + + parser.add_argument("--decoder_hidden_size", type=int, default=1152) + parser.add_argument("--decoder_num_hidden_layers", type=int, default=28) + parser.add_argument("--decoder_num_attention_heads", type=int, default=16) + parser.add_argument("--decoder_intermediate_size", type=int, default=4096) + + parser.add_argument("--noise_tau", type=float, default=0.0) + parser.add_argument("--scaling_factor", type=float, default=1.0) + parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True) + + parser.add_argument( + "--reconstruction_loss_type", + type=str, + choices=["l1", "mse"], + default="l1", + help="Pixel reconstruction loss.", + ) + parser.add_argument( + "--encoder_loss_weight", + type=float, + default=0.0, + help="Weight for encoder feature consistency loss used in `AutoencoderRAE.forward(return_loss=True)`.", + ) + parser.add_argument( + "--use_encoder_loss", + action="store_true", + help="Enable encoder feature consistency loss in model forward.", + ) + parser.add_argument("--report_to", type=str, default="tensorboard") + + return parser.parse_args() + + +def build_transforms(args): + image_transforms = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + if args.random_flip: + image_transforms.append(transforms.RandomHorizontalFlip()) + image_transforms.append(transforms.ToTensor()) + return transforms.Compose(image_transforms) + + +def main(): + args = parse_args() + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + project_config=accelerator_project_config, + log_with=args.report_to, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args)) + + def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]).float() + return {"pixel_values": pixel_values} + + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=True, + drop_last=True, + ) + + model = AutoencoderRAE( + encoder_cls=args.encoder_cls, + encoder_name_or_path=args.encoder_name_or_path, + decoder_hidden_size=args.decoder_hidden_size, + decoder_num_hidden_layers=args.decoder_num_hidden_layers, + decoder_num_attention_heads=args.decoder_num_attention_heads, + decoder_intermediate_size=args.decoder_intermediate_size, + patch_size=args.patch_size, + encoder_input_size=args.encoder_input_size, + image_size=args.image_size, + num_channels=args.num_channels, + noise_tau=args.noise_tau, + reshape_to_2d=args.reshape_to_2d, + use_encoder_loss=args.use_encoder_loss, + scaling_factor=args.scaling_factor, + ) + model.encoder.requires_grad_(False) + model.decoder.requires_grad_(True) + model.train() + + optimizer = torch.optim.AdamW( + (p for p in model.parameters() if p.requires_grad), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + if overrode_max_train_steps: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers("train_autoencoder_rae", config=vars(args)) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + pixel_values = batch["pixel_values"] + + model_output = model( + pixel_values, + return_loss=True, + reconstruction_loss_type=args.reconstruction_loss_type, + encoder_loss_weight=args.encoder_loss_weight, + ) + loss = model_output.loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = { + "loss": loss.detach().item(), + "reconstruction_loss": model_output.reconstruction_loss.detach().item(), + "encoder_loss": model_output.encoder_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step % args.validation_steps == 0: + with torch.no_grad(): + val_output = model( + pixel_values, + return_loss=True, + reconstruction_loss_type=args.reconstruction_loss_type, + encoder_loss_weight=args.encoder_loss_weight, + ) + accelerator.log( + { + "val/loss": val_output.loss.detach().item(), + "val/reconstruction_loss": val_output.reconstruction_loss.detach().item(), + "val/encoder_loss": val_output.encoder_loss.detach().item(), + }, + step=global_step, + ) + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(save_path) + logger.info(f"Saved checkpoint to {save_path}") + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir) + accelerator.end_training() + + +if __name__ == "__main__": + main() From 906d79a43244909b4028eaf7b0429960ebea8b9f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 00:02:27 +0000 Subject: [PATCH 14/30] input and ground truth sizes have to be the same --- .../autoencoder_rae/train_autoencoder_rae.py | 5 +++++ src/diffusers/models/autoencoders/autoencoder_rae.py | 9 ++++----- tests/models/autoencoders/test_models_autoencoder_rae.py | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py index b25facba0ed4..7c15ca85cab6 100644 --- a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -123,6 +123,11 @@ def build_transforms(args): def main(): args = parse_args() + if args.resolution != args.image_size: + raise ValueError( + f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) " + "for stage-1 reconstruction loss." + ) logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index cafceacd64f4..f0f797b92e27 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -702,11 +702,10 @@ def _compute_reconstruction_loss( self, reconstructed: torch.Tensor, target: torch.Tensor, reconstruction_loss_type: str = "l1" ) -> torch.Tensor: if reconstructed.shape[-2:] != target.shape[-2:]: - target = F.interpolate( - target, - size=reconstructed.shape[-2:], - mode="bicubic", - align_corners=False, + raise ValueError( + "Reconstruction loss requires matching spatial sizes, but got " + f"reconstructed={tuple(reconstructed.shape[-2:])} and target={tuple(target.shape[-2:])}. " + "Configure `image_size` to match training input resolution." ) if reconstruction_loss_type == "l1": return F.l1_loss(reconstructed.float(), target.float()) diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index ee5d74c9b5a4..1f3e6a6aadc7 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -161,7 +161,7 @@ def test_fast_noise_tau_applies_only_in_train(self): def test_fast_forward_return_loss_reconstruction_only(self): model = self._make_model(use_encoder_loss=False).train() - x = torch.rand(2, 3, 32, 32, device=torch_device) + x = torch.rand(2, 3, 16, 16, device=torch_device) output = model(x, return_loss=True) @@ -174,7 +174,7 @@ def test_fast_forward_return_loss_reconstruction_only(self): def test_fast_forward_return_loss_with_encoder_loss(self): model = self._make_model(use_encoder_loss=True).train() - x = torch.rand(2, 3, 32, 32, device=torch_device) + x = torch.rand(2, 3, 16, 16, device=torch_device) output = model(x, return_loss=True, encoder_loss_weight=0.5, reconstruction_loss_type="mse") From d3cbd5a60ba48c4af86ed613827c4549cd779122 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 00:03:54 +0000 Subject: [PATCH 15/30] fix argument --- examples/research_projects/autoencoder_rae/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/research_projects/autoencoder_rae/README.md b/examples/research_projects/autoencoder_rae/README.md index beea300f21d9..559eb37518ae 100644 --- a/examples/research_projects/autoencoder_rae/README.md +++ b/examples/research_projects/autoencoder_rae/README.md @@ -14,6 +14,7 @@ It follows the same high-level training recipe as the official RAE stage-1 setup accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ --train_data_dir /path/to/imagenet_like_folder \ --output_dir /tmp/autoencoder-rae \ + --resolution 256 \ --encoder_cls dinov2 \ --encoder_input_size 224 \ --patch_size 16 \ @@ -30,6 +31,8 @@ accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_r --encoder_loss_weight 0.1 ``` +Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`. + Dataset format is expected to be `ImageFolder`-compatible: ```text From 96520c4ff1c1871e70b7a1abf46ad574806aad21 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 12:35:18 +0000 Subject: [PATCH 16/30] move loss to training script --- docs/source/en/api/models/autoencoder_rae.md | 8 +- .../autoencoder_rae/train_autoencoder_rae.py | 52 +++++++-- .../models/autoencoders/autoencoder_rae.py | 100 ++---------------- .../test_models_autoencoder_rae.py | 33 ++---- 4 files changed, 57 insertions(+), 136 deletions(-) diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md index a34b7623e141..43ec89c3b7ec 100644 --- a/docs/source/en/api/models/autoencoder_rae.md +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -45,12 +45,8 @@ with torch.no_grad(): For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders). -For training, `forward(...)` also supports: -- `return_loss=True` -- `reconstruction_loss_type` (`"l1"` or `"mse"`) -- `encoder_loss_weight` (used when `use_encoder_loss=True`) - -See `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` for a stage-1 style training script. +See `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` for a stage-1 style training script +(reconstruction and optional encoder-feature losses are computed in the training loop, following diffusers training conventions). ## AutoencoderRAE class diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py index 7c15ca85cab6..4b0a2c551671 100644 --- a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -21,6 +21,7 @@ from pathlib import Path import torch +import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import ProjectConfiguration, set_seed from torch.utils.data import DataLoader @@ -98,7 +99,7 @@ def parse_args(): "--encoder_loss_weight", type=float, default=0.0, - help="Weight for encoder feature consistency loss used in `AutoencoderRAE.forward(return_loss=True)`.", + help="Weight for encoder feature consistency loss in the training loop.", ) parser.add_argument( "--use_encoder_loss", @@ -121,6 +122,34 @@ def build_transforms(args): return transforms.Compose(image_transforms) +def compute_losses( + model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float +): + decoded = model(pixel_values).sample + + if decoded.shape[-2:] != pixel_values.shape[-2:]: + raise ValueError( + "Training requires matching reconstruction and target sizes, got " + f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}." + ) + + if reconstruction_loss_type == "l1": + reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float()) + else: + reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float()) + + encoder_loss = torch.zeros_like(reconstruction_loss) + if use_encoder_loss and encoder_loss_weight > 0: + target_tokens = model._encode_tokens( + model._maybe_resize_and_normalize(pixel_values), requires_grad=False + ).detach() + reconstructed_tokens = model._encode_tokens(model._maybe_resize_and_normalize(decoded), requires_grad=True) + encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) + + loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss + return decoded, loss, reconstruction_loss, encoder_loss + + def main(): args = parse_args() if args.resolution != args.image_size: @@ -237,13 +266,13 @@ def collate_fn(examples): with accelerator.accumulate(model): pixel_values = batch["pixel_values"] - model_output = model( + _, loss, reconstruction_loss, encoder_loss = compute_losses( + model, pixel_values, - return_loss=True, reconstruction_loss_type=args.reconstruction_loss_type, + use_encoder_loss=args.use_encoder_loss, encoder_loss_weight=args.encoder_loss_weight, ) - loss = model_output.loss accelerator.backward(loss) if accelerator.sync_gradients: @@ -258,8 +287,8 @@ def collate_fn(examples): logs = { "loss": loss.detach().item(), - "reconstruction_loss": model_output.reconstruction_loss.detach().item(), - "encoder_loss": model_output.encoder_loss.detach().item(), + "reconstruction_loss": reconstruction_loss.detach().item(), + "encoder_loss": encoder_loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) @@ -267,17 +296,18 @@ def collate_fn(examples): if global_step % args.validation_steps == 0: with torch.no_grad(): - val_output = model( + _, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses( + model, pixel_values, - return_loss=True, reconstruction_loss_type=args.reconstruction_loss_type, + use_encoder_loss=args.use_encoder_loss, encoder_loss_weight=args.encoder_loss_weight, ) accelerator.log( { - "val/loss": val_output.loss.detach().item(), - "val/reconstruction_loss": val_output.reconstruction_loss.detach().item(), - "val/encoder_loss": val_output.encoder_loss.detach().item(), + "val/loss": val_loss.detach().item(), + "val/reconstruction_loss": val_reconstruction_loss.detach().item(), + "val/encoder_loss": val_encoder_loss.detach().item(), }, step=global_step, ) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index f0f797b92e27..83459fcfd391 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -163,36 +163,11 @@ class RAEDecoderOutput(BaseOutput): logits: torch.Tensor -@dataclass -class AutoencoderRAELossOutput(BaseOutput): - """ - Output of `AutoencoderRAE.forward(..., return_loss=True)`. - - Args: - sample (`torch.Tensor`): - Reconstructed image tensor of shape `(batch_size, num_channels, image_height, image_width)`. - loss (`torch.Tensor`): - Total training loss. - reconstruction_loss (`torch.Tensor`): - Pixel-space reconstruction loss. - encoder_loss (`torch.Tensor`): - Optional encoder feature-space loss. Zero when disabled. - """ - - sample: torch.Tensor - loss: torch.Tensor - reconstruction_loss: torch.Tensor - encoder_loss: torch.Tensor - - class ViTMAEIntermediate(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): super().__init__() self.dense = nn.Linear(hidden_size, intermediate_size) - try: - self.intermediate_act_fn = get_activation(hidden_act) - except ValueError as e: - raise ValueError(f"Unsupported hidden_act={hidden_act}") from e + self.intermediate_act_fn = get_activation(hidden_act) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) @@ -562,25 +537,12 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." ) - # Normalization stats from the encoder's image processor - # RAE-main uses AutoImageProcessor mean/std; we follow the same. - encoder_mean = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1) - encoder_std = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).view(1, 3, 1, 1) - try: - from transformers import AutoImageProcessor - - try: - proc = AutoImageProcessor.from_pretrained(encoder_name_or_path, local_files_only=True) - except Exception: - proc = AutoImageProcessor.from_pretrained(encoder_name_or_path, local_files_only=False) - encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1) - encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1) - except (OSError, ValueError): - # Keep default 0.5/0.5 if processor is unavailable. - logger.warning( - "Falling back to encoder mean/std [0.5, 0.5, 0.5] for `%s` because AutoImageProcessor could not be loaded.", - encoder_name_or_path, - ) + # Normalization stats from the encoder's image processor (strict, same as official RAE). + from transformers import AutoImageProcessor + + proc = AutoImageProcessor.from_pretrained(encoder_name_or_path) + encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1) + encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1) self.register_buffer("encoder_mean", encoder_mean, persistent=True) self.register_buffer("encoder_std", encoder_std, persistent=True) @@ -698,28 +660,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: return z - def _compute_reconstruction_loss( - self, reconstructed: torch.Tensor, target: torch.Tensor, reconstruction_loss_type: str = "l1" - ) -> torch.Tensor: - if reconstructed.shape[-2:] != target.shape[-2:]: - raise ValueError( - "Reconstruction loss requires matching spatial sizes, but got " - f"reconstructed={tuple(reconstructed.shape[-2:])} and target={tuple(target.shape[-2:])}. " - "Configure `image_size` to match training input resolution." - ) - if reconstruction_loss_type == "l1": - return F.l1_loss(reconstructed.float(), target.float()) - if reconstruction_loss_type == "mse": - return F.mse_loss(reconstructed.float(), target.float()) - raise ValueError( - f"Unsupported reconstruction_loss_type='{reconstruction_loss_type}'. Expected one of ['l1', 'mse']." - ) - - def _compute_encoder_feature_loss(self, reconstructed: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - target_tokens = self._encode_tokens(self._maybe_resize_and_normalize(target), requires_grad=False).detach() - reconstructed_tokens = self._encode_tokens(self._maybe_resize_and_normalize(reconstructed), requires_grad=True) - return F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) - @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]: if self.use_slicing and x.shape[0] > 1: @@ -760,33 +700,9 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) - def forward( - self, - sample: torch.Tensor, - return_dict: bool = True, - return_loss: bool = False, - reconstruction_loss_type: str = "l1", - encoder_loss_weight: float = 0.0, - ) -> Union[DecoderOutput, AutoencoderRAELossOutput, Tuple[torch.Tensor]]: + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]: latents = self.encode(sample, return_dict=False)[0] decoded = self.decode(latents, return_dict=False)[0] - if return_loss: - reconstruction_loss = self._compute_reconstruction_loss( - decoded, sample, reconstruction_loss_type=reconstruction_loss_type - ) - encoder_loss = torch.zeros_like(reconstruction_loss) - if self.use_encoder_loss and encoder_loss_weight > 0: - encoder_loss = self._compute_encoder_feature_loss(decoded, sample) - total_loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss - - if not return_dict: - return (decoded, total_loss, reconstruction_loss, encoder_loss) - return AutoencoderRAELossOutput( - sample=decoded, - loss=total_loss, - reconstruction_loss=reconstruction_loss, - encoder_loss=encoder_loss, - ) if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 1f3e6a6aadc7..d9fe351b62a3 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -15,6 +15,7 @@ import gc import unittest +from unittest.mock import patch import torch import torch.nn.functional as F @@ -74,7 +75,11 @@ def _make_model(self, **overrides) -> AutoencoderRAE: "scaling_factor": 1.0, } config.update(overrides) - return AutoencoderRAE(**config).to(torch_device) + with patch("transformers.AutoImageProcessor.from_pretrained") as mocked_from_pretrained: + mocked_from_pretrained.return_value = type( + "_MockProcessor", (), {"image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5]} + )() + return AutoencoderRAE(**config).to(torch_device) def test_fast_encode_decode_and_forward_shapes(self): model = self._make_model().eval() @@ -159,32 +164,6 @@ def test_fast_noise_tau_applies_only_in_train(self): self.assertFalse(torch.allclose(z_train_1, z_train_2)) self.assertTrue(torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)) - def test_fast_forward_return_loss_reconstruction_only(self): - model = self._make_model(use_encoder_loss=False).train() - x = torch.rand(2, 3, 16, 16, device=torch_device) - - output = model(x, return_loss=True) - - self.assertEqual(output.sample.shape, (2, 3, 16, 16)) - self.assertTrue(torch.isfinite(output.loss).all().item()) - self.assertTrue(torch.isfinite(output.reconstruction_loss).all().item()) - self.assertTrue(torch.isfinite(output.encoder_loss).all().item()) - self.assertEqual(output.encoder_loss.item(), 0.0) - self.assertTrue(torch.allclose(output.loss, output.reconstruction_loss)) - - def test_fast_forward_return_loss_with_encoder_loss(self): - model = self._make_model(use_encoder_loss=True).train() - x = torch.rand(2, 3, 16, 16, device=torch_device) - - output = model(x, return_loss=True, encoder_loss_weight=0.5, reconstruction_loss_type="mse") - - self.assertEqual(output.sample.shape, (2, 3, 16, 16)) - self.assertTrue(torch.isfinite(output.loss).all().item()) - self.assertTrue(torch.isfinite(output.reconstruction_loss).all().item()) - self.assertTrue(torch.isfinite(output.encoder_loss).all().item()) - self.assertGreaterEqual(output.encoder_loss.item(), 0.0) - self.assertGreaterEqual(output.loss.item(), output.reconstruction_loss.item()) - @slow class AutoencoderRAEEncoderIntegrationTests(unittest.TestCase): From fc5295951a362149ec45f0dd6c1251819bb08943 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 12:40:36 +0000 Subject: [PATCH 17/30] cleanup --- .../models/autoencoders/autoencoder_rae.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 83459fcfd391..9d7e70a152ee 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -621,24 +621,11 @@ def _maybe_denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) if self._latents_std is not None else 1 return z * (latents_std + 1e-5) + latents_mean - def _encode_tokens(self, x: torch.Tensor, *, requires_grad: bool) -> torch.Tensor: - # Keep compatibility with custom registered encoders that may not accept `requires_grad`. - try: - return self.encoder(x, requires_grad=requires_grad) - except TypeError: - if requires_grad: - logger.warning( - "Encoder class `%s` does not accept `requires_grad`; falling back to default forward for " - "encoder loss computation.", - self.encoder.__class__.__name__, - ) - return self.encoder(x) - def _encode(self, x: torch.Tensor) -> torch.Tensor: x = self._maybe_resize_and_normalize(x) # Encoder is frozen by default for latent extraction. - tokens = self._encode_tokens(x, requires_grad=False) # (B, N, C) + tokens = self.encoder(x, requires_grad=False) # (B, N, C) if self.training and self.noise_tau > 0: tokens = self._noising(tokens) From a4fc9f64b2d9ade9e7baaa1e724bfa61608ae660 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 12:52:20 +0000 Subject: [PATCH 18/30] simplify mixins --- src/diffusers/models/autoencoders/autoencoder_rae.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 9d7e70a152ee..9e4e6ed384ca 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -22,8 +22,6 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation @@ -400,9 +398,7 @@ def forward( return RAEDecoderOutput(logits=logits) -class AutoencoderRAE( - ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin -): +class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): r""" Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. From d06b50185042b3fbcd665a984629ca78cdc8cd6b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Feb 2026 13:00:00 +0000 Subject: [PATCH 19/30] fix training script --- .../autoencoder_rae/README.md | 1 + .../autoencoder_rae/train_autoencoder_rae.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/research_projects/autoencoder_rae/README.md b/examples/research_projects/autoencoder_rae/README.md index 559eb37518ae..c6ffe77112f3 100644 --- a/examples/research_projects/autoencoder_rae/README.md +++ b/examples/research_projects/autoencoder_rae/README.md @@ -26,6 +26,7 @@ accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_r --train_batch_size 8 \ --learning_rate 1e-4 \ --num_train_epochs 10 \ + --report_to wandb \ --reconstruction_loss_type l1 \ --use_encoder_loss \ --encoder_loss_weight 0.1 diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py index 4b0a2c551671..72f030a4286e 100644 --- a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -23,6 +23,7 @@ import torch import torch.nn.functional as F from accelerate import Accelerator +from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from torch.utils.data import DataLoader from torchvision import transforms @@ -33,7 +34,7 @@ from diffusers.optimization import get_scheduler -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def parse_args(): @@ -104,7 +105,7 @@ def parse_args(): parser.add_argument( "--use_encoder_loss", action="store_true", - help="Enable encoder feature consistency loss in model forward.", + help="Enable encoder feature consistency loss term in the training loop.", ) parser.add_argument("--report_to", type=str, default="tensorboard") @@ -122,9 +123,7 @@ def build_transforms(args): return transforms.Compose(image_transforms) -def compute_losses( - model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float -): +def compute_losses(model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float): decoded = model(pixel_values).sample if decoded.shape[-2:] != pixel_values.shape[-2:]: @@ -140,10 +139,12 @@ def compute_losses( encoder_loss = torch.zeros_like(reconstruction_loss) if use_encoder_loss and encoder_loss_weight > 0: - target_tokens = model._encode_tokens( - model._maybe_resize_and_normalize(pixel_values), requires_grad=False - ).detach() - reconstructed_tokens = model._encode_tokens(model._maybe_resize_and_normalize(decoded), requires_grad=True) + base_model = model.module if hasattr(model, "module") else model + target_encoder_input = base_model._maybe_resize_and_normalize(pixel_values) + reconstructed_encoder_input = base_model._maybe_resize_and_normalize(decoded) + + target_tokens = base_model.encoder(target_encoder_input, requires_grad=False).detach() + reconstructed_tokens = base_model.encoder(reconstructed_encoder_input, requires_grad=True) encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss From c68b812cb0a997dd49880c57e8f170ffdba69020 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 23 Feb 2026 09:40:18 +0000 Subject: [PATCH 20/30] fix entrypoint for instantiating the AutoencoderRAE --- docs/source/en/api/models/autoencoder_rae.md | 15 +- .../autoencoder_rae/train_autoencoder_rae.py | 14 +- scripts/convert_rae_to_diffusers.py | 91 +++++++-- .../models/autoencoders/autoencoder_rae.py | 188 +++++++++++------- .../test_models_autoencoder_rae.py | 67 ++----- 5 files changed, 231 insertions(+), 144 deletions(-) diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md index 43ec89c3b7ec..77af9ea0f21f 100644 --- a/docs/source/en/api/models/autoencoder_rae.md +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -26,22 +26,21 @@ The model follows the standard diffusers autoencoder API: import torch from diffusers import AutoencoderRAE -model = AutoencoderRAE( - encoder_cls="dinov2", - encoder_name_or_path="facebook/dinov2-with-registers-base", - encoder_input_size=224, - patch_size=16, - image_size=256, +# Load a converted model from the Hub +model = AutoencoderRAE.from_pretrained( + "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" ).to("cuda").eval() # Encode and decode -x = torch.randn(1, 3, 256, 256, device="cuda") +x = torch.randn(1, 3, 224, 224, device="cuda") with torch.no_grad(): latents = model.encode(x).latent recon = model.decode(latents).sample ``` -`encoder_cls` supports `"dinov2"`, `"siglip2"`, and `"mae"`. +`encoder_type` supports `"dinov2"`, `"siglip2"`, and `"mae"`. The encoder is built from config +(with random weights) during `__init__`; use `from_pretrained` to load a converted checkpoint +that includes both encoder and decoder weights. For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders). diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py index 72f030a4286e..a73ea30a67cf 100644 --- a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -73,8 +73,9 @@ def parse_args(): parser.add_argument("--checkpointing_steps", type=int, default=1000) parser.add_argument("--validation_steps", type=int, default=500) - parser.add_argument("--encoder_cls", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2") - parser.add_argument("--encoder_name_or_path", type=str, default=None) + parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2") + parser.add_argument("--encoder_hidden_size", type=int, default=768) + parser.add_argument("--encoder_patch_size", type=int, default=14) parser.add_argument("--encoder_input_size", type=int, default=224) parser.add_argument("--patch_size", type=int, default=16) parser.add_argument("--image_size", type=int, default=256) @@ -123,7 +124,9 @@ def build_transforms(args): return transforms.Compose(image_transforms) -def compute_losses(model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float): +def compute_losses( + model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float +): decoded = model(pixel_values).sample if decoded.shape[-2:] != pixel_values.shape[-2:]: @@ -198,8 +201,9 @@ def collate_fn(examples): ) model = AutoencoderRAE( - encoder_cls=args.encoder_cls, - encoder_name_or_path=args.encoder_name_or_path, + encoder_type=args.encoder_type, + encoder_hidden_size=args.encoder_hidden_size, + encoder_patch_size=args.encoder_patch_size, decoder_hidden_size=args.decoder_hidden_size, decoder_num_hidden_layers=args.decoder_num_hidden_layers, decoder_num_attention_heads=args.decoder_num_attention_heads, diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index 39285ad61569..6ef6ed68daac 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -36,6 +36,18 @@ "mae": "facebook/vit-mae-base", } +ENCODER_HIDDEN_SIZE = { + "dinov2": 768, + "siglip2": 768, + "mae": 768, +} + +ENCODER_PATCH_SIZE = { + "dinov2": 14, + "siglip2": 16, + "mae": 16, +} + DEFAULT_DECODER_SUBDIR = { "dinov2": "decoders/dinov2/wReg_base", "mae": "decoders/mae/base_p16", @@ -121,14 +133,14 @@ def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> di def resolve_decoder_file( - accessor: RepoAccessor, encoder_cls: str, variant: str, decoder_checkpoint: str | None + accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None ) -> str: if decoder_checkpoint is not None: if accessor.exists(decoder_checkpoint): return decoder_checkpoint raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}") - base = f"{DEFAULT_DECODER_SUBDIR[encoder_cls]}/{variant}" + base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}" for name in DECODER_FILE_CANDIDATES: candidate = f"{base}/{name}" if accessor.exists(candidate): @@ -141,7 +153,7 @@ def resolve_decoder_file( def resolve_stats_file( accessor: RepoAccessor, - encoder_cls: str, + encoder_type: str, dataset_name: str, stats_checkpoint: str | None, ) -> str | None: @@ -150,7 +162,7 @@ def resolve_stats_file( return stats_checkpoint raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}") - base = DEFAULT_STATS_SUBDIR[encoder_cls] + base = DEFAULT_STATS_SUBDIR[encoder_type] for dataset in dataset_case_candidates(dataset_name): for name in STATS_FILE_CANDIDATES: candidate = f"{base}/{dataset}/{name}" @@ -181,12 +193,33 @@ def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]: return mean, latents_std +def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]: + """Download the HF encoder and extract the state dict for the inner model.""" + if encoder_type == "dinov2": + from transformers import Dinov2WithRegistersModel + + hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path) + return hf_model.state_dict() + elif encoder_type == "siglip2": + from transformers import SiglipModel + + hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + return hf_model.state_dict() + elif encoder_type == "mae": + from transformers import ViTMAEForPreTraining + + hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + return hf_model.state_dict() + else: + raise ValueError(f"Unknown encoder_type: {encoder_type}") + + def convert(args: argparse.Namespace) -> None: accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) - encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_cls] + encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type] - decoder_relpath = resolve_decoder_file(accessor, args.encoder_cls, args.variant, args.decoder_checkpoint) - stats_relpath = resolve_stats_file(accessor, args.encoder_cls, args.dataset_name, args.stats_checkpoint) + decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint) + stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint) print(f"Using decoder checkpoint: {decoder_relpath}") if stats_relpath is not None: @@ -210,13 +243,39 @@ def convert(args: argparse.Namespace) -> None: decoder_cfg = DECODER_CONFIGS[args.decoder_config_name] + # Read encoder normalization stats from the HF image processor (only place that downloads encoder info) + from transformers import AutoConfig, AutoImageProcessor + + proc = AutoImageProcessor.from_pretrained(encoder_name_or_path) + encoder_norm_mean = list(proc.image_mean) + encoder_norm_std = list(proc.image_std) + + # Read encoder hidden size and patch size from HF config + encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type] + encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type] + try: + hf_config = AutoConfig.from_pretrained(encoder_name_or_path) + # For models like SigLIP that nest vision config + if hasattr(hf_config, "vision_config"): + hf_config = hf_config.vision_config + encoder_hidden_size = hf_config.hidden_size + encoder_patch_size = hf_config.patch_size + except Exception: + pass + + # Load the actual encoder weights from HF to include in the saved model + encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path) + model = AutoencoderRAE( - encoder_cls=args.encoder_cls, - encoder_name_or_path=encoder_name_or_path, + encoder_type=args.encoder_type, + encoder_hidden_size=encoder_hidden_size, + encoder_patch_size=encoder_patch_size, encoder_input_size=args.encoder_input_size, patch_size=args.patch_size, image_size=args.image_size, num_channels=args.num_channels, + encoder_norm_mean=encoder_norm_mean, + encoder_norm_std=encoder_norm_std, decoder_hidden_size=decoder_cfg["decoder_hidden_size"], decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"], decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"], @@ -226,6 +285,10 @@ def convert(args: argparse.Namespace) -> None: scaling_factor=args.scaling_factor, ) + # Load encoder weights + encoder_load_result = model.encoder.model.load_state_dict(encoder_state_dict, strict=True) + print(f"Encoder weights loaded: {encoder_load_result}") + load_result = model.decoder.load_state_dict(decoder_state_dict, strict=False) allowed_missing = {"trainable_cls_token"} missing = set(load_result.missing_keys) @@ -242,8 +305,10 @@ def convert(args: argparse.Namespace) -> None: metadata = { "source": args.repo_or_path, - "encoder_cls": args.encoder_cls, + "encoder_type": args.encoder_type, "encoder_name_or_path": encoder_name_or_path, + "encoder_hidden_size": encoder_hidden_size, + "encoder_patch_size": encoder_patch_size, "decoder_checkpoint": decoder_relpath, "stats_checkpoint": stats_relpath, "variant": args.variant, @@ -272,8 +337,10 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model") - parser.add_argument("--encoder_cls", type=str, choices=["dinov2", "mae", "siglip2"], required=True) - parser.add_argument("--encoder_name_or_path", type=str, default=None, help="Optional encoder HF id/path override") + parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True) + parser.add_argument( + "--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override" + ) parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name") parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name") diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 9e4e6ed384ca..b68363c53ffb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from math import sqrt from types import SimpleNamespace -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -32,42 +32,38 @@ from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput -ENCODER_ARCHS: Dict[str, Type] = {} -ENCODER_DEFAULT_NAME_OR_PATH = { - "dinov2": "facebook/dinov2-with-registers-base", - "siglip2": "google/siglip2-base-patch16-256", - "mae": "facebook/vit-mae-base", -} logger = logging.get_logger(__name__) -def register_encoder(cls: Optional[Type] = None, *, name: Optional[str] = None) -> Union[Callable[[Type], Type], Type]: - def decorator(inner_cls: Type) -> Type: - encoder_name = name or inner_cls.__name__ - if encoder_name in ENCODER_ARCHS and ENCODER_ARCHS[encoder_name] is not inner_cls: - raise ValueError(f"Encoder '{encoder_name}' is already registered.") - ENCODER_ARCHS[encoder_name] = inner_cls - return inner_cls - - if cls is None: - return decorator - return decorator(cls) - - -@register_encoder(name="dinov2") class Dinov2Encoder(nn.Module): - def __init__(self, encoder_name_or_path: str = "facebook/dinov2-with-registers-base"): + def __init__( + self, + hidden_size: int = 768, + patch_size: int = 14, + image_size: int = 224, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + num_register_tokens: int = 4, + ): super().__init__() - from transformers import Dinov2WithRegistersModel - - self.model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path) + from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + config = Dinov2WithRegistersConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=image_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_register_tokens=num_register_tokens, + ) + self.model = Dinov2WithRegistersModel(config) self.model.requires_grad_(False) self.model.layernorm.elementwise_affine = False self.model.layernorm.weight = None self.model.layernorm.bias = None - self.patch_size = self.model.config.patch_size - self.hidden_size = self.model.config.hidden_size + self.patch_size = patch_size + self.hidden_size = hidden_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ @@ -83,21 +79,36 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te return image_features -@register_encoder(name="siglip2") class Siglip2Encoder(nn.Module): - def __init__(self, encoder_name_or_path: str = "google/siglip2-base-patch16-256"): + def __init__( + self, + hidden_size: int = 768, + patch_size: int = 16, + image_size: int = 256, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + ): super().__init__() - from transformers import SiglipModel - - self.model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + from transformers import SiglipVisionConfig, SiglipVisionModel + + config = SiglipVisionConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=image_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + ) + self.model = SiglipVisionModel(config) self.model.requires_grad_(False) # remove the affine of final layernorm self.model.post_layernorm.elementwise_affine = False # remove the param self.model.post_layernorm.weight = None self.model.post_layernorm.bias = None - self.hidden_size = self.model.config.hidden_size - self.patch_size = self.model.config.patch_size + self.hidden_size = hidden_size + self.patch_size = patch_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ @@ -112,22 +123,37 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te return image_features -@register_encoder(name="mae") class MAEEncoder(nn.Module): - def __init__(self, encoder_name_or_path: str = "facebook/vit-mae-base"): + def __init__( + self, + hidden_size: int = 768, + patch_size: int = 16, + image_size: int = 224, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + ): super().__init__() - from transformers import ViTMAEForPreTraining - - self.model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + from transformers import ViTMAEConfig, ViTMAEModel + + config = ViTMAEConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=image_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + mask_ratio=0.0, + ) + self.model = ViTMAEModel(config) self.model.requires_grad_(False) # remove the affine of final layernorm self.model.layernorm.elementwise_affine = False # remove the param self.model.layernorm.weight = None self.model.layernorm.bias = None - self.hidden_size = self.model.config.hidden_size - self.patch_size = self.model.config.patch_size - self.model.config.mask_ratio = 0.0 # no masking + self.hidden_size = hidden_size + self.patch_size = patch_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ @@ -148,6 +174,20 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te return image_features +_ENCODER_TYPES: Dict[str, Type] = { + "dinov2": Dinov2Encoder, + "siglip2": Siglip2Encoder, + "mae": MAEEncoder, +} + +# Default encoder image sizes matching the HF pretrained models +_ENCODER_DEFAULT_IMAGE_SIZE: Dict[str, int] = { + "dinov2": 518, + "siglip2": 256, + "mae": 224, +} + + @dataclass class RAEDecoderOutput(BaseOutput): """ @@ -409,11 +449,12 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): all models (such as downloading or saving). Args: - encoder_cls (`str`, *optional*, defaults to `"dinov2"`): + encoder_type (`str`, *optional*, defaults to `"dinov2"`): Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. - encoder_name_or_path (`str`, *optional*): - Path to pretrained encoder model or model identifier from huggingface.co/models. If not provided, uses an - encoder-specific default model id. + encoder_hidden_size (`int`, *optional*, defaults to `768`): + Hidden size of the encoder model. + encoder_patch_size (`int`, *optional*, defaults to `14`): + Patch size of the encoder model. patch_size (`int`, *optional*, defaults to `16`): Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): @@ -421,9 +462,13 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): image_size (`int`, *optional*): Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // - encoder.patch_size) ** 2`. + encoder_patch_size) ** 2`. num_channels (`int`, *optional*, defaults to `3`): Number of input/output channels. + encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + Channel-wise mean for encoder input normalization (ImageNet defaults). + encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + Channel-wise std for encoder input normalization (ImageNet defaults). latents_mean (`list` or `tuple`, *optional*): Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable lists. @@ -445,8 +490,9 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): @register_to_config def __init__( self, - encoder_cls: str = "dinov2", - encoder_name_or_path: Optional[str] = None, + encoder_type: str = "dinov2", + encoder_hidden_size: int = 768, + encoder_patch_size: int = 14, decoder_hidden_size: int = 512, decoder_num_hidden_layers: int = 8, decoder_num_attention_heads: int = 16, @@ -455,6 +501,8 @@ def __init__( encoder_input_size: int = 224, image_size: Optional[int] = None, num_channels: int = 3, + encoder_norm_mean: Optional[list] = None, + encoder_norm_std: Optional[list] = None, latents_mean: Optional[Union[list, tuple, torch.Tensor]] = None, latents_std: Optional[Union[list, tuple, torch.Tensor]] = None, noise_tau: float = 0.0, @@ -464,10 +512,8 @@ def __init__( ): super().__init__() - if encoder_cls not in ENCODER_ARCHS: - raise ValueError(f"Unknown encoder_cls='{encoder_cls}'. Available: {sorted(ENCODER_ARCHS.keys())}") - if encoder_name_or_path is None: - encoder_name_or_path = ENCODER_DEFAULT_NAME_OR_PATH[encoder_cls] + if encoder_type not in _ENCODER_TYPES: + raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_TYPES.keys())}") def _to_config_compatible(value: Any) -> Any: if isinstance(value, torch.Tensor): @@ -489,7 +535,6 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. self.register_to_config( - encoder_name_or_path=encoder_name_or_path, latents_mean=_to_config_compatible(latents_mean), latents_std=_to_config_compatible(latents_std), ) @@ -499,17 +544,17 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op self.reshape_to_2d = bool(reshape_to_2d) self.use_encoder_loss = bool(use_encoder_loss) - # Frozen representation encoder - self.encoder: nn.Module = ENCODER_ARCHS[encoder_cls](encoder_name_or_path=encoder_name_or_path) + # Frozen representation encoder (built from config, no downloads) + encoder_patch_size = int(encoder_patch_size) + encoder_image_size = _ENCODER_DEFAULT_IMAGE_SIZE.get(encoder_type, 224) + self.encoder: nn.Module = _ENCODER_TYPES[encoder_type]( + hidden_size=encoder_hidden_size, patch_size=encoder_patch_size, image_size=encoder_image_size + ) # RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2 - encoder_patch_size = getattr(self.encoder, "patch_size", None) - if encoder_patch_size is None: - raise ValueError(f"Encoder '{encoder_cls}' must define `.patch_size` attribute.") - encoder_patch_size = int(encoder_patch_size) if self.encoder_input_size % encoder_patch_size != 0: raise ValueError( - f"encoder_input_size={self.encoder_input_size} must be divisible by encoder.patch_size={encoder_patch_size}." + f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." ) num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 @@ -533,15 +578,16 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." ) - # Normalization stats from the encoder's image processor (strict, same as official RAE). - from transformers import AutoImageProcessor - - proc = AutoImageProcessor.from_pretrained(encoder_name_or_path) - encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1) - encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1) + # Encoder input normalization stats (ImageNet defaults) + if encoder_norm_mean is None: + encoder_norm_mean = [0.485, 0.456, 0.406] + if encoder_norm_std is None: + encoder_norm_std = [0.229, 0.224, 0.225] + encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) + encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) - self.register_buffer("encoder_mean", encoder_mean, persistent=True) - self.register_buffer("encoder_std", encoder_std, persistent=True) + self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True) + self.register_buffer("encoder_std", encoder_std_tensor, persistent=True) # Optional latent normalization (RAE-main uses mean/var) latents_mean_tensor = _as_optional_tensor(latents_mean) @@ -556,10 +602,6 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op self._latents_std = None # ViT-MAE style decoder - encoder_hidden_size = getattr(self.encoder, "hidden_size", None) - if encoder_hidden_size is None: - raise ValueError(f"Encoder '{encoder_cls}' must define `.hidden_size` attribute.") - decoder_config = SimpleNamespace( hidden_size=int(encoder_hidden_size), decoder_hidden_size=int(decoder_hidden_size), diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index d9fe351b62a3..6846e007f771 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -15,17 +15,16 @@ import gc import unittest -from unittest.mock import patch import torch import torch.nn.functional as F from diffusers.models.autoencoders.autoencoder_rae import ( + _ENCODER_TYPES, AutoencoderRAE, Dinov2Encoder, MAEEncoder, Siglip2Encoder, - register_encoder, ) from ...testing_utils import backend_empty_cache, enable_full_determinism, slow, torch_device @@ -39,12 +38,11 @@ MAE_MODEL_ID = "facebook/vit-mae-base" -@register_encoder(name="tiny_test") class TinyTestEncoder(torch.nn.Module): - def __init__(self, encoder_name_or_path: str = "unused"): + def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs): super().__init__() - self.patch_size = 8 - self.hidden_size = 16 + self.patch_size = patch_size + self.hidden_size = hidden_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) @@ -52,6 +50,9 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te return tokens.repeat(1, 1, self.hidden_size) +_ENCODER_TYPES["tiny_test"] = TinyTestEncoder + + class AutoencoderRAETests(unittest.TestCase): def tearDown(self): super().tearDown() @@ -60,8 +61,9 @@ def tearDown(self): def _make_model(self, **overrides) -> AutoencoderRAE: config = { - "encoder_cls": "tiny_test", - "encoder_name_or_path": "unused", + "encoder_type": "tiny_test", + "encoder_hidden_size": 16, + "encoder_patch_size": 8, "encoder_input_size": 32, "patch_size": 4, "image_size": 16, @@ -70,16 +72,14 @@ def _make_model(self, **overrides) -> AutoencoderRAE: "decoder_num_attention_heads": 4, "decoder_intermediate_size": 64, "num_channels": 3, + "encoder_norm_mean": [0.5, 0.5, 0.5], + "encoder_norm_std": [0.5, 0.5, 0.5], "noise_tau": 0.0, "reshape_to_2d": True, "scaling_factor": 1.0, } config.update(overrides) - with patch("transformers.AutoImageProcessor.from_pretrained") as mocked_from_pretrained: - mocked_from_pretrained.return_value = type( - "_MockProcessor", (), {"image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5]} - )() - return AutoencoderRAE(**config).to(torch_device) + return AutoencoderRAE(**config).to(torch_device) def test_fast_encode_decode_and_forward_shapes(self): model = self._make_model().eval() @@ -173,9 +173,7 @@ def tearDown(self): backend_empty_cache(torch_device) def test_dinov2_encoder_forward_shape(self): - dino_path = DINO_MODEL_ID - - encoder = Dinov2Encoder(encoder_name_or_path=dino_path).to(torch_device) + encoder = Dinov2Encoder().to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) @@ -185,9 +183,7 @@ def test_dinov2_encoder_forward_shape(self): assert y.shape[2] == encoder.hidden_size def test_siglip2_encoder_forward_shape(self): - siglip2_path = SIGLIP2_MODEL_ID - - encoder = Siglip2Encoder(encoder_name_or_path=siglip2_path).to(torch_device) + encoder = Siglip2Encoder().to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) @@ -197,9 +193,7 @@ def test_siglip2_encoder_forward_shape(self): assert y.shape[2] == encoder.hidden_size def test_mae_encoder_forward_shape(self): - mae_path = MAE_MODEL_ID - - encoder = MAEEncoder(encoder_name_or_path=mae_path).to(torch_device) + encoder = MAEEncoder().to(torch_device) x = torch.rand(1, 3, 224, 224, device=torch_device) y = encoder(x) @@ -216,30 +210,11 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - def test_autoencoder_rae_encode_decode_forward_shapes_dinov2(self): - # This is a shape & numerical-sanity test. The decoder is randomly initialized unless you load trained weights. - dino_path = DINO_MODEL_ID - - encoder_input_size = 224 - decoder_patch_size = 16 - # dinov2 patch=14 -> (224/14)^2 = 256 tokens -> decoder output 256 for patch 16 - image_size = 256 - - model = AutoencoderRAE( - encoder_cls="dinov2", - encoder_name_or_path=dino_path, - image_size=image_size, - encoder_input_size=encoder_input_size, - patch_size=decoder_patch_size, - # keep the decoder lightweight for test runtime - decoder_hidden_size=128, - decoder_num_hidden_layers=1, - decoder_num_attention_heads=4, - decoder_intermediate_size=256, - ).to(torch_device) + def test_autoencoder_rae_from_pretrained_dinov2(self): + model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device) model.eval() - x = torch.rand(1, 3, encoder_input_size, encoder_input_size, device=torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) with torch.no_grad(): latents = model.encode(x).latent @@ -247,8 +222,8 @@ def test_autoencoder_rae_encode_decode_forward_shapes_dinov2(self): assert latents.shape[0] == 1 decoded = model.decode(latents).sample - assert decoded.shape == (1, 3, image_size, image_size) + assert decoded.shape[0] == 1 + assert decoded.shape[1] == 3 recon = model(x).sample - assert recon.shape == (1, 3, image_size, image_size) assert torch.isfinite(recon).all().item() From 61885f37e3783d2a04cfe52b3bfb3a80cf3689ba Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 23 Feb 2026 09:59:26 +0000 Subject: [PATCH 21/30] added encoder_image_size config --- scripts/convert_rae_to_diffusers.py | 3 +++ src/diffusers/models/autoencoders/autoencoder_rae.py | 11 ++++------- .../autoencoders/test_models_autoencoder_rae.py | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index 6ef6ed68daac..39496ab16ea1 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -253,6 +253,7 @@ def convert(args: argparse.Namespace) -> None: # Read encoder hidden size and patch size from HF config encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type] encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type] + encoder_image_size = 518 # fallback default try: hf_config = AutoConfig.from_pretrained(encoder_name_or_path) # For models like SigLIP that nest vision config @@ -260,6 +261,7 @@ def convert(args: argparse.Namespace) -> None: hf_config = hf_config.vision_config encoder_hidden_size = hf_config.hidden_size encoder_patch_size = hf_config.patch_size + encoder_image_size = hf_config.image_size except Exception: pass @@ -270,6 +272,7 @@ def convert(args: argparse.Namespace) -> None: encoder_type=args.encoder_type, encoder_hidden_size=encoder_hidden_size, encoder_patch_size=encoder_patch_size, + encoder_image_size=encoder_image_size, encoder_input_size=args.encoder_input_size, patch_size=args.patch_size, image_size=args.image_size, diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index b68363c53ffb..5b7014fa5ca0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -180,12 +180,7 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te "mae": MAEEncoder, } -# Default encoder image sizes matching the HF pretrained models -_ENCODER_DEFAULT_IMAGE_SIZE: Dict[str, int] = { - "dinov2": 518, - "siglip2": 256, - "mae": 224, -} + @dataclass @@ -455,6 +450,8 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): Hidden size of the encoder model. encoder_patch_size (`int`, *optional*, defaults to `14`): Patch size of the encoder model. + encoder_image_size (`int`, *optional*, defaults to `518`): + Image size the encoder was pretrained with. Controls position embedding dimensions. patch_size (`int`, *optional*, defaults to `16`): Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): @@ -493,6 +490,7 @@ def __init__( encoder_type: str = "dinov2", encoder_hidden_size: int = 768, encoder_patch_size: int = 14, + encoder_image_size: int = 518, decoder_hidden_size: int = 512, decoder_num_hidden_layers: int = 8, decoder_num_attention_heads: int = 16, @@ -546,7 +544,6 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op # Frozen representation encoder (built from config, no downloads) encoder_patch_size = int(encoder_patch_size) - encoder_image_size = _ENCODER_DEFAULT_IMAGE_SIZE.get(encoder_type, 224) self.encoder: nn.Module = _ENCODER_TYPES[encoder_type]( hidden_size=encoder_hidden_size, patch_size=encoder_patch_size, image_size=encoder_image_size ) diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 6846e007f771..54b512e4929f 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -64,6 +64,7 @@ def _make_model(self, **overrides) -> AutoencoderRAE: "encoder_type": "tiny_test", "encoder_hidden_size": 16, "encoder_patch_size": 8, + "encoder_image_size": 32, "encoder_input_size": 32, "patch_size": 4, "image_size": 16, From 28a02eb226852ed58f0f3623fe83a2799381b953 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 23 Feb 2026 10:05:24 +0000 Subject: [PATCH 22/30] undo last change --- scripts/convert_rae_to_diffusers.py | 3 --- src/diffusers/models/autoencoders/autoencoder_rae.py | 9 ++------- tests/models/autoencoders/test_models_autoencoder_rae.py | 1 - 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index 39496ab16ea1..6ef6ed68daac 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -253,7 +253,6 @@ def convert(args: argparse.Namespace) -> None: # Read encoder hidden size and patch size from HF config encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type] encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type] - encoder_image_size = 518 # fallback default try: hf_config = AutoConfig.from_pretrained(encoder_name_or_path) # For models like SigLIP that nest vision config @@ -261,7 +260,6 @@ def convert(args: argparse.Namespace) -> None: hf_config = hf_config.vision_config encoder_hidden_size = hf_config.hidden_size encoder_patch_size = hf_config.patch_size - encoder_image_size = hf_config.image_size except Exception: pass @@ -272,7 +270,6 @@ def convert(args: argparse.Namespace) -> None: encoder_type=args.encoder_type, encoder_hidden_size=encoder_hidden_size, encoder_patch_size=encoder_patch_size, - encoder_image_size=encoder_image_size, encoder_input_size=args.encoder_input_size, patch_size=args.patch_size, image_size=args.image_size, diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index 5b7014fa5ca0..f83a06489f69 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -40,7 +40,7 @@ def __init__( self, hidden_size: int = 768, patch_size: int = 14, - image_size: int = 224, + image_size: int = 518, num_hidden_layers: int = 12, num_attention_heads: int = 12, num_register_tokens: int = 4, @@ -181,8 +181,6 @@ def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Te } - - @dataclass class RAEDecoderOutput(BaseOutput): """ @@ -450,8 +448,6 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): Hidden size of the encoder model. encoder_patch_size (`int`, *optional*, defaults to `14`): Patch size of the encoder model. - encoder_image_size (`int`, *optional*, defaults to `518`): - Image size the encoder was pretrained with. Controls position embedding dimensions. patch_size (`int`, *optional*, defaults to `16`): Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): @@ -490,7 +486,6 @@ def __init__( encoder_type: str = "dinov2", encoder_hidden_size: int = 768, encoder_patch_size: int = 14, - encoder_image_size: int = 518, decoder_hidden_size: int = 512, decoder_num_hidden_layers: int = 8, decoder_num_attention_heads: int = 16, @@ -545,7 +540,7 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op # Frozen representation encoder (built from config, no downloads) encoder_patch_size = int(encoder_patch_size) self.encoder: nn.Module = _ENCODER_TYPES[encoder_type]( - hidden_size=encoder_hidden_size, patch_size=encoder_patch_size, image_size=encoder_image_size + hidden_size=encoder_hidden_size, patch_size=encoder_patch_size ) # RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2 diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py index 54b512e4929f..6846e007f771 100644 --- a/tests/models/autoencoders/test_models_autoencoder_rae.py +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -64,7 +64,6 @@ def _make_model(self, **overrides) -> AutoencoderRAE: "encoder_type": "tiny_test", "encoder_hidden_size": 16, "encoder_patch_size": 8, - "encoder_image_size": 32, "encoder_input_size": 32, "patch_size": 4, "image_size": 16, From b297868201767a5ffd7b05dacf886806716f31c8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 25 Feb 2026 13:38:22 +0000 Subject: [PATCH 23/30] fixes from pretrained weights --- scripts/convert_rae_to_diffusers.py | 5 ++- .../models/autoencoders/autoencoder_rae.py | 45 +++++++++++-------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py index 6ef6ed68daac..625c0e3cc76f 100644 --- a/scripts/convert_rae_to_diffusers.py +++ b/scripts/convert_rae_to_diffusers.py @@ -203,8 +203,11 @@ def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> elif encoder_type == "siglip2": from transformers import SiglipModel + # SiglipModel.vision_model is a SiglipVisionTransformer. + # Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it + # under .vision_model, so we add the prefix to match the diffusers key layout. hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model - return hf_model.state_dict() + return {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()} elif encoder_type == "mae": from transformers import ViTMAEForPreTraining diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index f83a06489f69..a40ab10f22ca 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -42,19 +42,18 @@ def __init__( patch_size: int = 14, image_size: int = 518, num_hidden_layers: int = 12, - num_attention_heads: int = 12, - num_register_tokens: int = 4, + **kwargs, ): super().__init__() from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + num_attention_heads = hidden_size // 64 # all dinov2 variants use head_dim=64 config = Dinov2WithRegistersConfig( hidden_size=hidden_size, patch_size=patch_size, image_size=image_size, - num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, - num_register_tokens=num_register_tokens, + num_hidden_layers=num_hidden_layers, ) self.model = Dinov2WithRegistersModel(config) self.model.requires_grad_(False) @@ -85,28 +84,26 @@ def __init__( hidden_size: int = 768, patch_size: int = 16, image_size: int = 256, - num_hidden_layers: int = 12, - num_attention_heads: int = 12, - intermediate_size: int = 3072, + **kwargs, ): super().__init__() from transformers import SiglipVisionConfig, SiglipVisionModel + num_attention_heads = hidden_size // 64 # all siglip2 variants use head_dim=64 + num_hidden_layers = kwargs.get("num_hidden_layers", 12) config = SiglipVisionConfig( hidden_size=hidden_size, patch_size=patch_size, image_size=image_size, - num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, ) self.model = SiglipVisionModel(config) self.model.requires_grad_(False) # remove the affine of final layernorm - self.model.post_layernorm.elementwise_affine = False - # remove the param - self.model.post_layernorm.weight = None - self.model.post_layernorm.bias = None + self.model.vision_model.post_layernorm.elementwise_affine = False + self.model.vision_model.post_layernorm.weight = None + self.model.vision_model.post_layernorm.bias = None self.hidden_size = hidden_size self.patch_size = patch_size @@ -129,20 +126,19 @@ def __init__( hidden_size: int = 768, patch_size: int = 16, image_size: int = 224, - num_hidden_layers: int = 12, - num_attention_heads: int = 12, - intermediate_size: int = 3072, + **kwargs, ): super().__init__() from transformers import ViTMAEConfig, ViTMAEModel + num_attention_heads = hidden_size // 64 # all MAE variants use head_dim=64 + num_hidden_layers = kwargs.get("num_hidden_layers", 12) config = ViTMAEConfig( hidden_size=hidden_size, patch_size=patch_size, image_size=image_size, - num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, mask_ratio=0.0, ) self.model = ViTMAEModel(config) @@ -319,6 +315,12 @@ def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None): self.trainable_cls_token = nn.Parameter(tensor) def _initialize_weights(self, num_patches: int): + # Skip initialization when parameters are on meta device (e.g. during + # accelerate.init_empty_weights() used by low_cpu_mem_usage loading). + # The weights will be loaded from the checkpoint afterwards. + if self.decoder_pos_embed.device.type == "meta": + return + grid_size = int(num_patches**0.5) pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], @@ -448,6 +450,8 @@ class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): Hidden size of the encoder model. encoder_patch_size (`int`, *optional*, defaults to `14`): Patch size of the encoder model. + encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): + Number of hidden layers in the encoder model. patch_size (`int`, *optional*, defaults to `16`): Decoder patch size (used for unpatchify and decoder head). encoder_input_size (`int`, *optional*, defaults to `224`): @@ -486,6 +490,7 @@ def __init__( encoder_type: str = "dinov2", encoder_hidden_size: int = 768, encoder_patch_size: int = 14, + encoder_num_hidden_layers: int = 12, decoder_hidden_size: int = 512, decoder_num_hidden_layers: int = 8, decoder_num_attention_heads: int = 16, @@ -540,7 +545,9 @@ def _as_optional_tensor(value: Optional[Union[torch.Tensor, list, tuple]]) -> Op # Frozen representation encoder (built from config, no downloads) encoder_patch_size = int(encoder_patch_size) self.encoder: nn.Module = _ENCODER_TYPES[encoder_type]( - hidden_size=encoder_hidden_size, patch_size=encoder_patch_size + hidden_size=encoder_hidden_size, + patch_size=encoder_patch_size, + num_hidden_layers=encoder_num_hidden_layers, ) # RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2 From b3ffd6344aacb79f855a968713875f4ffe16c0ed Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 26 Feb 2026 10:26:30 +0000 Subject: [PATCH 24/30] cleanups --- .../models/autoencoders/autoencoder_rae.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py index a40ab10f22ca..5091243e96cd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_rae.py +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,18 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging from ...utils.accelerate_utils import apply_forward_hook +from ...utils.import_utils import is_transformers_available + +if is_transformers_available(): + from transformers import ( + Dinov2WithRegistersConfig, + Dinov2WithRegistersModel, + SiglipVisionConfig, + SiglipVisionModel, + ViTMAEConfig, + ViTMAEModel, + ) + from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import Attention @@ -45,8 +57,6 @@ def __init__( **kwargs, ): super().__init__() - from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel - num_attention_heads = hidden_size // 64 # all dinov2 variants use head_dim=64 config = Dinov2WithRegistersConfig( hidden_size=hidden_size, @@ -61,8 +71,6 @@ def __init__( self.model.layernorm.weight = None self.model.layernorm.bias = None - self.patch_size = patch_size - self.hidden_size = hidden_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ @@ -87,8 +95,6 @@ def __init__( **kwargs, ): super().__init__() - from transformers import SiglipVisionConfig, SiglipVisionModel - num_attention_heads = hidden_size // 64 # all siglip2 variants use head_dim=64 num_hidden_layers = kwargs.get("num_hidden_layers", 12) config = SiglipVisionConfig( @@ -104,8 +110,6 @@ def __init__( self.model.vision_model.post_layernorm.elementwise_affine = False self.model.vision_model.post_layernorm.weight = None self.model.vision_model.post_layernorm.bias = None - self.hidden_size = hidden_size - self.patch_size = patch_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: """ @@ -129,8 +133,6 @@ def __init__( **kwargs, ): super().__init__() - from transformers import ViTMAEConfig, ViTMAEModel - num_attention_heads = hidden_size // 64 # all MAE variants use head_dim=64 num_hidden_layers = kwargs.get("num_hidden_layers", 12) config = ViTMAEConfig( @@ -148,7 +150,6 @@ def __init__( # remove the param self.model.layernorm.weight = None self.model.layernorm.bias = None - self.hidden_size = hidden_size self.patch_size = patch_size def forward(self, images: torch.Tensor, requires_grad: bool = False) -> torch.Tensor: From dca59233f627ee9ac54deb378bab2116fd580d4c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 26 Feb 2026 10:30:26 +0000 Subject: [PATCH 25/30] address reviews --- docs/source/en/api/models/autoencoder_rae.md | 67 ++++++++++++++------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md index 77af9ea0f21f..e03ec9d07eee 100644 --- a/docs/source/en/api/models/autoencoder_rae.md +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -1,4 +1,4 @@ -