diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0613cd65d74d..5f836dd87294 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -198,6 +198,8 @@ title: Model accelerators and hardware - isExpanded: false sections: + - local: using-diffusers/anyflow + title: AnyFlow - local: using-diffusers/helios title: Helios - local: using-diffusers/consisid @@ -328,6 +330,10 @@ title: AceStepTransformer1DModel - local: api/models/allegro_transformer3d title: AllegroTransformer3DModel + - local: api/models/anyflow_transformer3d + title: AnyFlowTransformer3DModel + - local: api/models/anyflow_far_transformer3d + title: AnyFlowFARTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/transformer_bria_fibo @@ -506,6 +512,8 @@ - sections: - local: api/pipelines/animatediff title: AnimateDiff + - local: api/pipelines/anyflow + title: AnyFlow - local: api/pipelines/aura_flow title: AuraFlow - local: api/pipelines/bria_3_2 @@ -735,6 +743,8 @@ title: EulerAncestralDiscreteScheduler - local: api/schedulers/euler title: EulerDiscreteScheduler + - local: api/schedulers/flow_map_euler_discrete + title: FlowMapEulerDiscreteScheduler - local: api/schedulers/flow_match_euler_discrete title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete diff --git a/docs/source/en/api/models/anyflow_far_transformer3d.md b/docs/source/en/api/models/anyflow_far_transformer3d.md new file mode 100644 index 000000000000..d29c3fefc07d --- /dev/null +++ b/docs/source/en/api/models/anyflow_far_transformer3d.md @@ -0,0 +1,45 @@ + + +# AnyFlowFARTransformer3DModel + +The causal (FAR) 3D Transformer used by [`AnyFlowFARPipeline`](../pipelines/anyflow#anyflowfarpipeline) — +the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS +ShowLab × NVIDIA). It extends the v0.35.1 Wan2.1 backbone with three additions: + +1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting frame-level autoregressive + generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325). +2. **Compressed-frame patch embedding** (`far_patch_embedding`) for context (already-generated) frames, + warm-started from the full-resolution `patch_embedding` at construction time via trilinear interpolation. +3. **Dual-timestep flow-map embedding** (same as + [`AnyFlowTransformer3DModel`](anyflow_transformer3d)) — every forward call conditions on both the source + timestep ``t`` and the target timestep ``r``. + +The chunk schedule (`chunk_partition`) is **not** baked into the model config. It is a per-call argument to +`forward`, so the same checkpoint handles different `num_frames` configurations without retraining. + +```python +from diffusers import AnyFlowFARTransformer3DModel + +# Causal AnyFlow checkpoint (FAR): +transformer = AnyFlowFARTransformer3DModel.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", subfolder="transformer" +) +``` + +## AnyFlowFARTransformer3DModel + +[[autodoc]] AnyFlowFARTransformer3DModel + +## AnyFlowFARTransformerOutput + +[[autodoc]] models.transformers.transformer_anyflow.AnyFlowFARTransformerOutput diff --git a/docs/source/en/api/models/anyflow_transformer3d.md b/docs/source/en/api/models/anyflow_transformer3d.md new file mode 100644 index 000000000000..95888080c0ce --- /dev/null +++ b/docs/source/en/api/models/anyflow_transformer3d.md @@ -0,0 +1,36 @@ + + +# AnyFlowTransformer3DModel + +The bidirectional 3D Transformer used by [`AnyFlowPipeline`](../pipelines/anyflow#anyflowpipeline). It is the +v0.35.1 Wan2.1 backbone with one structural change: the timestep embedder is replaced by +``AnyFlowDualTimestepTextImageEmbedding``, so every forward call conditions on both the source timestep +``t`` and the target timestep ``r``. This is the embedding required to learn the flow map +:math:`\Phi_{r\leftarrow t}` introduced in +[AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS ShowLab × NVIDIA). + +For frame-level autoregressive (FAR causal) generation, use +[`AnyFlowFARTransformer3DModel`](anyflow_far_transformer3d) instead. + +```python +from diffusers import AnyFlowTransformer3DModel + +# Bidirectional AnyFlow checkpoint (T2V): +transformer = AnyFlowTransformer3DModel.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer" +) +``` + +## AnyFlowTransformer3DModel + +[[autodoc]] AnyFlowTransformer3DModel diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md new file mode 100644 index 000000000000..c8948cdb8f59 --- /dev/null +++ b/docs/source/en/api/pipelines/anyflow.md @@ -0,0 +1,216 @@ + + +
+
+ + LoRA + +
+
+ +# AnyFlow + +[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA. + +*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.* + +The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow). + +The following AnyFlow checkpoints are supported: + +| Checkpoint | Backbone | Description | +|------------|----------|-------------| +| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight | +| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality | +| [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V | +| [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / V2V | + +All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection. + +> [!TIP] +> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling. + +> [!TIP] +> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks. + +### Optimizing Memory and Inference Speed + + + + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + + + + +```py +import torch +from diffusers import AnyFlowPipeline + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + + + + +### Generation with AnyFlow (Bidirectional T2V) + + + + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "A red panda eating bamboo in a forest, cinematic lighting" +video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +### Generation with AnyFlow (FAR Causal) + +The causal pipeline selects between T2V / I2V / V2V via the ``context_sequence`` argument: pass ``None`` +for plain text-to-video, or a dict with a ``"raw"`` key holding a video tensor of shape +``(B, C, T, H, W)`` with ``T = 4n + 1`` to condition on existing frames. Use a single conditioning frame +for I2V and a longer clip for V2V continuation. + +> [!IMPORTANT] +> `AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) is matched to the +> released checkpoints' canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When +> you change `num_frames`, you must also pass a matching `chunk_partition` summing to +> `(num_frames - 1) // 4 + 1`, otherwise the pipeline raises an `AssertionError`. + + + + +```py +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +video = pipe( + prompt="A cat surfing a wave, sunset", + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +```py +import numpy as np +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video, load_image + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Wrap the conditioning image as a one-frame video tensor: (1, 3, 1, H, W) in [0, 1]. +first_frame = load_image("path/to/first_frame.png").resize((832, 480)) +arr = np.asarray(first_frame).astype("float32") / 255.0 # (480, 832, 3) +context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).to("cuda") + +video = pipe( + prompt="a cat walks across a sunlit lawn", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +```py +import numpy as np +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video, load_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Context clip — 9 raw frames map to 3 latent frames (9 = 4·2 + 1, 3 = 2 + 1). +context_frames = load_video("path/to/context.mp4")[:9] +arr = np.stack([np.asarray(f.resize((832, 480))) for f in context_frames]).astype("float32") / 255.0 +context_tensor = torch.from_numpy(arr).permute(3, 0, 1, 2).unsqueeze(0).to("cuda") # (1, 3, 9, 480, 832) + +video = pipe( + prompt="continue the story", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, + # Override chunk_partition so the first chunk covers exactly the 3 latent context frames. + chunk_partition=[3, 3, 3, 3, 3, 3, 3], +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +## Notes + +- Classifier-free guidance is fused into the released checkpoints, so inference does not run a second guided forward pass. Keep the default `guidance_scale=1.0` unless your own checkpoint requires otherwise. +- `FlowMapEulerDiscreteScheduler` is general-purpose. You can attach it to any flow-map-distilled checkpoint via `from_pretrained(..., scheduler=FlowMapEulerDiscreteScheduler.from_config(...))`. +- `AnyFlowPipeline` uses [`AnyFlowTransformer3DModel`](../models/anyflow_transformer3d) (bidirectional). `AnyFlowFARPipeline` uses [`AnyFlowFARTransformer3DModel`](../models/anyflow_far_transformer3d), which adds a compressed-frame patch embedding and the FAR causal block-mask. +- LoRA loading is supported via `WanLoraLoaderMixin`, the same mixin used by the upstream Wan pipelines. +- For training recipes (forward flow-map training and on-policy distillation), refer to the original AnyFlow training framework at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow); training is out of scope for diffusers. + +## AnyFlowPipeline + +[[autodoc]] AnyFlowPipeline + - all + - __call__ + +## AnyFlowFARPipeline + +[[autodoc]] AnyFlowFARPipeline + - all + - __call__ + +## AnyFlowPipelineOutput + +[[autodoc]] pipelines.anyflow.pipeline_output.AnyFlowPipelineOutput diff --git a/docs/source/en/api/schedulers/flow_map_euler_discrete.md b/docs/source/en/api/schedulers/flow_map_euler_discrete.md new file mode 100644 index 000000000000..27a0c8612d70 --- /dev/null +++ b/docs/source/en/api/schedulers/flow_map_euler_discrete.md @@ -0,0 +1,28 @@ + + +# FlowMapEulerDiscreteScheduler + +`FlowMapEulerDiscreteScheduler` is an Euler-style sampler designed for flow-map-distilled diffusion +models. Flow-map models learn arbitrary-interval transitions $\mathbf{z}_t \to \mathbf{z}_r$ rather than +the fixed $\mathbf{z}_t \to \mathbf{z}_0$ mapping of consistency models. Both endpoints of the step are +caller-provided, which is what enables any-step sampling: a single distilled checkpoint can be evaluated at +1, 2, 4, 8, 16... NFE without retraining. + +The scheduler was introduced in +[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) +and ships with the `AnyFlowPipeline` and `AnyFlowFARPipeline` integrations, but it is not +AnyFlow-specific — any flow-map-distilled checkpoint can use it. + +## FlowMapEulerDiscreteScheduler + +[[autodoc]] FlowMapEulerDiscreteScheduler diff --git a/docs/source/en/using-diffusers/anyflow.md b/docs/source/en/using-diffusers/anyflow.md new file mode 100644 index 000000000000..9bf3ba13f258 --- /dev/null +++ b/docs/source/en/using-diffusers/anyflow.md @@ -0,0 +1,260 @@ + + +# AnyFlow + +[AnyFlow](https://huggingface.co/papers/2605.13724) is a video diffusion **distillation** framework that turns +a pretrained Wan2.1 teacher into an *any-step* student under standard Euler sampling. A single distilled +checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining and quality scales **monotonically** +with steps — unlike consistency models, which often degrade as NFE grows. + +The key idea is to learn the **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$ for arbitrary +$1 \ge t \ge r \ge 0$ instead of the fixed endpoint map $\mathbf{z}_t \to \mathbf{z}_0$ used by consistency +models. Composability of the flow map removes re-noising between sampling steps; on-policy distillation with +**DMD reverse-divergence supervision** plus **Flow-Map backward simulation** (3-segment shortcut) closes the +exposure-bias gap that consistency-based distillation leaves open. + +AnyFlow was developed by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA. The original training code lives at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow); the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow). The four released checkpoints are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection. + +This guide walks through the practical decisions: which pipeline to pick, how to use any-step sampling, and +how to plug AnyFlow into typical T2V / I2V / V2V workflows. + +## Bidirectional vs causal — pick a pipeline + +AnyFlow ships in two flavors that share the same scheduler and the same flow-map distillation but differ in +how they sample frames: + +- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) — **bidirectional** T2V. Denoises the entire + video tensor in one pass with global self-attention. Use this when the input is a single text prompt and you + do not need streaming output. +- [`AnyFlowFARPipeline`](../api/pipelines/anyflow#anyflowfarpipeline) — **causal (FAR)**. Denoises the + video chunk by chunk with block-sparse causal attention and reuses KV cache across chunks. Use this for + image-to-video (I2V), video-to-video (V2V) continuation, or any setup that benefits from frame-level + autoregressive sampling. The same model handles all three task modes via the `context_sequence` argument. + +A quick selector: + +| Scenario | Pipeline | How to invoke | +|----------|----------|---------------| +| Pure text-to-video, max quality at fixed NFE | `AnyFlowPipeline` | `pipe(prompt, ...)` | +| Image-to-video (start from a still image) | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` | +| Video continuation / V2V | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` | +| Streaming / progressive generation | `AnyFlowFARPipeline` | — | + +The bidirectional variant is faster per token at high resolution; the causal variant trades that for the +ability to start sampling before all latent frames are allocated, useful for very long sequences. + +## Loading checkpoints + +NVIDIA released four AnyFlow checkpoints, one per pipeline + scale combination: + +```py +import torch +from diffusers import AnyFlowPipeline, AnyFlowFARPipeline + +# Bidirectional, lightweight +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Bidirectional, full quality +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 1.3B +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 14B +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +All four use the same [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete) with +`shift=5.0` baked in. + +## Any-step sampling + +The defining feature of AnyFlow is that the same checkpoint produces increasing quality as you raise NFE, +with no schedule retuning. Sweep step counts on a fixed prompt to see how the model trades latency for +fidelity: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "A red panda eating bamboo in a forest, cinematic lighting" + +for nfe in [1, 2, 4, 8, 16, 32]: + # Re-seed the generator inside the loop so the only changing variable across runs is NFE. + generator = torch.Generator("cuda").manual_seed(0) + video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0] + export_to_video(video, f"out_nfe{nfe}.mp4", fps=16) +``` + +In our benchmarks (paper Tab 3 / Fig 1) every AnyFlow checkpoint improves monotonically from 4 → 32 NFE +on VBench Quality, while consistency-based baselines (rCM, Self-Forcing) degrade in the same regime. + +> [!TIP] +> Classifier-free guidance (CFG) was *fused* into the released model weights during training. +> The pipeline does not run a second guided forward pass at inference time — +> guidance comes from the distilled weights themselves. Leave `guidance_scale=1.0` (the default) for the +> released checkpoints. + +## Image-to-video and video-to-video + +The causal pipeline supports three task modes from a single distilled model. The mode is selected +implicitly by the ``context_sequence`` argument (a dict with a ``"raw"`` video tensor or ``"latent"`` +pre-encoded latents). Frame counts in the context tensor must satisfy ``T = 4n + 1`` to align with the +VAE temporal stride. + +> [!IMPORTANT] +> The FAR pipeline runs a chunked rollout, and `num_frames` must agree with the chunk schedule. The default +> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]` sums to 21 latent frames, which is matched to the released +> checkpoints' canonical `num_frames=81` (21 = (81 − 1) // 4 + 1). When you change `num_frames`, you **must** +> pass a matching `chunk_partition` whose entries sum to `(num_frames - 1) // 4 + 1`, otherwise the pipeline +> raises an `AssertionError`. For example, `num_frames=33` corresponds to 9 latent frames, so a valid +> override is `chunk_partition=[1, 4, 4]`. + +```py +import numpy as np +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video, load_image, load_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + + +def to_video_tensor(images, height=480, width=832): + """Convert a list of PIL images into the (B, C, T, H, W) [0, 1] tensor the FAR pipeline consumes.""" + frames = np.stack([np.asarray(img.resize((width, height))) for img in images]).astype("float32") / 255.0 + return torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, T, H, W) + + +# 1) Text-to-video (no context). 81 frames matches the default chunk_partition. +video = pipe(prompt="A cat surfing a wave at sunset", num_inference_steps=4, num_frames=81).frames[0] +export_to_video(video, "t2v.mp4", fps=16) + +# 2) Image-to-video — a single conditioning frame produces 1 latent frame, which exactly fits the first +# entry of the default chunk_partition (`[1, 3, 3, ...]`). +first_frame = load_image("path/to/first_frame.png") +context_tensor = to_video_tensor([first_frame]).to("cuda") # (1, 3, 1, 480, 832), [0, 1] +video = pipe( + prompt="a cat walks across a sunlit lawn", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "i2v.mp4", fps=16) + +# 3) Video-to-video continuation. A 9-frame raw context maps to 3 latent frames; we override +# chunk_partition so the first chunk covers the whole context exactly. +context_frames = load_video("path/to/context.mp4")[:9] # 9 = 4·2 + 1 +context_tensor = to_video_tensor(context_frames).to("cuda") # (1, 3, 9, 480, 832) +video = pipe( + prompt="continue the story", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, + chunk_partition=[3, 3, 3, 3, 3, 3, 3], # 7 chunks × 3 = 21 latent frames; first chunk = context +).frames[0] +export_to_video(video, "v2v.mp4", fps=16) +``` + +Internally, the patchification chunk schedule depends on whether (and how long) ``context_sequence`` is set: +without context the model uses kernel sizes 2 (full) and 4 (compressed); with a context clip the first chunk +uses kernel size 1 so the conditioning frames keep full resolution. + +If you already have VAE-encoded latents, pass them via ``context_sequence={"latent": ...}`` to skip the +``vae_encode`` step. + +## Memory and inference speed + +A 14B AnyFlow model fits on a single 40 GB device with group offloading + VAE slicing: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + +For latency, `torch.compile` works well on the transformer (the heaviest module by far): + +```py +pipe = pipe.to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + +Compile costs are amortized after a few steps; combined with low NFE (4–8 for AnyFlow) `torch.compile` +delivers a noticeable speedup over eager mode on the 14B path. + +## LoRA fine-tuning + +Both pipelines reuse [`WanLoraLoaderMixin`](../api/loaders/lora), so any LoRA adapter trained for the +matching Wan2.1 backbone loads directly: + +```py +pipe.load_lora_weights("path/or/repo/with/wan_lora") +``` + +For continued **on-policy** fine-tuning with DMD-style reverse-divergence supervision (the same recipe used +to produce the released checkpoints), refer to the original AnyFlow training framework at +[`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow), which is out of scope for diffusers. + +## Common gotchas + +- **Always-1.0 `guidance_scale`.** The distilled checkpoints already encode CFG. Setting `guidance_scale > 1` + will run a redundant unconditional pass, double the latency, and slightly hurt quality. +- **Bidirectional pipeline does not stream.** All `num_frames` worth of latents are denoised together. Use + the causal pipeline if you want to start playback before sampling completes. +- **Causal pipeline KV cache assumes the chunk schedule is consistent across calls.** Rebuilding the cache + mid-generation is not supported by the released model. +- **`num_frames` must satisfy the VAE temporal stride.** Use values of the form `(N - 1) % 4 == 0` (e.g., 9, + 17, 33, 81) for the released checkpoints. + +## Citation + +```bibtex +@misc{gu2026anyflowanystepvideodiffusion, + title={AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation}, + author={Yuchao Gu and Guian Fang and Yuxin Jiang and Weijia Mao and Song Han and Han Cai and Mike Zheng Shou}, + year={2026}, + eprint={2605.13724}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2605.13724}, +} + +@article{gu2025long, + title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction}, + author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng}, + journal={arXiv preprint arXiv:2503.19325}, + year={2025} +} +``` diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index af51506746b2..b49820dd76e7 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -130,6 +130,8 @@ - title: Specific pipeline examples isExpanded: false sections: + - local: using-diffusers/anyflow + title: AnyFlow - local: using-diffusers/consisid title: ConsisID - local: using-diffusers/helios diff --git a/docs/source/zh/using-diffusers/anyflow.md b/docs/source/zh/using-diffusers/anyflow.md new file mode 100644 index 000000000000..9a20f2eafd9d --- /dev/null +++ b/docs/source/zh/using-diffusers/anyflow.md @@ -0,0 +1,244 @@ + + +# AnyFlow + +[AnyFlow](https://huggingface.co/papers/2605.13724) 是一个视频扩散**蒸馏**框架,把预训练的 Wan2.1 教师 +模型蒸馏成在标准 Euler 采样下支持*任意步数 (any-step)* 的学生模型。同一个蒸馏出来的 checkpoint 可以 +在 1、2、4、8、16... NFE 下推理,**质量随步数单调提升** —— 这一点和 consistency models 不同,后者 +NFE 增加反而经常掉点。 + +核心思路是学习 **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$(任意 $1 \ge t \ge r \ge 0$), +而不是 consistency models 学的固定端点映射 $\mathbf{z}_t \to \mathbf{z}_0$。Flow map 的可组合性消除了 +采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation** +(3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。 + +AnyFlow 由 Yuchao Gu、Guian Fang 等人在 [NUS ShowLab](https://sites.google.com/view/showlab) 与 NVIDIA 合作完成。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)。4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。 + +本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / V2V 工作流。 + +## Bidirectional 还是 Causal —— 怎么选 pipeline + +AnyFlow 提供两个 pipeline 形态,scheduler 和蒸馏方法相同,区别在于**怎么对帧采样**: + +- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) —— **bidirectional** T2V。一次性对整个 + 视频张量去噪,全局自注意力。**纯 prompt 输入、不要流式输出**时选这个。 +- [`AnyFlowFARPipeline`](../api/pipelines/anyflow#anyflowfarpipeline) —— **causal (FAR)**。 + 按 chunk 分段去噪,块稀疏因果注意力 + 跨 chunk 复用 KV cache。**图生视频 (I2V)**、**视频续写 (V2V)**、 + 或任何受益于逐帧自回归采样的场景选这个。同一个模型通过传入 `context_sequence` 来切换三种任务模式。 + +简化对照表: + +| 场景 | Pipeline | 调用方式 | +|------|----------|----------| +| 纯文生视频,固定 NFE 求最大质量 | `AnyFlowPipeline` | `pipe(prompt, ...)` | +| 图生视频(首帧给定) | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": <单帧 tensor>}, ...)` | +| 视频续写 / V2V | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": <多帧 tensor>}, ...)` | +| 流式 / 渐进式生成 | `AnyFlowFARPipeline` | — | + +高分辨率下 bidirectional 单 token 更快;causal 牺牲一点单步速度,换来在所有 latent 帧分配前就能开始 +采样的能力,对超长序列尤其有用。 + +## 加载 checkpoint + +NVIDIA 发布了 4 个 AnyFlow checkpoint,pipeline × 规模各一份: + +```py +import torch +from diffusers import AnyFlowPipeline, AnyFlowFARPipeline + +# Bidirectional, 轻量 +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Bidirectional, 满血 +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 1.3B +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 14B +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +四个 checkpoint 共用同一份 [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete), +默认 `shift=5.0`。 + +## Any-step 采样 + +AnyFlow 最关键的特性是同一个 checkpoint **不需重新调度**,NFE 越大质量越高。固定 prompt、扫一下步数 +就能看出模型怎么在延迟和保真度之间权衡: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "森林里一只小熊猫在啃竹子,电影感光照" + +for nfe in [1, 2, 4, 8, 16, 32]: + # 每轮重建 generator —— 这样跨步数对比时唯一变量是 NFE。 + generator = torch.Generator("cuda").manual_seed(0) + video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0] + export_to_video(video, f"out_nfe{nfe}.mp4", fps=16) +``` + +paper 的 Tab 3 / Fig 1 表明:每个 AnyFlow checkpoint 在 4 → 32 NFE 范围 VBench Quality 都单调上升,而 +consistency 类基线(rCM、Self-Forcing)在同区间反而掉点。 + +> [!TIP] +> Classifier-free guidance (CFG) 已经在训练阶段融进权重。pipeline 推理 +> 时**不会**再跑一次 unconditional 前向 —— guidance 直接由蒸馏后的权重带出。release 出来的 checkpoint +> 都用默认的 `guidance_scale=1.0` 即可。 + +## 图生视频 与 视频续写 + +Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `context_sequence` 隐式选择**(dict,含 +`"raw"` 视频张量或 `"latent"` 已编码 latent)。Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE +时间步长对齐。 + +> [!IMPORTANT] +> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。默认 +> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21)对应发布 checkpoint 的标准 `num_frames=81` +> (21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`,使其求和等于 +> `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `AssertionError`。比如 `num_frames=33` 对应 9 个 latent +> 帧,可用 `chunk_partition=[1, 4, 4]`。 + +```py +import numpy as np +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video, load_image, load_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + + +def to_video_tensor(images, height=480, width=832): + """把 PIL 列表转成 FAR pipeline 需要的 (B, C, T, H, W) [0, 1] 张量。""" + frames = np.stack([np.asarray(img.resize((width, height))) for img in images]).astype("float32") / 255.0 + return torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, T, H, W) + + +# 1) 文生视频(无 context)。81 帧匹配默认 chunk_partition。 +video = pipe(prompt="一只猫在夕阳下冲浪", num_inference_steps=4, num_frames=81).frames[0] +export_to_video(video, "t2v.mp4", fps=16) + +# 2) 图生视频 —— 单帧 context 经过 VAE 是 1 个 latent,正好对上默认 chunk_partition 的第一项 (`[1, ...]`)。 +first_frame = load_image("path/to/first_frame.png") +context_tensor = to_video_tensor([first_frame]).to("cuda") # (1, 3, 1, 480, 832), [0, 1] +video = pipe( + prompt="一只猫走过阳光下的草坪", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "i2v.mp4", fps=16) + +# 3) 视频续写。9 帧 raw context → 3 个 latent context;显式覆盖 chunk_partition,让第一块正好覆盖 context。 +context_frames = load_video("path/to/context.mp4")[:9] # 9 = 4·2 + 1 +context_tensor = to_video_tensor(context_frames).to("cuda") # (1, 3, 9, 480, 832) +video = pipe( + prompt="继续这个故事", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=81, + chunk_partition=[3, 3, 3, 3, 3, 3, 3], # 7 个 chunk × 3 = 21 latent;首块就是 context +).frames[0] +export_to_video(video, "v2v.mp4", fps=16) +``` + +底层 patchify chunk 调度根据 `context_sequence` 自动调整:纯文生用 kernel 2 (full) 和 4 (compressed); +有 context 时第一个 chunk 改成 kernel 1,让条件帧保留全分辨率。 + +如果你已经有 VAE 编码过的 latent,可以直接传 `context_sequence={"latent": ...}` 跳过 `vae_encode` 步骤。 + +## 显存与推理速度 + +14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + +延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好: + +```py +pipe = pipe.to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + +编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),`torch.compile` 在 14B 上相比 eager +模式有明显加速。 + +## LoRA 微调 + +两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的 +LoRA adapter 直接加载即可: + +```py +pipe.load_lora_weights("path/or/repo/with/wan_lora") +``` + +如果要做**继续 on-policy 蒸馏微调**(用论文里相同的 DMD 反向散度监督配方训新 LoRA),请参考原始 +AnyFlow 训练框架 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),这套训练流程不在 +diffusers 范围内。 + +## 常见坑 + +- **永远 `guidance_scale=1.0`。** 蒸馏后的 checkpoint 已经把 CFG 融进权重。设 `> 1` 会多跑一遍 + unconditional 前向、延迟翻倍、质量微降。 +- **Bidirectional pipeline 不支持流式。** 所有 `num_frames` 一起去噪。需要边采边播请用 causal pipeline。 +- **Causal pipeline KV cache 假设 chunk 调度跨调用一致。** 中途重建 cache 不被 release 模型支持。 +- **`num_frames` 必须满足 VAE 时间步长。** release checkpoint 用 `(N - 1) % 4 == 0` 的值(如 9、17、33、81)。 + +## 引用 + +```bibtex +@misc{gu2026anyflowanystepvideodiffusion, + title={AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation}, + author={Yuchao Gu and Guian Fang and Yuxin Jiang and Weijia Mao and Song Han and Han Cai and Mike Zheng Shou}, + year={2026}, + eprint={2605.13724}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2605.13724}, +} + +@article{gu2025long, + title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction}, + author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng}, + journal={arXiv preprint arXiv:2503.19325}, + year={2025} +} +``` diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py new file mode 100644 index 000000000000..60574ca23a1e --- /dev/null +++ b/scripts/convert_anyflow_to_diffusers.py @@ -0,0 +1,152 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert AnyFlow training checkpoints to the diffusers ``save_pretrained`` layout. + +The AnyFlow training pipeline emits ``.pt`` files containing an ``ema`` key whose value is a flat state +dict for the transformer. This script: + +1. Loads the matching base Wan2.1 pipeline from the Hub (provides VAE, tokenizer, and text encoder). +2. Constructs an ``AnyFlowTransformer3DModel`` with the right config flags for the chosen variant. +3. Loads the ``ema`` weights into the transformer. +4. Wraps everything in an ``AnyFlowPipeline`` (bidirectional) or ``AnyFlowFARPipeline`` (FAR causal). +5. Calls ``pipeline.save_pretrained(output_dir)``. + +Example: + +```bash +python scripts/convert_anyflow_to_diffusers.py \\ + --variant AnyFlow-FAR-Wan2.1-1.3B-Diffusers \\ + --ckpt /path/to/anyflow-checkpoint.pt \\ + --output-dir /path/to/output/AnyFlow-FAR-Wan2.1-1.3B-Diffusers +``` +""" + +import argparse +import logging +import os + +import torch + +from diffusers import ( + AnyFlowFARPipeline, + AnyFlowFARTransformer3DModel, + AnyFlowPipeline, + AnyFlowTransformer3DModel, + FlowMapEulerDiscreteScheduler, +) + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + +# Per-variant configuration. ``base_model`` is fetched from the Hub to source the matching VAE / text encoder. +VARIANTS = { + "AnyFlow-FAR-Wan2.1-1.3B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "transformer_cls": AnyFlowFARTransformer3DModel, + "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]}, + "pipeline_cls": AnyFlowFARPipeline, + }, + "AnyFlow-FAR-Wan2.1-14B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "transformer_cls": AnyFlowFARTransformer3DModel, + "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]}, + "pipeline_cls": AnyFlowFARPipeline, + }, + "AnyFlow-Wan2.1-T2V-1.3B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "transformer_cls": AnyFlowTransformer3DModel, + "transformer_kwargs": {}, + "pipeline_cls": AnyFlowPipeline, + }, + "AnyFlow-Wan2.1-T2V-14B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "transformer_cls": AnyFlowTransformer3DModel, + "transformer_kwargs": {}, + "pipeline_cls": AnyFlowPipeline, + }, +} + + +def build_pipeline(variant: str, ckpt_path: str): + if variant not in VARIANTS: + raise ValueError(f"Unknown variant {variant!r}. Choices: {list(VARIANTS)}.") + spec = VARIANTS[variant] + + transformer = spec["transformer_cls"].from_pretrained( + spec["base_model"], + subfolder="transformer", + gate_value=0.25, + deltatime_type="r", + **spec["transformer_kwargs"], + ) + # NVlabs/AnyFlow training checkpoints are wrapped Python objects (the `ema` key carries metadata + # alongside tensors), so the unpickle is required. Only run this script on checkpoints you trust. + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["ema"] + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + if unexpected: + logger.warning( + "Unexpected keys in state dict (ignored): %s%s", + unexpected[:5], + "..." if len(unexpected) > 5 else "", + ) + if missing: + logger.warning( + "Missing keys not loaded from state dict: %s%s", + missing[:5], + "..." if len(missing) > 5 else "", + ) + + scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0) + + pipeline = spec["pipeline_cls"].from_pretrained( + spec["base_model"], + transformer=transformer, + scheduler=scheduler, + ) + return pipeline + + +def main(): + parser = argparse.ArgumentParser( + description="Convert an AnyFlow training checkpoint into a diffusers pipeline directory." + ) + parser.add_argument( + "--variant", + required=True, + choices=list(VARIANTS), + help="Which AnyFlow variant the checkpoint corresponds to.", + ) + parser.add_argument( + "--ckpt", + required=True, + help="Path to the AnyFlow training checkpoint (a .pt file containing an 'ema' key).", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Destination directory for pipeline.save_pretrained.", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + pipeline = build_pipeline(args.variant, args.ckpt) + pipeline.save_pretrained(args.output_dir) + logger.info("Saved %s pipeline to %s", args.variant, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d120d0a22818..3a8332dc0c3a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -191,6 +191,8 @@ [ "AceStepTransformer1DModel", "AllegroTransformer3DModel", + "AnyFlowFARTransformer3DModel", + "AnyFlowTransformer3DModel", "AsymmetricAutoencoderKL", "AttentionBackendName", "AuraFlowTransformer2DModel", @@ -380,6 +382,7 @@ "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", + "FlowMapEulerDiscreteScheduler", "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", @@ -511,6 +514,8 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoControlNetPipeline", "AnimateDiffVideoToVideoPipeline", + "AnyFlowFARPipeline", + "AnyFlowPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", @@ -1019,6 +1024,8 @@ from .models import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnyFlowFARTransformer3DModel, + AnyFlowTransformer3DModel, AsymmetricAutoencoderKL, AttentionBackendName, AuraFlowTransformer2DModel, @@ -1204,6 +1211,7 @@ EDMEulerScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + FlowMapEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, @@ -1316,6 +1324,8 @@ AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, + AnyFlowFARPipeline, + AnyFlowPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index ff8e16aad447..05d29aabb7e0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -95,6 +95,10 @@ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] + _import_structure["transformers.transformer_anyflow"] = [ + "AnyFlowFARTransformer3DModel", + "AnyFlowTransformer3DModel", + ] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] @@ -214,6 +218,8 @@ from .transformers import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnyFlowFARTransformer3DModel, + AnyFlowTransformer3DModel, AuraFlowTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 156b54e7f07d..4b0a55d0a8ad 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel + from .transformer_anyflow import AnyFlowFARTransformer3DModel, AnyFlowTransformer3DModel from .transformer_bria import BriaTransformer2DModel from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py new file mode 100644 index 000000000000..55fadc317559 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -0,0 +1,1683 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file derives from the FAR architecture (Gu et al., 2025, arXiv:2503.19325) and adds AnyFlow's +# dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding). The base 3D DiT structure +# is adapted from the v0.35.1 Wan2.1 transformer (transformer_wan.py); upstream Wan has since been +# refactored, so this file is intentionally self-contained rather than annotated with `# Copied from`. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention.flex_attention import create_block_mask +from torch.nn.attention.flex_attention import flex_attention as _flex_attention_eager + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import BaseOutput, logging +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class AnyFlowFARTransformerOutput(BaseOutput): + """ + Output dataclass for ``AnyFlowTransformer3DModel``'s causal forward paths. + + Args: + sample (`torch.Tensor` or `None`): + Predicted denoising target for the autoregressive chunk. ``None`` for the cache-prefill path, + which only writes the KV cache and produces no usable sample. + kv_cache (`list[dict[str, torch.Tensor]]`, *optional*): + Per-block KV cache state used by subsequent autoregressive steps. + """ + + sample: Optional[torch.Tensor] = None + kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None + + +# `flex_attention` is compiled lazily on first CUDA use. The compiled kernel goes through +# Triton/Inductor C++ codegen which can fail on small or unusual shapes seen in fast tests +# (e.g. CPU-side `pipe.to("cpu")` with tiny dummy components), so we always fall back to the +# eager kernel on CPU. Both paths produce identical numeric output — the wrapper only differs +# in execution speed. +_flex_attention_compiled = None + + +def _get_compiled_flex_attention(): + global _flex_attention_compiled + if _flex_attention_compiled is None: + try: + _flex_attention_compiled = torch.compile(_flex_attention_eager, dynamic=True) + except Exception as e: # pragma: no cover - environment-dependent + logger.warning( + "Failed to torch.compile flex_attention; falling back to the eager kernel. Error: %s", + e, + ) + _flex_attention_compiled = _flex_attention_eager + return _flex_attention_compiled + + +def flex_attention(query, key, value, *args, **kwargs): + """Dispatch to the compiled flex_attention on CUDA (fast path) or the eager one on CPU + (avoids Triton/Inductor codegen failures on tiny test shapes).""" + if query.device.type == "cuda": + return _get_compiled_flex_attention()(query, key, value, *args, **kwargs) + return _flex_attention_eager(query, key, value, *args, **kwargs) + + +def build_block_mask(mask_2d, device): + if mask_2d.dim() != 2 or mask_2d.dtype != torch.bool: + raise ValueError( + f"`mask_2d` must be a 2D boolean tensor, got shape {tuple(mask_2d.shape)} and dtype {mask_2d.dtype}." + ) + mask_2d = mask_2d.contiguous() + + Q_LEN, KV_LEN = mask_2d.shape + + def mask_mod(b, h, q_idx, kv_idx): + return mask_2d[q_idx, kv_idx] + + return create_block_mask(mask_mod, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, _compile=False) + + +def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices. + is_mps = hidden_states.device.type == "mps" + is_npu = hidden_states.device.type == "npu" + rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + +class AnyFlowAttnProcessor: + """ + Self-attention processor for AnyFlow. Supports three modes: + + * Bidirectional (``attention_mask`` is ``None``) — standard SDPA / dispatched backend. + * FAR causal — :class:`~torch.nn.attention.flex_attention.BlockMask` via ``flex_attention``. + * Autoregressive inference with KV cache (read or write) for the FAR variant. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AnyFlowAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0 or higher." + ) + + def __call__( + self, + attn: "AnyFlowAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[Any] = None, + rotary_emb: Optional[Dict[str, torch.Tensor]] = None, + kv_cache: Optional[Dict[str, torch.Tensor]] = None, + kv_cache_flag: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Layout (B, H, L, D) is required by KV-cache slicing, rotary application, and flex_attention. + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if kv_cache is not None: + if kv_cache_flag["is_cache_step"]: + kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_compressed_tokens"], :] = key[ + :, :, : kv_cache_flag["num_compressed_tokens"] + ] + kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_compressed_tokens"], :] = value[ + :, :, : kv_cache_flag["num_compressed_tokens"] + ] + kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_full_tokens"], :] = key[ + :, :, kv_cache_flag["num_compressed_tokens"] : + ] + kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_full_tokens"], :] = value[ + :, :, kv_cache_flag["num_compressed_tokens"] : + ] + else: + key = torch.cat( + [ + kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :], + kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_cached_full_tokens"], :], + key, + ], + dim=2, + ) + value = torch.cat( + [ + kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :], + kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_cached_full_tokens"], :], + value, + ], + dim=2, + ) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb["query"]) + key = apply_rotary_emb(key, rotary_emb["key"]) + + if attention_mask is None: + # Pass (B, L, H, D) to dispatch_attention_fn (its native backend permutes back to (B, H, L, D) + # internally before calling the kernel). + hidden_states = dispatch_attention_fn( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + # Output already in (B, L, H, D); fold heads into the channel dim. + hidden_states = hidden_states.flatten(2, 3) + else: + # FAR causal path: BlockMask is consumed by flex_attention only, so we keep the + # (B, H, L, D) layout and pad to a multiple of 128 for the BlockMask block size. + seq_len = query.shape[2] + head_dim = query.shape[3] + padded_length = int(math.ceil(seq_len / 128.0) * 128.0 - seq_len) + query = torch.cat( + [ + query, + torch.zeros( + [query.shape[0], query.shape[1], padded_length, query.shape[3]], + device=query.device, + dtype=query.dtype, + ), + ], + dim=2, + ) + key = torch.cat( + [ + key, + torch.zeros( + [key.shape[0], key.shape[1], padded_length, key.shape[3]], + device=key.device, + dtype=key.dtype, + ), + ], + dim=2, + ) + value = torch.cat( + [ + value, + torch.zeros( + [value.shape[0], value.shape[1], padded_length, value.shape[3]], + device=value.device, + dtype=value.dtype, + ), + ], + dim=2, + ) + # flex_attention requires head_dim >= 16. When tiny dummy components use head_dim < 16 + # (real ckpts use 128) we right-pad q/k/v with zeros and pass an explicit scale matched + # to the original head_dim; padded value rows contribute 0, so trimming back preserves + # output equivalence. + head_pad = max(0, 16 - head_dim) + scale = 1.0 / (head_dim**0.5) if head_pad > 0 else None + if head_pad > 0: + query = F.pad(query, (0, head_pad)) + key = F.pad(key, (0, head_pad)) + value = F.pad(value, (0, head_pad)) + hidden_states = flex_attention(query, key, value, block_mask=attention_mask, scale=scale) + if head_pad > 0: + hidden_states = hidden_states[..., :head_dim] + hidden_states = hidden_states[:, :, :seq_len] + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states = hidden_states.type_as(query) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AnyFlowCrossAttnProcessor: + """ + Cross-attention processor for AnyFlow. Always uses the dispatched SDPA-compatible backend; no rotary + embedding or KV cache is applied to the text→video cross-attention path. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AnyFlowCrossAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0 or higher." + ) + + def __call__( + self, + attn: "AnyFlowAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # (B, L, H, D) layout for dispatch_attention_fn. + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AnyFlowAttention(torch.nn.Module, AttentionModuleMixin): + """ + Attention module used by :class:`AnyFlowTransformerBlock`. Layout matches the legacy + :class:`~diffusers.models.attention_processor.Attention` so existing AnyFlow checkpoints load + bit-exactly into this class. + """ + + _default_processor_cls = AnyFlowAttnProcessor + _available_processors = [AnyFlowAttnProcessor, AnyFlowCrossAttnProcessor] + + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + eps: float = 1e-6, + processor: Optional[Any] = None, + ): + super().__init__() + self.heads = heads + self.inner_dim = heads * dim_head + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(0.0), + ] + ) + # ``rms_norm_across_heads`` per-axis: normalize Q and K across the entire ``heads * dim_head`` + # channel axis. We use diffusers' RMSNorm (rather than ``torch.nn.RMSNorm``) so the numerics + # match the legacy Attention class that produced the released checkpoints. + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + self.set_processor(processor if processor is not None else self._default_processor_cls()) + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + return self.processor(self, hidden_states, **kwargs) + + +class AnyFlowImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class AnyFlowDualTimestepTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + gate_value: float, + deltatime_type: str, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.delta_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim) + + self.register_buffer("delta_emb_gate", torch.tensor([gate_value], dtype=torch.float32), persistent=False) + self.deltatime_type = deltatime_type + + def forward_timestep( + self, timestep: torch.Tensor, delta_timestep: torch.Tensor, encoder_hidden_states, token_per_frame + ): + batch_size, num_frames = timestep.shape + timestep = timestep.reshape(-1) + delta_timestep = delta_timestep.reshape(-1) + + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + + delta_timestep = self.timesteps_proj(delta_timestep) + + delta_embedder_dtype = next(iter(self.delta_embedder.parameters())).dtype + if delta_timestep.dtype != delta_embedder_dtype and delta_embedder_dtype != torch.int8: + delta_timestep = delta_timestep.to(delta_embedder_dtype) + delta_emb = self.delta_embedder(delta_timestep).type_as(encoder_hidden_states) + + gate = self.delta_emb_gate.to(delta_embedder_dtype) + + rt_emb = (1 - gate) * temb + gate * delta_emb + timestep_proj = self.time_proj(self.act_fn(rt_emb)) + + rt_emb = rt_emb.unflatten(0, (batch_size, num_frames)).repeat_interleave(token_per_frame, dim=1) + timestep_proj = timestep_proj.unflatten(0, (batch_size, num_frames)).repeat_interleave(token_per_frame, dim=1) + + return rt_emb, timestep_proj + + def forward( + self, + timestep: torch.Tensor, + r_timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + far_cfg=None, + clean_timestep=None, + is_causal=True, + ): + if self.deltatime_type == "r": + delta_timestep = r_timestep + elif self.deltatime_type == "t-r": + delta_timestep = timestep - r_timestep + else: + raise NotImplementedError + + if is_causal: + full_frame_timestep, full_frame_timestep_proj = self.forward_timestep( + timestep[:, -far_cfg["num_full_frames"] :], + delta_timestep[:, -far_cfg["num_full_frames"] :], + encoder_hidden_states, + far_cfg["full_token_per_frame"], + ) # noqa: E501 + compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep( + timestep[:, : -far_cfg["num_full_frames"]], + delta_timestep[:, : -far_cfg["num_full_frames"]], + encoder_hidden_states, + far_cfg["compressed_token_per_frame"], + ) # noqa: E501 + + if clean_timestep is not None: + clean_timestep, clean_timestep_proj = self.forward_timestep( + clean_timestep, clean_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1) + timestep_proj = torch.cat( + [compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1 + ) + else: + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1) + timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + else: + timestep, timestep_proj = self.forward_timestep( + timestep, delta_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return timestep, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class AnyFlowRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + compressed_patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.compressed_patch_size = compressed_patch_size + self.max_seq_len = max_seq_len + self.theta = theta + + # Frequency table is lazily built per-device in ``_build_freqs``: MPS / NPU don't support + # complex128, so we downcast to complex64 there. + self._freqs_cache: Optional[Tuple[Any, torch.Tensor]] = None + + def _build_freqs(self, device: torch.device) -> torch.Tensor: + cache_key = (device.type, str(device)) + if self._freqs_cache is not None and self._freqs_cache[0] == cache_key: + return self._freqs_cache[1] + + is_mps = device.type == "mps" + is_npu = device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + + h_dim = w_dim = 2 * (self.attention_head_dim // 6) + t_dim = self.attention_head_dim - h_dim - w_dim + + freqs_list = [] + for dim in (t_dim, h_dim, w_dim): + f = get_1d_rotary_pos_embed( + dim, + self.max_seq_len, + self.theta, + use_real=False, + repeat_interleave_real=False, + freqs_dtype=freqs_dtype, + ) + freqs_list.append(f.to(device)) + freqs = torch.cat(freqs_list, dim=1) + self._freqs_cache = (cache_key, freqs) + return freqs + + def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int): + + real = freq.real # [B, C, L], float + real = real.transpose(0, 1).unsqueeze(0) + imag = freq.imag # [B, C, L], float + imag = imag.transpose(0, 1).unsqueeze(0) + + pr = F.avg_pool1d(real, kernel_size, stride) + pi = F.avg_pool1d(imag, kernel_size, stride) + + pr = pr.squeeze(0).transpose(0, 1) + pi = pi.squeeze(0).transpose(0, 1) + + norm = torch.sqrt(pr**2 + pi**2) + pr_unit = pr / norm + pi_unit = pi / norm + + return torch.complex(pr_unit, pi_unit) + + def _forward_compressed_frame(self, num_frames, height, width, device): + ppf, pph, ppw = num_frames, height, width + # Tiny dummy components (e.g. height=16/width=16 with compressed_patch_size=(1,4,4) and + # an upstream VAE stride of 8) can produce 0-element grids; the .view(0, k, 1, -1) reshape + # below would be ambiguous. Real ckpts use 60x104 latents and never hit this path. + freqs_full = self._build_freqs(device) + if min(ppf, pph, ppw) <= 0: + freq_channels = self.attention_head_dim // 2 + return torch.empty((ppf, pph, ppw, freq_channels), dtype=freqs_full.dtype, device=device) + downscale = [self.compressed_patch_size[i] // self.patch_size[i] for i in range(len(self.patch_size))] + + freqs = freqs_full.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = self.avg_pool_complex(freqs[0], kernel_size=downscale[0], stride=downscale[0]) + freqs_h = self.avg_pool_complex(freqs[1], kernel_size=downscale[1], stride=downscale[1]) + freqs_w = self.avg_pool_complex(freqs[2], kernel_size=downscale[2], stride=downscale[2]) + + freqs_f = freqs_f[:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs_h[:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs_w[:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) + return freqs + + def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor: + ppf, pph, ppw = num_frames, height, width + + freqs_full = self._build_freqs(device) + if min(ppf, pph, ppw) <= 0: + freq_channels = self.attention_head_dim // 2 + return torch.empty((ppf, pph, ppw, freq_channels), dtype=freqs_full.dtype, device=device) + + freqs = freqs_full.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) + return freqs + + def forward(self, far_cfg, device, clean_hidden_states=None, is_causal=True): + if is_causal: + full_frame_freqs = self._forward_full_frame( + num_frames=far_cfg["total_frames"], + height=far_cfg["full_frame_shape"][0], + width=far_cfg["full_frame_shape"][1], + device=device, + ) + compressed_frame_freqs = self._forward_compressed_frame( + num_frames=far_cfg["total_frames"], + height=far_cfg["compressed_frame_shape"][0], + width=far_cfg["compressed_frame_shape"][1], + device=device, + ) + + compressed_frame_freqs, full_frame_freqs = ( + compressed_frame_freqs[: far_cfg["num_compressed_frames"]], + full_frame_freqs[far_cfg["num_compressed_frames"] :], + ) # noqa: E501 + + compressed_frame_freqs = compressed_frame_freqs.flatten(start_dim=0, end_dim=2) + full_frame_freqs = full_frame_freqs.flatten(start_dim=0, end_dim=2) + + if clean_hidden_states is not None: + freqs = torch.cat([compressed_frame_freqs, full_frame_freqs, full_frame_freqs], dim=0) + else: + freqs = torch.cat([compressed_frame_freqs, full_frame_freqs], dim=0) + + freqs = freqs[None, None, ...] + + return {"query": freqs, "key": freqs} + else: + freqs = self._forward_full_frame( + num_frames=far_cfg["total_frames"], + height=far_cfg["full_frame_shape"][0], + width=far_cfg["full_frame_shape"][1], + device=device, + ) + freqs = freqs.flatten(start_dim=0, end_dim=2) + freqs = freqs[None, None, ...] + return {"query": freqs, "key": freqs} + + +class AnyFlowTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + cross_attn_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = AnyFlowAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + processor=AnyFlowAttnProcessor(), + ) + + # 2. Cross-attention + self.attn2 = AnyFlowAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + processor=AnyFlowCrossAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + attention_mask: torch.Tensor, + kv_cache=None, + kv_cache_flag=None, + ) -> torch.Tensor: + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=2) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + shift_msa.squeeze(2), + scale_msa.squeeze(2), + gate_msa.squeeze(2), + c_shift_msa.squeeze(2), + c_scale_msa.squeeze(2), + c_gate_msa.squeeze(2), + ) # noqa: E501 + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1( + hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + kv_cache_flag=kv_cache_flag, + ) # noqa: E501 + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + Bidirectional 3D Transformer for AnyFlow flow-map sampling. + + The architecture is the v0.35.1 Wan2.1 3D DiT backbone with one structural change: the timestep + embedder is replaced by ``AnyFlowDualTimestepTextImageEmbedding`` so that every forward call conditions + on both the source timestep ``t`` and the target timestep ``r``. This is the embedding required to + learn the flow map :math:`\Phi_{r\leftarrow t}` introduced in + [AnyFlow](https://huggingface.co/papers/2605.13724). + + For frame-level autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead; + that variant adds the FAR causal block-mask and a compressed-frame patch embedding on top of the same + backbone. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Number of attention heads. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input latent. + out_channels (`int`, defaults to `16`): + The number of channels in the output latent. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings (UMT5). + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + Number of transformer blocks. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + eps (`float`, defaults to `1e-6`): + Epsilon for normalization layers. + image_dim (`Optional[int]`, *optional*, defaults to `None`): + Image embedding dimension for I2V conditioning (`1280` for the original Wan2.1-I2V model). + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length used to precompute rotary position frequencies. + gate_value (`float`, defaults to `0.25`): + Mixing gate between source-timestep and delta-timestep embeddings (the AnyFlow paper's :math:`g` + parameter, fixed at 0.25 in stage-1 distillation). + deltatime_type (`str`, defaults to `'r'`): + Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval). + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["AnyFlowTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _repeated_blocks = ["AnyFlowTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + eps: float = 1e-6, + image_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + gate_value: float = 0.25, + deltatime_type: str = "r", + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding. ``compressed_patch_size`` is unused in the bidirectional path; + # we forward ``patch_size`` itself so the rotary helper has a valid (no-op) value. + self.rope = AnyFlowRotaryPosEmbed(attention_head_dim, patch_size, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embedding (always dual-timestep for AnyFlow distilled checkpoints). + self.condition_embedder = AnyFlowDualTimestepTextImageEmbedding( + dim=inner_dim, + gate_value=gate_value, + deltatime_type=deltatime_type, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + AnyFlowTransformerBlock(inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size): + batch_size, num_patches, channels = latents.shape + height, width = height // patch_size, width // patch_size + + latents = latents.view( + batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size) + ) + latents = latents.permute(0, 5, 1, 3, 2, 4) + latents = latents.reshape( + batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size + ) + return latents + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + r_timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Transformer2DModelOutput, Tuple]: + """ + Bidirectional flow-map forward pass. + + ``hidden_states`` is laid out as ``(B, F, C, H, W)`` (per-frame latents). The input is patchified + with the standard ``patch_embedding`` (kernel = stride = ``patch_size``) and denoised with global + bidirectional self-attention over the resulting flat token sequence. + """ + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height * width) // (self.config.patch_size[1] * self.config.patch_size[2]) + + far_cfg = { + "total_frames": num_frames, + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "full_token_per_frame": full_token_per_frame, + } + + rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device, is_causal=False) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + is_causal=False, + far_cfg=far_cfg, + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + attention_mask = None + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask) + + # Output norm, projection & unpatchify + if temb.ndim == 3: + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move shift/scale to hidden_states' device for multi-GPU accelerate inference. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + output = self._unpack_latent_sequence( + hidden_states, + num_frames=far_cfg["total_frames"], + height=height, + width=width, + patch_size=self.config.patch_size[1], + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with frame-level autoregressive generation. + + Extends the v0.35.1 Wan2.1 backbone with: + + * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting frame-level + autoregressive generation (FAR; [Gu et al., 2025](https://arxiv.org/abs/2503.19325)). + * **Compressed-frame patch embedding** ``far_patch_embedding`` for context (already-generated) frames, + initialized from ``patch_embedding`` via trilinear interpolation so a freshly constructed model is + already at a reasonable starting point even before LoRA fine-tuning. + * **Dual-timestep flow-map embedding** for any-step sampling (same as ``AnyFlowTransformer3DModel``). + + Use ``AnyFlowTransformer3DModel`` instead for plain bidirectional T2V — that variant skips the FAR + causal masking and ``far_patch_embedding`` and is ~5–10% smaller. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for full-resolution chunks. + compressed_patch_size (`Tuple[int]`, defaults to `(1, 4, 4)`): + Larger patch dimensions for the FAR-compressed (context) chunks. + full_chunk_limit (`int`, defaults to `3`): + Maximum number of full-resolution chunks before earlier chunks are demoted to compressed FAR + context. The released checkpoints use ``3``. + num_attention_heads (`int`, defaults to `40`): + Number of attention heads. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input latent. + out_channels (`int`, defaults to `16`): + The number of channels in the output latent. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings (UMT5). + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + Number of transformer blocks. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + eps (`float`, defaults to `1e-6`): + Epsilon for normalization layers. + image_dim (`Optional[int]`, *optional*, defaults to `None`): + Image embedding dimension for I2V conditioning. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length used to precompute rotary position frequencies. + gate_value (`float`, defaults to `0.25`): + Mixing gate between source-timestep and delta-timestep embeddings. + deltatime_type (`str`, defaults to `'r'`): + Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval). + + .. note:: + ``chunk_partition`` is **not** a model config field — it is a per-call argument passed to + :meth:`forward`. Different inference setups (varying ``num_frames`` or full-vs-compressed schedules) + therefore do not require separate checkpoints. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "far_patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["AnyFlowTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _repeated_blocks = ["AnyFlowTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + compressed_patch_size: Tuple[int] = (1, 4, 4), + full_chunk_limit: int = 3, + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + eps: float = 1e-6, + image_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + gate_value: float = 0.25, + deltatime_type: str = "r", + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding (full + FAR-compressed branches). + self.rope = AnyFlowRotaryPosEmbed(attention_head_dim, patch_size, compressed_patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + self.far_patch_embedding = nn.Conv3d( + in_channels, inner_dim, kernel_size=compressed_patch_size, stride=compressed_patch_size + ) + # Warm-start the compressed branch from the full-resolution branch by trilinear interpolation. This + # matches FAR-Dev's `setup_far_model()` initialization. State-dict loading will overwrite these + # weights for trained checkpoints; the warm-start only matters when constructing a fresh model. + original_weight = self.patch_embedding.weight.data.view(-1, 1, *patch_size) + new_weight = F.interpolate(original_weight, size=compressed_patch_size, mode="trilinear", align_corners=False) + new_weight = new_weight.view(inner_dim, in_channels, *compressed_patch_size) + with torch.no_grad(): + self.far_patch_embedding.weight.copy_(new_weight) + self.far_patch_embedding.bias.copy_(self.patch_embedding.bias) + + # 2. Condition embedding (always dual-timestep for AnyFlow distilled checkpoints). + self.condition_embedder = AnyFlowDualTimestepTextImageEmbedding( + dim=inner_dim, + gate_value=gate_value, + deltatime_type=deltatime_type, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + AnyFlowTransformerBlock(inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + r_timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + chunk_partition: List[int], + encoder_hidden_states_image: Optional[torch.Tensor] = None, + clean_hidden_states: Optional[torch.Tensor] = None, + clean_timestep: Optional[torch.Tensor] = None, + kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None, + kv_cache_flag: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Transformer2DModelOutput, AnyFlowFARTransformerOutput, Tuple]: + """ + FAR causal forward pass. Dispatches to one of three internal paths: + + * ``kv_cache is None`` → causal training rollout (returns + :class:`Transformer2DModelOutput`). + * ``kv_cache is not None`` and ``kv_cache_flag["is_cache_step"]`` → cache-prefill (returns + :class:`AnyFlowFARTransformerOutput` with ``sample=None``). + * Otherwise → autoregressive inference step (returns :class:`AnyFlowFARTransformerOutput`). + + Args: + hidden_states (`torch.Tensor`): Latent input of shape ``(B, F, C, H, W)``. + timestep, r_timestep (`torch.Tensor`): Source / target diffusion timesteps. + encoder_hidden_states (`torch.Tensor`): UMT5 text embeddings. + chunk_partition (`List[int]`): Per-chunk frame counts; total must match the number of latent + frames in ``hidden_states``. + encoder_hidden_states_image (`torch.Tensor`, *optional*): I2V image embedding. + clean_hidden_states, clean_timestep (`torch.Tensor`, *optional*): Clean conditioning frames + used by the training rollout. + kv_cache, kv_cache_flag (*optional*): Per-block KV cache and metadata for autoregressive + inference. + attention_kwargs (*optional*): forwarded to the attention processors. + return_dict (`bool`, defaults to `True`): If `False`, returns positional tuples. + """ + common = { + "hidden_states": hidden_states, + "chunk_partition": chunk_partition, + "timestep": timestep, + "r_timestep": r_timestep, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_image": encoder_hidden_states_image, + "return_dict": return_dict, + "attention_kwargs": attention_kwargs, + } + if kv_cache is not None: + common["kv_cache"] = kv_cache + common["kv_cache_flag"] = kv_cache_flag + if kv_cache_flag is not None and kv_cache_flag.get("is_cache_step"): + return self._forward_cache( + clean_hidden_states=clean_hidden_states, + clean_timestep=clean_timestep, + **common, + ) + return self._forward_inference(**common) + return self._forward_train( + clean_hidden_states=clean_hidden_states, + clean_timestep=clean_timestep, + **common, + ) + + def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size): + batch_size, num_patches, channels = latents.shape + height, width = height // patch_size, width // patch_size + + latents = latents.view( + batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size) + ) + + latents = latents.permute(0, 5, 1, 3, 2, 4) + latents = latents.reshape( + batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size + ) + return latents + + def forward_far_patchify(self, hidden_states, far_cfg, clean_hidden_states=None): + + full_hidden_states, compressed_hidden_states = ( + hidden_states[:, :, far_cfg["num_compressed_frames"] :], + hidden_states[:, :, : far_cfg["num_compressed_frames"]], + ) # noqa: E501 + + patchified_full_hidden_states = ( + self.patch_embedding(full_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) + if clean_hidden_states is not None: + clean_hidden_states = ( + self.patch_embedding(clean_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) + patchified_full_hidden_states = torch.cat([patchified_full_hidden_states, clean_hidden_states], dim=1) + + if far_cfg["num_compressed_frames"] > 0: + patchified_compressed_hidden_states = ( + self.far_patch_embedding(compressed_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) + hidden_states = torch.cat([patchified_compressed_hidden_states, patchified_full_hidden_states], dim=1) + else: + hidden_states = patchified_full_hidden_states + return hidden_states + + def forward_far_patchify_inference(self, hidden_states): + hidden_states = self.patch_embedding(hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + return hidden_states + + def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype): + chunk_partition = far_cfg["chunk_partition"] + + noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] + context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] + + noise_start = context_seq_len + noise_end = noise_start + noise_seq_len + + clean_start = context_seq_len + noise_seq_len + clean_end = clean_start + clean_seq_len + + if clean_hidden_states is not None: + real_seq_len = context_seq_len + noise_seq_len + clean_seq_len + else: + real_seq_len = context_seq_len + noise_seq_len + + padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0) + + if clean_hidden_states is not None: + context_chunk_partition, noise_chunk_partition = ( + chunk_partition[: far_cfg["num_compressed_chunk"]], + chunk_partition[far_cfg["num_compressed_chunk"] :], + ) # noqa: E501 + + if len(context_chunk_partition) != 0: + context_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx + for chunk_idx, chunk_len in enumerate(context_chunk_partition) + ] + ) # noqa: E501 + else: + context_frame_idx = None + noise_frame_idx = clean_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) # noqa: E501 + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + # q_idx, kv_idx: LongTensor, range: [0, padded_seq_len) + + # 1) whether is padding + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + + # 3) chunk casual + base = frame_idx[q_idx] >= frame_idx[kv_idx] + + # 4) interval mask + q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end) + q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end) + + k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end) + k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end) + + # 5) clean -> noise: disallowed + is_clean_to_noise = q_is_clean & k_is_noise + + # 6) noise -> noise: only same frame + same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx] + + noise_to_noise = q_is_noise & k_is_noise + noise_to_clean = q_is_noise & k_is_clean + + noise_to_noise_allow = noise_to_noise & same_frame_idx + noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow + + noise_to_clean_same = noise_to_clean & same_frame_idx + noise_to_clean_disallow = noise_to_clean_same + + # attention mask is chunk casual + allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow + return allowed + + return create_block_mask( + mask_mod, + B=None, + H=None, + Q_LEN=padded_seq_len, + KV_LEN=padded_seq_len, + device=device, + _compile=False, + ) + else: + context_chunk_partition, noise_chunk_partition = ( + chunk_partition[: far_cfg["num_compressed_chunk"]], + chunk_partition[far_cfg["num_compressed_chunk"] :], + ) # noqa: E501 + + if len(context_chunk_partition) != 0: + context_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx + for chunk_idx, chunk_len in enumerate(context_chunk_partition) + ] + ) # noqa: E501 + else: + context_frame_idx = None + + noise_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) # noqa: E501 + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + base = frame_idx[q_idx] >= frame_idx[kv_idx] + return base & ~is_padding + + return create_block_mask( + mask_mod, + B=None, + H=None, + Q_LEN=padded_seq_len, + KV_LEN=padded_seq_len, + device=device, + _compile=False, + ) + + def _forward_inference( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache=None, + kv_cache_flag=None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) + + total_chunks = 1 + kv_cache_flag["num_cached_chunks"] + + if total_chunks >= self.config.full_chunk_limit: + num_full_chunk, num_compressed_chunk = ( + self.config.full_chunk_limit, + total_chunks - self.config.full_chunk_limit, + ) + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + kv_cache_flag["num_cached_full_tokens"] = ( + sum(chunk_partition[num_compressed_chunk : num_compressed_chunk + (num_full_chunk - 1)]) + * full_token_per_frame + ) # noqa: E501 + kv_cache_flag["num_cached_compressed_tokens"] = ( + sum(chunk_partition[:num_compressed_chunk]) * compressed_token_per_frame + ) + + far_cfg = { + "total_frames": sum(chunk_partition), + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + } + + # step 3: generate attention mask + attention_mask = None + hidden_states = self.forward_far_patchify_inference(hidden_states) + + rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device) + rotary_emb["query"] = rotary_emb["query"][:, :, -hidden_states.shape[1] :] + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, # noqa: E501 + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2), scale.squeeze(2) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + output = self.proj_out(hidden_states) + output = self._unpack_latent_sequence( + output, num_frames=chunk_partition[-1], height=height, width=width, patch_size=self.config.patch_size[1] + ) + + if not return_dict: + return output, kv_cache + + return AnyFlowFARTransformerOutput(sample=output, kv_cache=kv_cache) + + def _forward_cache( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + clean_hidden_states=None, + clean_timestep=None, + kv_cache=None, + kv_cache_flag=None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + if clean_hidden_states is not None: + clean_hidden_states = clean_hidden_states.permute(0, 2, 1, 3, 4) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) + total_chunks = len(chunk_partition) + + full_chunk_limit = self.config.full_chunk_limit - 1 + + if total_chunks > full_chunk_limit: + num_full_chunk, num_compressed_chunk = full_chunk_limit, total_chunks - full_chunk_limit + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + far_cfg = { + "total_frames": sum(chunk_partition), + "num_full_chunk": num_full_chunk, + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_chunk": num_compressed_chunk, + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + "chunk_partition": chunk_partition, + } + + kv_cache_flag["num_full_tokens"] = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] + kv_cache_flag["num_compressed_tokens"] = ( + far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] + ) + + # step 3: generate attention mask + attention_mask = self._build_causal_mask( + far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype + ) + + rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) + hidden_states = self.forward_far_patchify( + hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states + ) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, + clean_timestep=clean_timestep, + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) + + if not return_dict: + return None, kv_cache + + return AnyFlowFARTransformerOutput(sample=None, kv_cache=kv_cache) + + def _forward_train( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + clean_hidden_states=None, + clean_timestep=None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + if clean_hidden_states is not None: + clean_hidden_states = clean_hidden_states.permute(0, 2, 1, 3, 4) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) + total_chunks = len(chunk_partition) + + if total_chunks > self.config.full_chunk_limit: + num_full_chunk, num_compressed_chunk = ( + self.config.full_chunk_limit, + total_chunks - self.config.full_chunk_limit, + ) + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + far_cfg = { + "total_frames": sum(chunk_partition), + "num_full_chunk": num_full_chunk, + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_chunk": num_compressed_chunk, + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + "chunk_partition": chunk_partition, + } + + # step 3: generate attention mask + attention_mask = self._build_causal_mask( + far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype + ) + + rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) + + hidden_states = self.forward_far_patchify( + hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states + ) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, + clean_timestep=clean_timestep, + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2), scale.squeeze(2) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + if clean_hidden_states is not None: + hidden_states = hidden_states[ + :, : -(far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]) + ] # remove clean copy + output = self.proj_out( + hidden_states[:, far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] :] + ) # remove far context + output = self._unpack_latent_sequence( + output, + num_frames=far_cfg["num_full_frames"], + height=height, + width=width, + patch_size=self.config.patch_size[1], + ) # noqa: E501 + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d4b3974322b4..c0d12121d5e8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -164,6 +164,10 @@ "AnimateDiffVideoToVideoPipeline", "AnimateDiffVideoToVideoControlNetPipeline", ] + _import_structure["anyflow"] = [ + "AnyFlowPipeline", + "AnyFlowFARPipeline", + ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["flux2"] = [ @@ -603,6 +607,10 @@ AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, ) + from .anyflow import ( + AnyFlowFARPipeline, + AnyFlowPipeline, + ) from .audioldm2 import ( AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/pipelines/anyflow/__init__.py b/src/diffusers/pipelines/anyflow/__init__.py new file mode 100644 index 000000000000..10603cdedc3b --- /dev/null +++ b/src/diffusers/pipelines/anyflow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_anyflow"] = ["AnyFlowPipeline"] + _import_structure["pipeline_anyflow_far"] = ["AnyFlowFARPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_anyflow import AnyFlowPipeline + from .pipeline_anyflow_far import AnyFlowFARPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py new file mode 100644 index 000000000000..19a0ec69b08d --- /dev/null +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py @@ -0,0 +1,669 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for any-step flow-map sampling. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AnyFlowTransformer3DModel, AutoencoderKLWan +from ...models.autoencoders.vae import DiagonalGaussianDistribution +from ...schedulers import FlowMapEulerDiscreteScheduler +from ...utils import is_ftfy_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnyFlowPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AnyFlowPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = AnyFlowPipeline.from_pretrained( + ... "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A red panda eating bamboo in a forest, cinematic lighting" + >>> video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0] + >>> export_to_video(video, "anyflow_t2v.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints. + + AnyFlow learns arbitrary-interval transitions :math:`z_t \to z_r` rather than the fixed + :math:`z_t \to z_0` mapping of consistency models, so a single distilled checkpoint can be evaluated at + 1, 2, 4, 8, 16... NFE without retraining. This pipeline operates over the full video tensor in one + bidirectional pass; for frame-level autoregressive (causal) generation use ``AnyFlowFARPipeline``. + + Sampling is plain Euler in mean-velocity form (``z_r = z_t - (t - r) * u``) with no re-noising. The + released NVIDIA checkpoints fold classifier-free guidance into the model weights, so the default + ``guidance_scale=1.0`` is the recommended setting. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [google/umt5-xxl](https://huggingface.co/google/umt5-xxl). + text_encoder ([`UMT5EncoderModel`]): + [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) text encoder. + transformer ([`AnyFlowTransformer3DModel`]): + Bidirectional flow-map 3D Transformer. + vae ([`AutoencoderKLWan`]): + VAE that encodes/decodes videos to and from latent representations. + scheduler ([`FlowMapEulerDiscreteScheduler`]): + Flow-map sampler. The pipeline drives ``scheduler.step(..., timestep, sample, r_timestep)`` per + inference step. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: AnyFlowTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMapEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + def vae_encode(self, context_sequence): + # normalize: [0, 1] -> [-1, 1] + context_sequence = context_sequence * 2 - 1 + context_sequence = self.encode_latents( + context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False + ) + context_sequence = context_sequence.permute(0, 2, 1, 3, 4) + return context_sequence + + def _normalize_latents(self, latents, latents_mean, latents_std): + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents + + @torch.no_grad() + def encode_latents(self, videos, sample=True): + videos = videos.permute(0, 2, 1, 3, 4) + moments = self.vae._encode(videos) + + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std) + + mu, logvar = torch.chunk(moments, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + + if sample: + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + + latents = torch.cat([mu, logvar], dim=1) + posterior = DiagonalGaussianDistribution(latents) + latents = posterior.sample(generator=None) + del posterior + else: + latents = mu + return latents + + def _denoise_rollout( + self, + context_sequence=None, + num_inference_steps: int = 50, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + guidance_scale: float = 1.0, + use_mean_velocity: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ): + device = self._execution_device + + if negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + context_length = context_sequence.shape[1] if context_sequence is not None else 0 + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + for i, t in enumerate(tqdm(timesteps[:-1])): + r = timesteps[i + 1] + + if t == r: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = timestep.repeat((1, latent_model_input.shape[1])) + + if use_mean_velocity: + r_timestep = r.expand(latent_model_input.shape[0]).unsqueeze(-1) + r_timestep = r_timestep.repeat((1, latent_model_input.shape[1])) + else: + r_timestep = timestep + + if context_sequence is not None: + latent_model_input[:, :context_length, ...] = context_sequence + timestep[:, :context_length] = 0 + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + r_timestep=r_timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, r_timestep=r, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs or []: + if k == "latents": + callback_kwargs[k] = latents + elif k == "prompt_embeds": + callback_kwargs[k] = prompt_embeds + elif k == "negative_prompt_embeds": + callback_kwargs[k] = negative_prompt_embeds + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + context_sequence: Optional[torch.Tensor] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + use_mean_velocity: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` + instead. + context_sequence (`torch.Tensor`, *optional*): + Pre-VAE conditioning frames of shape `(B, C, T, H, W)` in `[0, 1]`. When provided, the pipeline + VAE-encodes them and keeps the corresponding latent prefix fixed during sampling. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. Ignored when not using guidance + (`guidance_scale < 1`). + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. Must satisfy `(num_frames - 1) % + vae_scale_factor_temporal == 0`. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. Distilled AnyFlow checkpoints support any-step sampling, so + values as low as `1`, `2`, `4`, or `8` are typical. + guidance_scale (`float`, defaults to `1.0`): + Classifier-free guidance scale. The released AnyFlow checkpoints fuse CFG into the weights + during training; keep at `1.0` unless you know your checkpoint expects otherwise. + num_videos_per_prompt (`int`, *optional*, defaults to `1`): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents to use as inputs. If not provided, latents are sampled from the + supplied `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to tweak text inputs (e.g., prompt weighting). If + not provided, embeddings are generated from `prompt`. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"np"`): + The output format. One of `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return an [`AnyFlowPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Reserved for future use; currently forwarded to the transformer but not consumed. + callback_on_step_end (`Callable`, *optional*): + A function or [`PipelineCallback`] called at the end of each inference step. See + [`callbacks`](../callbacks) for details. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to `["latents"]`): + The tensor inputs forwarded to the callback. Must be a subset of + `self._callback_tensor_inputs`. + max_sequence_length (`int`, defaults to `512`): + The maximum text-encoder sequence length. Longer prompts are truncated. + use_mean_velocity (`bool`, defaults to `True`): + When `True`, the flow-map model is conditioned on both the source timestep `t` and the target + timestep `r` to predict a mean velocity, matching the training-time behavior. Disable to + mirror raw Euler stepping (`r = t`). + + Examples: + + Returns: + [`~AnyFlowPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`AnyFlowPipelineOutput`] is returned, otherwise a `tuple` whose + first element is the generated video. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._num_timesteps = num_inference_steps + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 5. Prepare latent variables. ``prepare_latents`` returns the standard ``(B, C, T, H, W)`` + # diffusers layout; the AnyFlow rollout expects ``(B, T, C, H, W)`` so we permute here. + num_channels_latents = self.transformer.config.in_channels + init_latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + init_latents = init_latents.permute(0, 2, 1, 3, 4).to(transformer_dtype) + + # setup start sequence + if context_sequence is not None: + context_sequence = self.vae_encode(context_sequence) + context_length = context_sequence.shape[1] + + latents = self._denoise_rollout( + context_sequence=context_sequence, + num_inference_steps=num_inference_steps, + latents=init_latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + use_mean_velocity=use_mean_velocity, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + if context_sequence is not None: + latents[:, :context_length, ...] = context_sequence + latents = latents.permute(0, 2, 1, 3, 4) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnyFlowPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py new file mode 100644 index 000000000000..3b26904cb82f --- /dev/null +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py @@ -0,0 +1,874 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for FAR causal flow-map sampling. + +import copy +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AnyFlowFARTransformer3DModel, AutoencoderKLWan +from ...models.autoencoders.vae import DiagonalGaussianDistribution +from ...schedulers import FlowMapEulerDiscreteScheduler +from ...utils import is_ftfy_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnyFlowPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import numpy as np + >>> import torch + >>> from diffusers import AnyFlowFARPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = AnyFlowFARPipeline.from_pretrained( + ... "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> # Single-frame I2V: wrap the conditioning image as a (1, 3, 1, H, W) tensor in [0, 1]. + >>> first_frame = load_image("path/to/first_frame.png").resize((832, 480)) + >>> arr = np.asarray(first_frame).astype("float32") / 255.0 + >>> context = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).to("cuda") + + >>> video = pipe( + ... prompt="a cat walks across a sunlit lawn", + ... context_sequence={"raw": context}, + ... num_inference_steps=4, + ... num_frames=81, + ... ).frames[0] + >>> export_to_video(video, "anyflow_far.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Causal (FAR-based) text-to-video / image-to-video / video-to-video pipeline for AnyFlow checkpoints. + + The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with + flow-map steps while attending only to past chunks via block-sparse causal attention, and intermediate + KV cache is reused across chunks. + + The task mode (T2V / I2V / V2V) is selected by the ``context_sequence`` argument passed to ``__call__``: + + - ``context_sequence=None`` — pure text-to-video. + - ``context_sequence={"raw":