Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
382aad0
feat: implement three RAE encoders(dinov2, siglip2, mae)
Jan 24, 2026
f82cecc
feat: finish first version of autoencoder_rae
Jan 28, 2026
a3926d7
Merge branch 'main' into rae
Ando233 Jan 28, 2026
3ecf89d
Merge branch 'main' into rae
kashif Feb 15, 2026
0850c8c
fix formatting
kashif Feb 15, 2026
24acab0
make fix-copies
kashif Feb 15, 2026
25bc9e3
initial doc
kashif Feb 15, 2026
f06ea7a
fix latent_mean / latent_var init types to accept config-friendly inputs
kashif Feb 15, 2026
d7cb124
use mean and std convention
kashif Feb 15, 2026
0d59b22
cleanup
kashif Feb 15, 2026
202b14f
add rae to diffusers script
kashif Feb 15, 2026
7cbbf27
use imports
kashif Feb 15, 2026
e6d4499
use attention
kashif Feb 15, 2026
6a9bde6
remove unneeded class
kashif Feb 15, 2026
9522e68
example traiing script
kashif Feb 15, 2026
906d79a
input and ground truth sizes have to be the same
kashif Feb 16, 2026
d3cbd5a
fix argument
kashif Feb 16, 2026
96520c4
move loss to training script
kashif Feb 16, 2026
fc52959
cleanup
kashif Feb 16, 2026
a4fc9f6
simplify mixins
kashif Feb 16, 2026
d06b501
fix training script
kashif Feb 16, 2026
d8b2983
Merge branch 'main' into rae
kashif Feb 17, 2026
c68b812
fix entrypoint for instantiating the AutoencoderRAE
kashif Feb 23, 2026
61885f3
added encoder_image_size config
kashif Feb 23, 2026
28a02eb
undo last change
kashif Feb 23, 2026
b297868
fixes from pretrained weights
kashif Feb 25, 2026
7debd07
Merge branch 'main' into rae
kashif Feb 26, 2026
b3ffd63
cleanups
kashif Feb 26, 2026
dca5923
address reviews
kashif Feb 26, 2026
c71cb44
Merge branch 'rae' of https://github.com/Ando233/diffusers into rae
kashif Feb 26, 2026
5c85781
fix train script to use pretrained
kashif Feb 26, 2026
d965cab
fix conversion script review
kashif Feb 26, 2026
663b580
latebt normalization buffers are now always registered with no-op def…
kashif Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
Expand Down
88 changes: 88 additions & 0 deletions docs/source/en/api/models/autoencoder_rae.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# AutoencoderRAE

The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.

RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).

The following RAE models are released and supported in Diffusers:

| Model | Encoder | Latent shape (224px input) |
|:------|:--------|:---------------------------|
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |

## Loading a pretrained model

```python
from diffusers import AutoencoderRAE

model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
```

## Encoding and decoding a real image

```python
import torch
from diffusers import AutoencoderRAE
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image

model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()

image = Image.open("cat.png").convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]

with torch.no_grad():
latents = model.encode(x).latent # (1, 768, 16, 16)
recon = model.decode(latents).sample # (1, 3, 256, 256)

recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
recon_image.save("recon.png")
```

## Latent normalization

Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.

```python
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()

# Latent normalization is handled automatically inside encode/decode
# when the checkpoint config includes latents_mean/latents_std.
with torch.no_grad():
latents = model.encode(x).latent # normalized latents
recon = model.decode(latents).sample
```

## AutoencoderRAE

[[autodoc]] AutoencoderRAE
- encode
- decode
- all

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
65 changes: 65 additions & 0 deletions examples/research_projects/autoencoder_rae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Training AutoencoderRAE

This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.

It follows the same high-level training recipe as the official RAE stage-1 setup:
- frozen encoder
- train decoder
- pixel reconstruction loss
- optional encoder feature consistency loss

## Quickstart

### Resume or finetune from pretrained weights

```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```

### Train from scratch with a pretrained encoder

```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also include the pretrained encoder path in the example command?

--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--encoder_type dinov2 \
--encoder_name_or_path facebook/dinov2-with-registers-base \
--encoder_input_size 224 \
--patch_size 16 \
--image_size 256 \
--decoder_hidden_size 1152 \
--decoder_num_hidden_layers 28 \
--decoder_num_attention_heads 16 \
--decoder_intermediate_size 4096 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```

Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.

Dataset format is expected to be `ImageFolder`-compatible:

```text
train_data_dir/
class_a/
img_0001.jpg
class_b/
img_0002.jpg
```
Loading