Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
66b5a07
feat(sana-wm): add diffusers-style SANA-WM camera-controlled I2V pipe…
HaoyiZhu May 25, 2026
a764dee
feat(sana-wm): align pipeline with merged sana_video style; fix mp4 e…
lawrence-cj Jun 2, 2026
bd08244
feat(sana-wm): port chunk-causal AR refiner mode (RefinerChunkRunner …
lawrence-cj Jun 2, 2026
34f0d81
feat(sana-wm): block-level checkpoint for AR refiner (resume after pr…
lawrence-cj Jun 2, 2026
4d8b2cc
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jun 9, 2026
44aa5cb
test(sana-wm): add CPU unit tests + slow GPU integration stub
lawrence-cj Jun 9, 2026
bf48b22
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jun 10, 2026
32ea160
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jun 15, 2026
1afe995
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
dg845 Jun 16, 2026
c0712d3
feat(sana-wm): make triton optional + auto-fallback to pure-PyTorch a…
lawrence-cj Jun 16, 2026
e1d13d5
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jun 18, 2026
efada20
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
dg845 Jun 24, 2026
0c23442
fix(sana-wm): make optional deps lazy + register transformer in __init__
lawrence-cj Jun 25, 2026
7b7dea1
style(sana-wm): apply make style + fix-copies
lawrence-cj Jun 25, 2026
c3fe5b4
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
dg845 Jun 25, 2026
eb7b3df
refactor(sana-wm): drop einops dependency, inline with torch ops
lawrence-cj Jun 25, 2026
1b81344
refactor(sana-wm): remove dead code per @dg845's review
lawrence-cj Jun 25, 2026
2723228
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
dg845 Jun 25, 2026
d352443
fix(sana-wm): defer transformers import + document mask/return_dict
lawrence-cj Jun 26, 2026
aa06c13
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jun 29, 2026
271d94f
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
dg845 Jul 1, 2026
08d3bde
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jul 2, 2026
f259221
refactor(sana-wm): address review feedback on SanaWMPipeline
lawrence-cj Jul 2, 2026
6317ce3
refactor(sana-wm): make SanaWMLTX2Refiner a standalone DiffusionPipeline
lawrence-cj Jul 2, 2026
b66576e
Merge branch 'main' into feat/sana-wm-diffusers-cleanup
lawrence-cj Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@
title: SanaTransformer2DModel
- local: api/models/sana_video_transformer3d
title: SanaVideoTransformer3DModel
- local: api/models/sana_wm_transformer3d
title: SanaWMTransformer3DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
Expand Down Expand Up @@ -605,6 +607,8 @@
title: Sana Sprint
- local: api/pipelines/sana_video
title: Sana Video
- local: api/pipelines/sana_wm
title: SANA-WM
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_cascade
Expand Down
46 changes: 46 additions & 0 deletions docs/source/en/api/models/sana_wm_transformer3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<!-- Copyright 2025 The HuggingFace Team and SANA-WM Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# SanaWMTransformer3DModel

A 3D Diffusion Transformer (1.6B parameters) for camera-controlled image-to-video generation, used as the stage-1
sampler of [`SanaWMPipeline`]. The transformer combines:

* a bidirectional GDN-Triton linear-attention main branch (depth 20, hidden 2240, 20 heads),
* a UCPE (Unified Camera Pose Embedding) camera-control branch that consumes a raymap + Plücker representation of
the requested trajectory, and
* a Wan-style 3D rotary position embedding plus periodic softmax-attention blocks injected every `softmax_every_n`
layers.

The state-dict layout matches the public SANA-WM release one-to-one — the diffusers wrapper places the inner DiT
under a `_inner.` prefix. See [`SanaWMTransformer3DModel.add_inner_prefix`] for the helper used by the conversion
script.

The model can be loaded with:

```python
import torch
from diffusers import SanaWMTransformer3DModel

transformer = SanaWMTransformer3DModel.from_pretrained(
"Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
```

## SanaWMTransformer3DModel

[[autodoc]] SanaWMTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
119 changes: 119 additions & 0 deletions docs/source/en/api/pipelines/sana_wm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
<!-- Copyright 2025 The HuggingFace Team and SANA-WM Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# SANA-WM

SANA-WM is a camera-controlled image-to-video world model built on top of SANA. Given a first-frame image, a text
prompt, and a camera trajectory (either explicit `c2w` poses or a WASD/IJKL action string), it generates a video
whose motion follows the requested camera path.

Inference runs in two stages:

1. **Stage 1 — SANA-WM DiT.** A 1.6B-parameter bidirectional DiT with GDN-Triton linear attention and a UCPE
camera-control branch. Sampling uses an LTX-style flow-matching Euler scheduler with per-token timesteps; the
first latent frame is the conditioning anchor.
2. **Stage 2 — LTX-2 refiner (optional).** A sink-bidirectional Euler refiner ([`SanaWMLTX2Refiner`]) that wraps
diffusers' own `LTX2VideoTransformer3DModel` + `LTX2TextConnectors` and a Gemma-3 text encoder, run for 3
distilled sigma steps.

Both stages decode through the [`AutoencoderKLLTX2Video`] VAE.

Available models:

| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/SANA-WM_bidirectional-diffusers`](https://huggingface.co/Efficient-Large-Model/SANA-WM_bidirectional-diffusers) | `torch.bfloat16` |

> [!TIP]
> SANA-WM is trained at a fixed 704×1280 resolution. The recommended dtype is for the transformer weights — keep
> the text encoder in `torch.bfloat16` and the VAE in `torch.float32` for best numerics. The pipeline expects
> camera intrinsics `[fx, fy, cx, cy]` in *original-image* pixel coordinates; the resize-and-center-crop transform
> is applied internally.

## Inference

```python
import torch
from PIL import Image

from diffusers import SanaWMPipeline
from diffusers.utils import export_to_video

pipe = SanaWMPipeline.from_pretrained(
"Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload() # ~45 GB of weights — offload between stages

output = pipe(
image=Image.open("input.png").convert("RGB"),
prompt="A car driving across a vast desert plain at golden hour.",
action="w-80,jw-40,w-40", # WASD-style action DSL: forward 80f, jump+forward 40f, forward 40f
intrinsics=[800.0, 800.0, 845.0, 464.0], # fx, fy, cx, cy in original-image pixels
num_frames=161,
num_inference_steps=60,
guidance_scale=5.0,
seed=42,
)
export_to_video(list(output.frames), "sana_wm.mp4", fps=16)
```

Pass `action=None` and supply your own `c2w` poses (`(F, 4, 4)` numpy array) to drive the camera trajectory
explicitly. Set `use_refiner=False` to skip stage 2.

If you don't have camera intrinsics, [`pi3-vision`](https://github.com/OliverSFAC/pi3-vision) can estimate them
from a single frame:

```python
from diffusers.pipelines.sana_wm.cam_utils import estimate_intrinsics_with_pi3x
intrinsics = estimate_intrinsics_with_pi3x(image) # `pip install pi3-vision`
```

## Converting the released checkpoint

If you have the source SANA-WM release (not the pre-converted diffusers snapshot), run the conversion script once:

```bash
python scripts/sana_wm/convert_sana_wm_to_diffusers.py \
--src Efficient-Large-Model/SANA-WM_bidirectional \
--dst ./SANA-WM_bidirectional-diffusers
```

Then load from the local path as usual.

## Components

- `tokenizer` — [`GemmaTokenizerFast`]
- `text_encoder` — Gemma-2 (returns decoder hidden states)
- `vae` — [`AutoencoderKLLTX2Video`] (LTX-2, spatial ×32 / temporal ×8)
- `transformer` — [`SanaWMTransformer3DModel`], 1.6B-parameter bidirectional DiT
- `scheduler` — [`FlowMatchEulerDiscreteScheduler`]
- `refiner` (optional) — [`SanaWMLTX2Refiner`], wraps `LTX2VideoTransformer3DModel`, `LTX2TextConnectors`, and a
Gemma-3 text encoder

## SanaWMPipeline

[[autodoc]] SanaWMPipeline
- all
- __call__

## SanaWMLTX2Refiner

The optional LTX-2 stage-2 refiner is itself a [`DiffusionPipeline`]. [`SanaWMPipeline`] runs it automatically when
`use_refiner=True`, but it can also be used standalone on stage-1 latents.

[[autodoc]] SanaWMLTX2Refiner
- all
- __call__

## SanaWMPipelineOutput

[[autodoc]] pipelines.sana_wm.pipeline_output.SanaWMPipelineOutput
187 changes: 187 additions & 0 deletions scripts/sana_wm/convert_sana_wm_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team and SANA-WM Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert the public SANA-WM release into a diffusers-loadable directory.

Reads the ``Efficient-Large-Model/SANA-WM_bidirectional`` HF repo (or a local
mirror) and writes a directory ready for ``SanaWMPipeline.from_pretrained(path)``:

<output_dir>/
├── model_index.json
├── tokenizer/
├── text_encoder/
├── vae/
├── transformer/
├── scheduler/
└── refiner/
├── transformer/
├── connectors/
├── text_encoder/
└── tokenizer/

Usage:
python scripts/sana_wm/convert_sana_wm_to_diffusers.py \\
--src Efficient-Large-Model/SANA-WM_bidirectional \\
--dst /path/to/SANA-WM_bidirectional-diffusers \\
[--no-refiner]

The output is local-only; no upload to the Hub.
"""

from __future__ import annotations

import argparse
import json
import shutil
from pathlib import Path

import torch
from huggingface_hub import snapshot_download


def _copy_subdir(src: Path, dst: Path) -> None:
if dst.exists():
shutil.rmtree(dst)
shutil.copytree(src, dst, symlinks=False)


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--src", default="Efficient-Large-Model/SANA-WM_bidirectional", help="HF repo or local dir")
parser.add_argument("--dst", required=True, type=Path, help="Output directory")
parser.add_argument("--no-refiner", action="store_true", help="Skip refiner export")
parser.add_argument(
"--torch-dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Weight dtype"
)
args = parser.parse_args()

torch_dtype = getattr(torch, args.torch_dtype)
dst: Path = args.dst.absolute()
dst.mkdir(parents=True, exist_ok=True)

# Resolve the source on disk (snapshot_download for HF repos, otherwise use as-is).
src_path = Path(args.src)
if not src_path.is_dir():
print(f"[convert] snapshot_download({args.src}) …")
src_path = Path(snapshot_download(args.src))
print(f"[convert] source: {src_path}")

# 1. VAE (already diffusers format under <src>/vae).
print("[convert] vae …")
_copy_subdir(src_path / "vae", dst / "vae")

# 2. Tokenizer + text encoder — fetch via the configured Gemma-2 repo.
# We save the full ``Gemma2ForCausalLM``; the pipeline grabs the decoder
# at runtime via ``self.text_encoder.model(...)``. This matches the sana
# inference recipe of ``AutoModelForCausalLM.from_pretrained(...).get_decoder()``
# and avoids subtle state-dict prefix differences when saving just the
# decoder submodule.
print("[convert] tokenizer + text_encoder (gemma-2-2b-it) …")
from transformers import AutoModelForCausalLM, AutoTokenizer

gemma_repo = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(gemma_repo)
tokenizer.padding_side = "right"
tokenizer.save_pretrained(dst / "tokenizer")
text_encoder = AutoModelForCausalLM.from_pretrained(gemma_repo, torch_dtype=torch_dtype)
text_encoder.save_pretrained(dst / "text_encoder")
del text_encoder

# 3. Transformer (SanaWMTransformer3DModel) — load the public DiT, save in diffusers format.
print("[convert] transformer (SanaWMTransformer3DModel) …")
from diffusers import SanaWMTransformer3DModel

transformer = SanaWMTransformer3DModel().to(torch_dtype).eval()
dit_ckpt = src_path / "dit" / "sana_wm_1600m_720p.safetensors"
if not dit_ckpt.is_file():
raise FileNotFoundError(f"DiT checkpoint not found at {dit_ckpt}")
from safetensors.torch import load_file

sd = load_file(str(dit_ckpt))
sd.pop("pos_embed", None) # unused at inference (wan_rope is computed on-the-fly)
sd = SanaWMTransformer3DModel.add_inner_prefix(sd)
missing, unexpected = transformer.load_state_dict(sd, strict=False)
if missing:
missing_nontrivial = [k for k in missing if not k.endswith(".pos_embed")]
if missing_nontrivial:
print(f" missing keys: {missing_nontrivial[:10]}{' …' if len(missing_nontrivial) > 10 else ''}")
if unexpected:
print(f" unexpected keys: {unexpected[:10]}{' …' if len(unexpected) > 10 else ''}")
transformer.save_pretrained(dst / "transformer")
del transformer, sd

# 4. Scheduler — FlowMatchEulerDiscreteScheduler config.
print("[convert] scheduler …")
from diffusers import FlowMatchEulerDiscreteScheduler

FlowMatchEulerDiscreteScheduler(shift=9.8).save_pretrained(dst / "scheduler")

# 5. Refiner (LTX-2): now a standalone DiffusionPipeline saved in the
# ``refiner/`` subfolder with its own ``model_index.json``. Copy the
# LTX-2 sub-model folders as-is, split out a ``tokenizer/`` folder, add a
# ``scheduler/`` (FlowMatchEulerDiscreteScheduler), and write the manifest.
if not args.no_refiner:
print("[convert] refiner …")
from transformers import AutoTokenizer

refiner_src = src_path / "refiner"
refiner_dst = dst / "refiner"
refiner_dst.mkdir(exist_ok=True)
for sub in ("transformer", "connectors", "text_encoder"):
if (refiner_src / sub).is_dir():
_copy_subdir(refiner_src / sub, refiner_dst / sub)

# Tokenizer lives co-located with the Gemma-3 text encoder in the release;
# re-save it into its own subfolder so it registers as a pipeline component.
refiner_tokenizer = AutoTokenizer.from_pretrained(refiner_src / "text_encoder")
refiner_tokenizer.save_pretrained(refiner_dst / "tokenizer")

# Scheduler carries the distilled sigma schedule; shift=1.0 leaves the
# explicit sigmas passed at inference time unmodified.
FlowMatchEulerDiscreteScheduler(shift=1.0).save_pretrained(refiner_dst / "scheduler")

refiner_index = {
"_class_name": "SanaWMLTX2Refiner",
"_diffusers_version": "0.38.0",
"transformer": ["diffusers", "LTX2VideoTransformer3DModel"],
# LTX2TextConnectors lives in diffusers.pipelines.ltx2 (not top-level),
# so the loader resolves it via the pipeline-module path ("ltx2", ...).
"connectors": ["ltx2", "LTX2TextConnectors"],
"tokenizer": ["transformers", type(refiner_tokenizer).__name__],
"text_encoder": ["transformers", "Gemma3ForConditionalGeneration"],
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_max_sequence_length": 1024,
}
(refiner_dst / "model_index.json").write_text(json.dumps(refiner_index, indent=2))

# 6. model_index.json — the top-level diffusers manifest.
print("[convert] model_index.json …")
index = {
"_class_name": "SanaWMPipeline",
"_diffusers_version": "0.38.0",
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"text_encoder": ["transformers", "Gemma2ForCausalLM"],
"vae": ["diffusers", "AutoencoderKLLTX2Video"],
"transformer": ["diffusers", "SanaWMTransformer3DModel"],
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
}
if not args.no_refiner:
index["refiner"] = ["diffusers", "SanaWMLTX2Refiner"]
(dst / "model_index.json").write_text(json.dumps(index, indent=2))

print(f"[convert] done — wrote {dst}")


if __name__ == "__main__":
main()
Loading
Loading