Add SANA-WM camera-controlled image-to-video pipeline#13881
Add SANA-WM camera-controlled image-to-video pipeline#13881lawrence-cj wants to merge 25 commits into
Conversation
…line
Adds the public SANA-WM bidirectional camera-controlled image-to-video
model as a first-class diffusers pipeline + transformer. Layout mirrors
``sana_video``: the model lives under ``src/diffusers/models/transformers/``
as a near-single-file (kernels split off so the ``@triton.jit`` decorators
don't drown the model body); the pipeline lives under
``src/diffusers/pipelines/sana_wm/``.
Files added:
src/diffusers/models/transformers/
├── transformer_sana_wm.py # SanaWMTransformer3DModel + blocks + helpers
└── transformer_sana_wm_kernels.py # fused Triton kernels + camera math
src/diffusers/pipelines/sana_wm/
├── __init__.py
├── pipeline_sana_wm.py
├── pipeline_output.py
├── refiner.py
└── cam_utils.py
Pipeline architecture:
* Stage 1: 1600M ``SanaWMTransformer3DModel`` DiT with bidirectional
GDN-Triton linear attention + UCPE camera-control branch, LTX-style
flow-matching Euler scheduler with per-token timesteps.
* Stage 2: LTX-2 sink-bidirectional Euler refiner (3 distilled sigma
steps, reuses diffusers' ``LTX2VideoTransformer3DModel`` +
``LTX2TextConnectors`` + Gemma-3 text encoder).
* Decode through the LTX-2 VAE (``AutoencoderKLLTX2Video``).
One-line usage:
pipe = SanaWMPipeline.from_pretrained(
"Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
torch_dtype=torch.bfloat16,
).to("cuda")
out = pipe(image=img, prompt="...", action="w-80,jw-40,w-40",
intrinsics=[fx, fy, cx, cy])
End-to-end smoke test (stage-1 + refiner + VAE decode) passes on H100.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…xport transformer_sana_wm.py: * License header switched to the "HuggingFace Team and SANA-WM Authors" style used by merged sana_video. * Imports rewritten in stdlib -> third-party -> diffusers order; use diffusers `from ...utils import logging` instead of stdlib `logging`. * Fix 9 `Optional[X]` annotations written as `X or None` (Python's `or` short-circuits and silently returns `X`). * Fix two `assert (cond, msg)` tuple-asserts in PatchEmbedMS3D.forward that always pass (SyntaxWarning at import time). * Remove duplicate `__all__` declarations (the second silently overwrote the first). * Remove dead `reset_bn` (imports a nonexistent `packages.apps.utils`, would crash on call). * Remove the duplicate `logger = logging.getLogger(__name__)` further down in the file. transformer_sana_wm_kernels.py: * License header normalized; collapse three duplicate triton/torch import blocks into one. pipeline_sana_wm.py: * License header normalized. * `_decode_latents` now returns `(T, H, W, 3)` float in [0, 1], matching the diffusers convention used by `VideoProcessor`. Returning uint8 silently broke `export_to_video`: it does `frame * 255` assuming float input, so uint8 overflows to `(-x) mod 256` and inverts colors. * `__call__` converts to PIL/uint8 only when `output_type="pil"`. * Intrinsics argument now accepts (4,), (F, 4), (3, 3), and (F, 3, 3) forms (auto-extracts fx, fy, cx, cy from a 3x3 K) and auto-trims to `num_frames` when a longer-than-needed trajectory is passed. * Inline `retrieve_timesteps` with the standard `# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps` marker, matching merged sana_video. * Docstrings + EXAMPLE_DOC_STRING updated to reflect the new return type. pipeline_output.py: * Update `frames` field docstring to describe the new float [0, 1] return. refiner.py, cam_utils.py, scripts/sana_wm/convert_sana_wm_to_diffusers.py: * License headers normalized. Docs: * New `docs/source/en/api/pipelines/sana_wm.md` and `docs/source/en/api/models/sana_wm_transformer3d.md`, modeled on sana_video.md / sana_video_transformer3d.md, wired into `docs/source/en/_toctree.yml` under Models and Pipelines. 5s end-to-end smoke test (81 frames @ 16fps, 30 stage-1 steps + 3-step LTX-2 refiner) passes on 1x H100 80GB with `enable_model_cpu_offload`. Round-trip diff vs raw float frames is 2.06/255 mean (h264 lossy noise), confirming the export_to_video fix.
…+ KV cache hooks)
The first cleanup pass only kept the legacy single-shot refiner path. That
path is what the model was *not* trained on — its docstring even says
"feeding the full sequence at once is out-of-distribution" — and its cost
is O(T^2) attention over the full latent volume, which made longer videos
unusable (~21 min per refiner step at 321 frames on an H100).
Port the chunk-causal AR mode from the upstream reference so the refiner
matches the training contract:
* `refine_latents` now defaults to `block_size=3, kv_max_frames=11`
(the canonical AR recipe). Pass `block_size=None` to fall back to the
legacy single-shot path.
* New `_refine_latents_ar` + `_RefinerChunkRunner` orchestrate the sliding
window: pre-capture pre-RoPE sink K/V on `z_sana[:source_sink_frames]`
at sigma=0, then for each `block_size`-frame chunk run a 3-step Euler
with prefix `{sink_k_pre, sink_v, sink_pe, history_k, history_v}` and
capture post-RoPE K/V to feed the next window. History is bounded to
`kv_max_frames - source_sink_frames` so per-block compute is constant.
* New `_predict_x0_active_block` runs the transformer on the active block
only (Q from active, K/V from prefix+active).
* New `_capture_block_kv` runs sigma=0 forward with a pre_rope/post_rope
capture flag set on each `attn1`.
* New `_forward_video_only_with_rope` takes a pre-built RoPE so each block
can use absolute frame positions in the source video.
* `_streaming_self_attention` extended with the `_kv_cache_capture`,
`_tf_capture_kv`, `_tf_kv_prefix` hook contract that AR mode uses to
inject and capture K/V on each block.
* New helpers: `_build_rotary_emb_for_absolute_positions`,
`_set_kv_prefix_on_blocks`, `_clear_kv_prefix_on_blocks`,
`_set_capture_flag_on_blocks`, `_collect_captured_kv_from_blocks`.
* `_encode_prompt` now also moves the Gemma-3 text encoder back to CPU
after producing the embeds — otherwise it stays resident through the
entire AR loop and gates how much GPU memory the refiner transformer
has left.
Module-level docstring updated to document both modes; existing
single-shot path preserved verbatim.
…eemption)
The AR refiner is expensive (~3-5 min per block) and the refinement loop
ran end-to-end has no in-progress state to recover, so a SLURM preemption
mid-refinement loses all progress. With the canonical
``block_size=3, kv_max_frames=11`` setup, refining a 50s video is 34
blocks of work that has to make it through without preemption on a
backfill queue.
Add per-block atomic checkpointing:
* ``SanaWMLTX2Refiner.refine_latents(checkpoint_dir=Path)`` and
``_refine_latents_ar`` accept a directory. After each completed AR
block, the AR loop writes ``checkpoint_dir/state.pt`` atomically
(tmp + os.replace).
* The payload is ``{block_idx_done, n_blocks, sink_size, block_size,
output_shape, output, runner_state}``. ``runner_state`` is a CPU snapshot
of the runner's ``_sink_kv_pre``, ``_history_kv_post``,
``_history_frames`` and ``torch.Generator`` state.
* On entry, if ``state.pt`` exists with a compatible shape signature, the
AR loop loads the persisted output tensor + runner state and resumes
from ``block_idx_done + 1`` instead of recomputing from scratch.
* ``SanaWMPipeline.__call__(refiner_checkpoint_dir=...)`` plumbs the
directory through to the refiner.
Checkpoint size: ~output_volume + sink_KV (~360MB for 50 layers) +
rolling history KV (~3-4GB at full capacity) — saved once per block,
total per-block save overhead ~10s on lustre.
* CPU unit tests for cam_utils helpers (action DSL → c2w, intrinsics rescale-for-crop, resize+center-crop, snap_num_frames 8k+1 rounding). * Public-surface registration tests (top-level diffusers symbols, SanaWMPipelineOutput dataclass shape, refiner signature has AR defaults + checkpoint_dir, pipeline __call__ accepts c2w/action/intrinsics/ refiner_checkpoint_dir). * @slow @require_torch_accelerator integration stub for an end-to-end I2V against the public checkpoint, currently @unittest.skip — wires up the nightly GPU path without exploding regular CI. SanaWMTransformer3DModel has hardcoded depth/hidden_size/num_heads inside its inner SanaMSVideoCamCtrl (not exposed through register_to_config), so the usual PipelineTesterMixin small-config fast tests aren't applicable without a transformer refactor (followup PR).
|
As a preliminary comment, would it be possible to use PyTorch ops instead of custom Triton kernels (or add pure PyTorch fallback paths) for now? We will work on supporting the custom kernels through |
Yes, love to do that. |
…ttention `transformer_sana_wm_kernels.py` previously did a hard `import triton` at the top of the file. That blocked importing the SANA-WM transformer on any environment without Triton (CPU-only, ROCm without Triton, older Triton, etc.), even though the model has pure-PyTorch attention classes for every `*Triton` variant. Make Triton optional and have the dispatcher transparently fall back: * Wrap `import triton` / `import triton.language as tl` in try/except. When unavailable, install a shim where `@triton.jit` is a no-op so the kernel function definitions still load (they just aren't compiled by Triton). Module-level `triton.X` / `tl.X` lookups return a self-shimming sentinel so signature parsing doesn't blow up either. * Add `is_triton_available()` + `_require_triton(entry_point)`. The four Triton-backed entry points called by the model (`fused_qk_inv_rms`, `fused_bigdn_func`, `cam_prep_func`, `cam_scan_bidi_chunkwise`) now raise a clear RuntimeError on a Triton-less host with a hint to use the pure-PyTorch attention variants — but the dispatcher does this automatically (see below) so users shouldn't ever see it. * Delete the leftover duplicate `import torch / triton / triton.language` block at line 262 (left over from the upstream port). * Register `BidirectionalGDNUCPESinglePathLiteLA` in `ATTENTION_BLOCKS` so the fallback chain can find it. * New `_resolve_attention_block(name, role)` walks the requested class's MRO at dispatch time. If Triton isn't usable AND the requested class name ends in `Triton`, route to the closest registered non-`Triton` ancestor (BidirectionalGDNUCPESinglePathLiteLABothTriton -> BidirectionalGDNUCPESinglePathLiteLA, etc.) and log a one-shot warning. * Rewire both `SanaVideoMSCamCtrlBlock` dispatch sites to use `_resolve_attention_block` for the GDN+UCPE camera branch and the main attention branch (the `BidirectionalSoftmaxUCPESinglePathLiteLA` branch doesn't use Triton at all so it stays hard-coded). Tests: * `test_kernels_module_imports_with_triton_hidden` — reloads the kernels module with `sys.modules['triton'] = None` and verifies the module imports, `is_triton_available()` is False, and the pure-PyTorch helpers remain callable. * `test_resolve_attention_block_cpu_fallback` — on a CPU-only host, the three `*Triton` attn types resolve to the correct non-Triton ancestor. * `test_triton_entry_point_raises_clean_error_without_triton` — verifies the `_require_triton` guard yields a RuntimeError that mentions Triton.
|
Done in
Triton remains the default on CUDA + Triton ≥ 3. CPU tests added under |
Three CI checks were failing on the PR: 1. `check_code_quality` (43 ruff errors): mix of unused imports / import sorting / E731 lambdas (auto-fixable) plus a handful of F821 dead-code references inherited from the upstream research codebase (`xformers.*` inside `if _xformers_available:` blocks, an undefined `BlockHook` type annotation, two `x_sa`/`mlp_out` references in a block forward whose live assignment was already overridden by subclasses). Ran `ruff check --fix --unsafe-fixes` + `ruff format`, fixed the type annotation manually, and added targeted `# noqa: F821` markers on the conditionally unreachable lines. 2. `check_torch_dependencies`: `transformer_sana_wm.py` hard-imported `einops`, `fla`, `timm`, `termcolor`. The minimum-deps CI environment doesn't have them, and diffusers' lazy loader rewrites `ModuleNotFoundError` as `RuntimeError` so `test_pipeline_imports` blew up. Wrapped each of the four optional imports in a try/except shim — `rearrange`/ `ShortConvolution`/`DropPath`/`Attention_`/`Mlp` become placeholders that raise a clear `ImportError` on construction, `colored` falls back to plain text. Class bodies that subclass these still parse at module load, so `import diffusers.models.transformers.transformer_sana_wm` succeeds anywhere. Same treatment for the kernels file's `from einops import rearrange, repeat`. 3. `build_pr_documentation`: doc-builder imported `SanaWMTransformer3DModel` from `diffusers.models.transformers` (not the diffusers top level) and that subpackage's `__init__.py` was missing the entry. Added the import.
* `doc-builder style src/diffusers docs/source --max_len 119` rewraps docstrings in the six SANA-WM files (transformer, kernels, pipeline, refiner, output, cam_utils) to the repo-wide 119-column limit. No behaviour change — purely whitespace inside docstrings. * `make fix-copies` regenerates `dummy_pt_objects.py` and `dummy_torch_and_transformers_objects.py` to add `DummyObject` stubs for the three new public classes (`SanaWMTransformer3DModel`, `SanaWMPipeline`, `SanaWMLTX2Refiner`), so `from diffusers import …` gives the standard "missing backend" message on installs without torch / transformers. Verified: `make quality` passes (ruff check, ruff format check, doc-builder style check_only, check_doc_toc). Test suite still 15 passed / 1 skipped.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Can you remove dead code (that isn't used by any existing Sana-WM checkpoint) from the PR so that it is easier to review? For example, in
|
Per @dg845's review on `transformer_sana_wm.py:44`. The 9 call sites all match well-known patterns that have one-liner torch equivalents: rearrange(s, "b d -> (b d)") -> s.reshape(b * d) rearrange(x, "(b d) d2 -> b (d d2)", ...) -> x.reshape(b, d * d2) rearrange(R, "b t h w i j -> b t h w j i") -> R.transpose(-1, -2) repeat(x, "b h w c -> b t h w c", t=T) -> x.unsqueeze(1).expand(-1, T, -1, -1, -1) repeat(x, "b t c -> b t h w c", h=H, w=W) -> x[:, :, None, None, :].expand(-1, -1, H, W, -1) repeat(x, "... -> b ...", b=B) -> x.unsqueeze(0).expand(B, *x.shape) repeat(x, "H W C -> B T H W C", B, T) -> x[None, None].expand(B, T, -1, -1, -1) repeat(x, "B H W C -> B T H W C", T) -> x.unsqueeze(1).expand(-1, T, -1, -1, -1) Each replacement is bit-identical to the einops original — verified against a fresh `einops` install on random tensors before swapping. The optional `from einops import ...` shim block is gone from both `transformer_sana_wm.py` and `transformer_sana_wm_kernels.py`.
Drop 23 unused symbols (1451 lines) from `transformer_sana_wm.py` that aren't reachable from the public SANA-WM checkpoint's `SanaMSVideoCamCtrl` -> `SanaVideoMSCamCtrlBlock` -> GDN/UCPE attention path. Each was verified to have zero call sites outside of its own definition (or only within other-deleted items). Classes: * `SanaMS`, `SanaMSBlock` — alternative `Sana` subclass + block, not used by the SANA-WM checkpoint (which goes through `SanaMSVideoCamCtrl`). * `ChunkCausalAttention`, `CachedCausalAttention`, `ChunkedLiteLAReLURope`, `LiteLAReLURope` — chunk-causal / cached attention variants and their common base; SANA-WM uses the bidi GDN path. `LiteLAReLURope` had only the three (now-deleted) subclasses referencing it. * `PAGCFGIdentitySelfAttnProcessorLiteLA`, `PAGIdentitySelfAttnProcessorLiteLA`, `SelfAttnProcessorLiteLA`, `SelfAttnProcessorLiteLAReLURope` — PAG processors; we don't expose PAG in the SANA-WM pipeline. * `ChunkGLUMBConvTemp`, `CachedGLUMBConvTemp`, `MBConvPreGLU` — alternative FFN/conv blocks; the checkpoint uses `GLUMBConvTemp`. * `MaskFinalLayer`, `DecoderLayer` — alternative final layers; the checkpoint uses `T2IFinalLayer`. * `LabelEmbedder`, `CaptionEmbedderDoubleBr` — alternative embedders; the checkpoint uses `CaptionEmbedder`. Helpers: * `set_grad_checkpoint`, `prepare_prompt_ar`, `resize_and_crop_tensor`, `generate_temporal_head_mask_mod`, `is_chunk_causal_request`, `get_chunk_index_from_config` — training-only or chunk-causal-mode helpers with no inference call sites. Verified after the diff: * `make quality` clean. * CPU test suite 15 passed / 1 skipped (slow GPU integration). * `import diffusers ; SanaWMPipeline / SanaWMTransformer3DModel / SanaWMLTX2Refiner` resolve identically.
|
Addressed both in
|
Two PR-CI failures, both ours: * `check_torch_dependencies`: line 33 hard-imported `transformers`, which isn't present in the minimum-deps environment. Move the lone `AutoModelForCausalLM` use site inside `initialize_gemma_params` (a training-only helper). * `check_repository_consistency`: `SanaWMTransformer3DModel.forward`'s docstring was missing entries for `mask` and `return_dict`. Added.
|
@dg845 Hi, gentle ping. Seems the PR CI test failing is not due to our code. Anything else to do? |
|
Hi @lawrence-cj, I am reviewing the code and will try to have a full review out soon. Thanks for your patience! |
| vae_scale_factor_spatial: int = 32 | ||
| vae_scale_factor_temporal: int = 8 |
There was a problem hiding this comment.
I think we should get these attributes from the vae component like LTX2Pipeline does rather than hardcoding them here. This would make it easier for the pipeline to support different VAEs.
| # Stage-1 DiT sampling — LTX-style per-token timesteps | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _sample_stage1( |
There was a problem hiding this comment.
I think we should inline the logic in _sample_stage1 into __call__, as this follows the design used by other pipelines. We could then define standard methods like prepare_latents, etc. to organize it.
| latent_channels = first_latent.shape[1] | ||
| do_cfg = guidance_scale > 1.0 | ||
|
|
||
| scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) |
There was a problem hiding this comment.
We should use the self.scheduler component rather than creating a new scheduler here.
| **cam_kwargs, | ||
| } | ||
|
|
||
| for t in tqdm(timesteps, disable=os.getenv("DPM_TQDM", "False") == "True"): |
There was a problem hiding this comment.
We should use self.progress_bar here. For example, this is what LTX2Pipeline does:
diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Lines 1213 to 1214 in 6d71b76
I think we should also respect progress bar config changes via DiffusionPipeline.set_progress_bar_config. For example, users can use it to disable the progress bar, as some tests do:
diffusers/tests/pipelines/ltx2/test_ltx2.py
Line 207 in 6d71b76
so I think we don't need to have an explicit condition for disable here.
| if isinstance(image, (str, Path)): | ||
| image = PIL.Image.open(image).convert("RGB") | ||
|
|
||
| if (c2w is None) == (action is None): | ||
| raise ValueError("Provide exactly one of `c2w` or `action`.") | ||
| if action is not None: | ||
| c2w = action_string_to_c2w(action) | ||
| c2w = np.asarray(c2w, dtype=np.float32) | ||
| if c2w.ndim != 3 or c2w.shape[1:] != (4, 4): | ||
| raise ValueError(f"`c2w` must be `(F, 4, 4)`; got {c2w.shape}.") |
There was a problem hiding this comment.
We should move the validation checks here (and below) into a separate check_inputs method.
| cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image, height, width) | ||
| intr = transform_intrinsics_for_crop(intr, src_size, resized_size, crop_offset) |
There was a problem hiding this comment.
I think it would be better to move the image and intrinsics pre-processing into a custom VaeImageProcessor subclass, similar to what Wan Animate does:
For example, I think resize_and_center_crop, estimate_intrinsics_with_pi3x, transform_intrinsics_for_crop, and the image normalization code in _encode_first_frame could all potentially be refactored into a custom image processor. CC @yiyixuxu
| c2w, intr, (height, width), device=device, dtype=dtype, do_cfg=guidance_scale > 1.0 | ||
| ) | ||
|
|
||
| generator = torch.Generator(device=device).manual_seed(seed) |
There was a problem hiding this comment.
I think we should make generator a __call__ argument, in line with other pipelines like LTX-2:
| return ( | ||
| SanaWMPipelineOutput(frames=latents.cpu(), c2w=c2w, latent=latents.cpu()) | ||
| if return_dict | ||
| else (latents.cpu(),) | ||
| ) |
There was a problem hiding this comment.
nit: I think we can remove the cpu() casts here since we don't normally place the output latents on CPU when output_type="latent".
| if output_type == "pil": | ||
| video_uint8 = (video.numpy() * 255.0).round().clip(0, 255).astype(np.uint8) | ||
| frames: list | np.ndarray = [PIL.Image.fromarray(f) for f in video_uint8] | ||
| elif output_type == "np": | ||
| frames = video.numpy() | ||
| else: | ||
| frames = video |
There was a problem hiding this comment.
We should use a VideoProcessor to post-process the generated video latents rather than re-implementing the post-processing logic here.
| @@ -0,0 +1,63 @@ | |||
| # SANA-WM diffusers pipeline | |||
There was a problem hiding this comment.
I think we should move the content here to the docs (e.g. docs/source/en/api/pipelines/sana_wm.md).
| STAGE_2_DISTILLED_SIGMA_VALUES: tuple[float, ...] = (0.909375, 0.725, 0.421875, 0.0) | ||
|
|
||
|
|
||
| class SanaWMLTX2Refiner(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
I think SanaWMLTX2Refiner makes more sense as a standalone pipeline (that is, as a DiffusionPipeline subclass rather than a ModelMixin subclass) since it wants to have components (e.g. self.transformer, self.connectors, etc.) and implements a denoising loop (e.g. in refine_latents). Can you refactor it to be a pipeline (using a scheduler to implement the denoising steps)?
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! I left an initial design review for the pipeline code, am still working on reviewing the modeling code.
Thanks so much for the review. I'll keep reformatting the code as your requirements. |
Per @dg845's review comments on pipeline_sana_wm.py: * Read the VAE spatial/temporal strides from the ``vae`` component (``vae_spatial_compression_ratio`` / ``vae_temporal_compression_ratio``) instead of hardcoding 32/8, mirroring LTX2Pipeline. * Inline the stage-1 sampling loop into ``__call__`` and factor the noise init into a standard ``prepare_latents`` method. * Use ``self.scheduler`` for the flow-matching Euler steps instead of constructing a new scheduler per call. * Drive the sampling loop with ``self.progress_bar`` (respects ``set_progress_bar_config``) instead of a bare tqdm. * Move input validation/normalization into a ``check_inputs`` method. * Move first-frame resize+center-crop, [-1, 1] normalization and the intrinsics rescale into a ``SanaWMImageProcessor(VaeImageProcessor)`` subclass (new ``image_processor.py``). * Add ``generator`` as a ``__call__`` argument (``seed`` kept as a convenience shortcut). * Post-process decoded latents with ``VideoProcessor.postprocess_video`` rather than a hand-rolled conversion; drop the ``.cpu()`` casts on the ``output_type="latent"`` path. * Move the pipeline README into the docs (docs/.../sana_wm.md); delete the in-package README.
Per @dg845's review: the refiner has components (transformer, connectors, text encoder, tokenizer) and runs a denoising loop, so it fits better as a pipeline than a ModelMixin. * SanaWMLTX2Refiner now subclasses DiffusionPipeline and registers its components via ``register_modules`` (dropping the bespoke ``from_pretrained`` / ``save_pretrained``); standard load/save now handle the ``refiner/`` subfolder. * Add a ``FlowMatchEulerDiscreteScheduler`` component (shift=1.0) and drive the Euler steps through ``scheduler.step`` / ``scheduler.scale_noise`` (single-shot and per-AR-block) instead of hand-rolled updates. Numerically equivalent to the previous flow-matching update. * Rename the entry point ``refine_latents`` -> ``__call__``; add a ``device`` arg so the parent can hand it the execution device without a bulk move. * SanaWMPipeline: keep the refiner as an optional nested component; free the parent's GPU weights before running it (it manages its own sub-module placement) and bring the VAE back for decode. * Conversion script emits the new refiner layout (model_index.json + scheduler/ + tokenizer/); test asserts the refiner is a DiffusionPipeline with the canonical AR ``__call__`` defaults. Validated end-to-end on 1xH100 (stage-1 + refiner on the official demo, coherent video output).
|
Pushed
|
What does this PR do?
Hi @sayakpaul @dg845 , Long time no see. Hoping your are doing great.♥️
Adds SANA-WM, the camera-controlled image-to-video world model from NVIDIA + MIT HAN Lab, as a first-class diffusers pipeline and transformer. Given a first-frame image, a text prompt, and a camera trajectory (explicit
c2wposes or a WASD/IJKL action-DSL string), the pipeline generates a video whose motion follows the requested camera path. Trained natively for minute-scale generation at 704×1280.The pipeline runs in two stages:
SanaWMTransformer3DModel. A 1.6B-parameter bidirectional DiT with GDN-Triton linear attention and a UCPE camera-control branch; samples with an LTX-style flow-matching Euler scheduler at per-token timesteps. The first latent frame is the conditioning anchor.SanaWMLTX2Refiner(optional). A chunk-causal AR refiner that wraps diffusers'LTX2VideoTransformer3DModel+LTX2TextConnectors+ Gemma-3 text encoder. Processes 3 latent frames at a time with a sliding window of[source_sink + recent_history + active_block]K/V, so per-block compute is bounded and total refinement cost is linear in video length.Both stages decode through
AutoencoderKLLTX2Video.Layout
Usage
Demo
5-second sample (30 stage-1 steps + 3-step distilled AR refiner, official
asset/sana_wm/demo_0inputs, 704×1280 @ 16 fps) :sana_wm_5s.mp4
Smoke tests
End-to-end on 1× H100 80GB with `enable_model_cpu_offload` and the official `asset/sana_wm/demo_0.{png,txt,_pose.npy,_intrinsics.npy}`:
Checkpoint conversion
scripts/sana_wm/convert_sana_wm_to_diffusers.py --src Efficient-Large-Model/SANA-WM_bidirectional --dst /local/pathconverts the public release into a `from_pretrained`-loadable directory (VAE, Gemma-2 tokenizer + text_encoder, transformer, scheduler, refiner subfolders, top-level `model_index.json`).Related
Paper: https://arxiv.org/abs/2605.15178