diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 0613cd65d74d..5f836dd87294 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -198,6 +198,8 @@
title: Model accelerators and hardware
- isExpanded: false
sections:
+ - local: using-diffusers/anyflow
+ title: AnyFlow
- local: using-diffusers/helios
title: Helios
- local: using-diffusers/consisid
@@ -328,6 +330,10 @@
title: AceStepTransformer1DModel
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
+ - local: api/models/anyflow_transformer3d
+ title: AnyFlowTransformer3DModel
+ - local: api/models/anyflow_far_transformer3d
+ title: AnyFlowFARTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/transformer_bria_fibo
@@ -506,6 +512,8 @@
- sections:
- local: api/pipelines/animatediff
title: AnimateDiff
+ - local: api/pipelines/anyflow
+ title: AnyFlow
- local: api/pipelines/aura_flow
title: AuraFlow
- local: api/pipelines/bria_3_2
@@ -735,6 +743,8 @@
title: EulerAncestralDiscreteScheduler
- local: api/schedulers/euler
title: EulerDiscreteScheduler
+ - local: api/schedulers/flow_map_euler_discrete
+ title: FlowMapEulerDiscreteScheduler
- local: api/schedulers/flow_match_euler_discrete
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
diff --git a/docs/source/en/api/models/anyflow_far_transformer3d.md b/docs/source/en/api/models/anyflow_far_transformer3d.md
new file mode 100644
index 000000000000..d29c3fefc07d
--- /dev/null
+++ b/docs/source/en/api/models/anyflow_far_transformer3d.md
@@ -0,0 +1,45 @@
+
+
+# AnyFlowFARTransformer3DModel
+
+The causal (FAR) 3D Transformer used by [`AnyFlowFARPipeline`](../pipelines/anyflow#anyflowfarpipeline) —
+the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS
+ShowLab × NVIDIA). It extends the v0.35.1 Wan2.1 backbone with three additions:
+
+1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting frame-level autoregressive
+ generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325).
+2. **Compressed-frame patch embedding** (`far_patch_embedding`) for context (already-generated) frames,
+ warm-started from the full-resolution `patch_embedding` at construction time via trilinear interpolation.
+3. **Dual-timestep flow-map embedding** (same as
+ [`AnyFlowTransformer3DModel`](anyflow_transformer3d)) — every forward call conditions on both the source
+ timestep ``t`` and the target timestep ``r``.
+
+The chunk schedule (`chunk_partition`) is **not** baked into the model config. It is a per-call argument to
+`forward`, so the same checkpoint handles different `num_frames` configurations without retraining.
+
+```python
+from diffusers import AnyFlowFARTransformer3DModel
+
+# Causal AnyFlow checkpoint (FAR):
+transformer = AnyFlowFARTransformer3DModel.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", subfolder="transformer"
+)
+```
+
+## AnyFlowFARTransformer3DModel
+
+[[autodoc]] AnyFlowFARTransformer3DModel
+
+## AnyFlowFARTransformerOutput
+
+[[autodoc]] models.transformers.transformer_anyflow.AnyFlowFARTransformerOutput
diff --git a/docs/source/en/api/models/anyflow_transformer3d.md b/docs/source/en/api/models/anyflow_transformer3d.md
new file mode 100644
index 000000000000..95888080c0ce
--- /dev/null
+++ b/docs/source/en/api/models/anyflow_transformer3d.md
@@ -0,0 +1,36 @@
+
+
+# AnyFlowTransformer3DModel
+
+The bidirectional 3D Transformer used by [`AnyFlowPipeline`](../pipelines/anyflow#anyflowpipeline). It is the
+v0.35.1 Wan2.1 backbone with one structural change: the timestep embedder is replaced by
+``AnyFlowDualTimestepTextImageEmbedding``, so every forward call conditions on both the source timestep
+``t`` and the target timestep ``r``. This is the embedding required to learn the flow map
+:math:`\Phi_{r\leftarrow t}` introduced in
+[AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS ShowLab × NVIDIA).
+
+For frame-level autoregressive (FAR causal) generation, use
+[`AnyFlowFARTransformer3DModel`](anyflow_far_transformer3d) instead.
+
+```python
+from diffusers import AnyFlowTransformer3DModel
+
+# Bidirectional AnyFlow checkpoint (T2V):
+transformer = AnyFlowTransformer3DModel.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer"
+)
+```
+
+## AnyFlowTransformer3DModel
+
+[[autodoc]] AnyFlowTransformer3DModel
diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md
new file mode 100644
index 000000000000..c8948cdb8f59
--- /dev/null
+++ b/docs/source/en/api/pipelines/anyflow.md
@@ -0,0 +1,216 @@
+
+
+
+
+# AnyFlow
+
+[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA.
+
+*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.*
+
+The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow).
+
+The following AnyFlow checkpoints are supported:
+
+| Checkpoint | Backbone | Description |
+|------------|----------|-------------|
+| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight |
+| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality |
+| [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V |
+| [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / V2V |
+
+All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection.
+
+> [!TIP]
+> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling.
+
+> [!TIP]
+> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks.
+
+### Optimizing Memory and Inference Speed
+
+
+
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.hooks import apply_group_offloading
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+)
+apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
+pipe.vae.enable_slicing()
+pipe.vae.enable_tiling()
+```
+
+
+
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
+```
+
+
+
+
+### Generation with AnyFlow (Bidirectional T2V)
+
+
+
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.utils import export_to_video
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "A red panda eating bamboo in a forest, cinematic lighting"
+video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
+export_to_video(video, "out.mp4", fps=16)
+```
+
+
+
+
+### Generation with AnyFlow (FAR Causal)
+
+The causal pipeline selects between T2V / I2V / V2V via the ``context_sequence`` argument: pass ``None``
+for plain text-to-video, or a dict with a ``"raw"`` key holding a video tensor of shape
+``(B, C, T, H, W)`` with ``T = 4n + 1`` to condition on existing frames. Use a single conditioning frame
+for I2V and a longer clip for V2V continuation.
+
+> [!IMPORTANT]
+> `AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) is matched to the
+> released checkpoints' canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When
+> you change `num_frames`, you must also pass a matching `chunk_partition` summing to
+> `(num_frames - 1) // 4 + 1`, otherwise the pipeline raises an `AssertionError`.
+
+
+
+
+```py
+import torch
+from diffusers import AnyFlowFARPipeline
+from diffusers.utils import export_to_video
+
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+video = pipe(
+ prompt="A cat surfing a wave, sunset",
+ num_inference_steps=4,
+ num_frames=81,
+).frames[0]
+export_to_video(video, "out.mp4", fps=16)
+```
+
+
+
+
+```py
+import numpy as np
+import torch
+from diffusers import AnyFlowFARPipeline
+from diffusers.utils import export_to_video, load_image
+
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Wrap the conditioning image as a one-frame video tensor: (1, 3, 1, H, W) in [0, 1].
+first_frame = load_image("path/to/first_frame.png").resize((832, 480))
+arr = np.asarray(first_frame).astype("float32") / 255.0 # (480, 832, 3)
+context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).to("cuda")
+
+video = pipe(
+ prompt="a cat walks across a sunlit lawn",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+).frames[0]
+export_to_video(video, "out.mp4", fps=16)
+```
+
+
+
+
+```py
+import numpy as np
+import torch
+from diffusers import AnyFlowFARPipeline
+from diffusers.utils import export_to_video, load_video
+
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Context clip — 9 raw frames map to 3 latent frames (9 = 4·2 + 1, 3 = 2 + 1).
+context_frames = load_video("path/to/context.mp4")[:9]
+arr = np.stack([np.asarray(f.resize((832, 480))) for f in context_frames]).astype("float32") / 255.0
+context_tensor = torch.from_numpy(arr).permute(3, 0, 1, 2).unsqueeze(0).to("cuda") # (1, 3, 9, 480, 832)
+
+video = pipe(
+ prompt="continue the story",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+ # Override chunk_partition so the first chunk covers exactly the 3 latent context frames.
+ chunk_partition=[3, 3, 3, 3, 3, 3, 3],
+).frames[0]
+export_to_video(video, "out.mp4", fps=16)
+```
+
+
+
+
+## Notes
+
+- Classifier-free guidance is fused into the released checkpoints, so inference does not run a second guided forward pass. Keep the default `guidance_scale=1.0` unless your own checkpoint requires otherwise.
+- `FlowMapEulerDiscreteScheduler` is general-purpose. You can attach it to any flow-map-distilled checkpoint via `from_pretrained(..., scheduler=FlowMapEulerDiscreteScheduler.from_config(...))`.
+- `AnyFlowPipeline` uses [`AnyFlowTransformer3DModel`](../models/anyflow_transformer3d) (bidirectional). `AnyFlowFARPipeline` uses [`AnyFlowFARTransformer3DModel`](../models/anyflow_far_transformer3d), which adds a compressed-frame patch embedding and the FAR causal block-mask.
+- LoRA loading is supported via `WanLoraLoaderMixin`, the same mixin used by the upstream Wan pipelines.
+- For training recipes (forward flow-map training and on-policy distillation), refer to the original AnyFlow training framework at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow); training is out of scope for diffusers.
+
+## AnyFlowPipeline
+
+[[autodoc]] AnyFlowPipeline
+ - all
+ - __call__
+
+## AnyFlowFARPipeline
+
+[[autodoc]] AnyFlowFARPipeline
+ - all
+ - __call__
+
+## AnyFlowPipelineOutput
+
+[[autodoc]] pipelines.anyflow.pipeline_output.AnyFlowPipelineOutput
diff --git a/docs/source/en/api/schedulers/flow_map_euler_discrete.md b/docs/source/en/api/schedulers/flow_map_euler_discrete.md
new file mode 100644
index 000000000000..27a0c8612d70
--- /dev/null
+++ b/docs/source/en/api/schedulers/flow_map_euler_discrete.md
@@ -0,0 +1,28 @@
+
+
+# FlowMapEulerDiscreteScheduler
+
+`FlowMapEulerDiscreteScheduler` is an Euler-style sampler designed for flow-map-distilled diffusion
+models. Flow-map models learn arbitrary-interval transitions $\mathbf{z}_t \to \mathbf{z}_r$ rather than
+the fixed $\mathbf{z}_t \to \mathbf{z}_0$ mapping of consistency models. Both endpoints of the step are
+caller-provided, which is what enables any-step sampling: a single distilled checkpoint can be evaluated at
+1, 2, 4, 8, 16... NFE without retraining.
+
+The scheduler was introduced in
+[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724)
+and ships with the `AnyFlowPipeline` and `AnyFlowFARPipeline` integrations, but it is not
+AnyFlow-specific — any flow-map-distilled checkpoint can use it.
+
+## FlowMapEulerDiscreteScheduler
+
+[[autodoc]] FlowMapEulerDiscreteScheduler
diff --git a/docs/source/en/using-diffusers/anyflow.md b/docs/source/en/using-diffusers/anyflow.md
new file mode 100644
index 000000000000..9bf3ba13f258
--- /dev/null
+++ b/docs/source/en/using-diffusers/anyflow.md
@@ -0,0 +1,260 @@
+
+
+# AnyFlow
+
+[AnyFlow](https://huggingface.co/papers/2605.13724) is a video diffusion **distillation** framework that turns
+a pretrained Wan2.1 teacher into an *any-step* student under standard Euler sampling. A single distilled
+checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining and quality scales **monotonically**
+with steps — unlike consistency models, which often degrade as NFE grows.
+
+The key idea is to learn the **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$ for arbitrary
+$1 \ge t \ge r \ge 0$ instead of the fixed endpoint map $\mathbf{z}_t \to \mathbf{z}_0$ used by consistency
+models. Composability of the flow map removes re-noising between sampling steps; on-policy distillation with
+**DMD reverse-divergence supervision** plus **Flow-Map backward simulation** (3-segment shortcut) closes the
+exposure-bias gap that consistency-based distillation leaves open.
+
+AnyFlow was developed by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA. The original training code lives at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow); the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow). The four released checkpoints are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection.
+
+This guide walks through the practical decisions: which pipeline to pick, how to use any-step sampling, and
+how to plug AnyFlow into typical T2V / I2V / V2V workflows.
+
+## Bidirectional vs causal — pick a pipeline
+
+AnyFlow ships in two flavors that share the same scheduler and the same flow-map distillation but differ in
+how they sample frames:
+
+- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) — **bidirectional** T2V. Denoises the entire
+ video tensor in one pass with global self-attention. Use this when the input is a single text prompt and you
+ do not need streaming output.
+- [`AnyFlowFARPipeline`](../api/pipelines/anyflow#anyflowfarpipeline) — **causal (FAR)**. Denoises the
+ video chunk by chunk with block-sparse causal attention and reuses KV cache across chunks. Use this for
+ image-to-video (I2V), video-to-video (V2V) continuation, or any setup that benefits from frame-level
+ autoregressive sampling. The same model handles all three task modes via the `context_sequence` argument.
+
+A quick selector:
+
+| Scenario | Pipeline | How to invoke |
+|----------|----------|---------------|
+| Pure text-to-video, max quality at fixed NFE | `AnyFlowPipeline` | `pipe(prompt, ...)` |
+| Image-to-video (start from a still image) | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` |
+| Video continuation / V2V | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` |
+| Streaming / progressive generation | `AnyFlowFARPipeline` | — |
+
+The bidirectional variant is faster per token at high resolution; the causal variant trades that for the
+ability to start sampling before all latent frames are allocated, useful for very long sequences.
+
+## Loading checkpoints
+
+NVIDIA released four AnyFlow checkpoints, one per pipeline + scale combination:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline, AnyFlowFARPipeline
+
+# Bidirectional, lightweight
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Bidirectional, full quality
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Causal (FAR), 1.3B
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Causal (FAR), 14B
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+```
+
+All four use the same [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete) with
+`shift=5.0` baked in.
+
+## Any-step sampling
+
+The defining feature of AnyFlow is that the same checkpoint produces increasing quality as you raise NFE,
+with no schedule retuning. Sweep step counts on a fixed prompt to see how the model trades latency for
+fidelity:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.utils import export_to_video
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "A red panda eating bamboo in a forest, cinematic lighting"
+
+for nfe in [1, 2, 4, 8, 16, 32]:
+ # Re-seed the generator inside the loop so the only changing variable across runs is NFE.
+ generator = torch.Generator("cuda").manual_seed(0)
+ video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0]
+ export_to_video(video, f"out_nfe{nfe}.mp4", fps=16)
+```
+
+In our benchmarks (paper Tab 3 / Fig 1) every AnyFlow checkpoint improves monotonically from 4 → 32 NFE
+on VBench Quality, while consistency-based baselines (rCM, Self-Forcing) degrade in the same regime.
+
+> [!TIP]
+> Classifier-free guidance (CFG) was *fused* into the released model weights during training.
+> The pipeline does not run a second guided forward pass at inference time —
+> guidance comes from the distilled weights themselves. Leave `guidance_scale=1.0` (the default) for the
+> released checkpoints.
+
+## Image-to-video and video-to-video
+
+The causal pipeline supports three task modes from a single distilled model. The mode is selected
+implicitly by the ``context_sequence`` argument (a dict with a ``"raw"`` video tensor or ``"latent"``
+pre-encoded latents). Frame counts in the context tensor must satisfy ``T = 4n + 1`` to align with the
+VAE temporal stride.
+
+> [!IMPORTANT]
+> The FAR pipeline runs a chunked rollout, and `num_frames` must agree with the chunk schedule. The default
+> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]` sums to 21 latent frames, which is matched to the released
+> checkpoints' canonical `num_frames=81` (21 = (81 − 1) // 4 + 1). When you change `num_frames`, you **must**
+> pass a matching `chunk_partition` whose entries sum to `(num_frames - 1) // 4 + 1`, otherwise the pipeline
+> raises an `AssertionError`. For example, `num_frames=33` corresponds to 9 latent frames, so a valid
+> override is `chunk_partition=[1, 4, 4]`.
+
+```py
+import numpy as np
+import torch
+from diffusers import AnyFlowFARPipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+
+def to_video_tensor(images, height=480, width=832):
+ """Convert a list of PIL images into the (B, C, T, H, W) [0, 1] tensor the FAR pipeline consumes."""
+ frames = np.stack([np.asarray(img.resize((width, height))) for img in images]).astype("float32") / 255.0
+ return torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, T, H, W)
+
+
+# 1) Text-to-video (no context). 81 frames matches the default chunk_partition.
+video = pipe(prompt="A cat surfing a wave at sunset", num_inference_steps=4, num_frames=81).frames[0]
+export_to_video(video, "t2v.mp4", fps=16)
+
+# 2) Image-to-video — a single conditioning frame produces 1 latent frame, which exactly fits the first
+# entry of the default chunk_partition (`[1, 3, 3, ...]`).
+first_frame = load_image("path/to/first_frame.png")
+context_tensor = to_video_tensor([first_frame]).to("cuda") # (1, 3, 1, 480, 832), [0, 1]
+video = pipe(
+ prompt="a cat walks across a sunlit lawn",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+).frames[0]
+export_to_video(video, "i2v.mp4", fps=16)
+
+# 3) Video-to-video continuation. A 9-frame raw context maps to 3 latent frames; we override
+# chunk_partition so the first chunk covers the whole context exactly.
+context_frames = load_video("path/to/context.mp4")[:9] # 9 = 4·2 + 1
+context_tensor = to_video_tensor(context_frames).to("cuda") # (1, 3, 9, 480, 832)
+video = pipe(
+ prompt="continue the story",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+ chunk_partition=[3, 3, 3, 3, 3, 3, 3], # 7 chunks × 3 = 21 latent frames; first chunk = context
+).frames[0]
+export_to_video(video, "v2v.mp4", fps=16)
+```
+
+Internally, the patchification chunk schedule depends on whether (and how long) ``context_sequence`` is set:
+without context the model uses kernel sizes 2 (full) and 4 (compressed); with a context clip the first chunk
+uses kernel size 1 so the conditioning frames keep full resolution.
+
+If you already have VAE-encoded latents, pass them via ``context_sequence={"latent": ...}`` to skip the
+``vae_encode`` step.
+
+## Memory and inference speed
+
+A 14B AnyFlow model fits on a single 40 GB device with group offloading + VAE slicing:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.hooks import apply_group_offloading
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+)
+apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
+pipe.vae.enable_slicing()
+pipe.vae.enable_tiling()
+```
+
+For latency, `torch.compile` works well on the transformer (the heaviest module by far):
+
+```py
+pipe = pipe.to("cuda")
+pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
+```
+
+Compile costs are amortized after a few steps; combined with low NFE (4–8 for AnyFlow) `torch.compile`
+delivers a noticeable speedup over eager mode on the 14B path.
+
+## LoRA fine-tuning
+
+Both pipelines reuse [`WanLoraLoaderMixin`](../api/loaders/lora), so any LoRA adapter trained for the
+matching Wan2.1 backbone loads directly:
+
+```py
+pipe.load_lora_weights("path/or/repo/with/wan_lora")
+```
+
+For continued **on-policy** fine-tuning with DMD-style reverse-divergence supervision (the same recipe used
+to produce the released checkpoints), refer to the original AnyFlow training framework at
+[`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow), which is out of scope for diffusers.
+
+## Common gotchas
+
+- **Always-1.0 `guidance_scale`.** The distilled checkpoints already encode CFG. Setting `guidance_scale > 1`
+ will run a redundant unconditional pass, double the latency, and slightly hurt quality.
+- **Bidirectional pipeline does not stream.** All `num_frames` worth of latents are denoised together. Use
+ the causal pipeline if you want to start playback before sampling completes.
+- **Causal pipeline KV cache assumes the chunk schedule is consistent across calls.** Rebuilding the cache
+ mid-generation is not supported by the released model.
+- **`num_frames` must satisfy the VAE temporal stride.** Use values of the form `(N - 1) % 4 == 0` (e.g., 9,
+ 17, 33, 81) for the released checkpoints.
+
+## Citation
+
+```bibtex
+@misc{gu2026anyflowanystepvideodiffusion,
+ title={AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation},
+ author={Yuchao Gu and Guian Fang and Yuxin Jiang and Weijia Mao and Song Han and Han Cai and Mike Zheng Shou},
+ year={2026},
+ eprint={2605.13724},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2605.13724},
+}
+
+@article{gu2025long,
+ title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction},
+ author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng},
+ journal={arXiv preprint arXiv:2503.19325},
+ year={2025}
+}
+```
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index af51506746b2..b49820dd76e7 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -130,6 +130,8 @@
- title: Specific pipeline examples
isExpanded: false
sections:
+ - local: using-diffusers/anyflow
+ title: AnyFlow
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/helios
diff --git a/docs/source/zh/using-diffusers/anyflow.md b/docs/source/zh/using-diffusers/anyflow.md
new file mode 100644
index 000000000000..9a20f2eafd9d
--- /dev/null
+++ b/docs/source/zh/using-diffusers/anyflow.md
@@ -0,0 +1,244 @@
+
+
+# AnyFlow
+
+[AnyFlow](https://huggingface.co/papers/2605.13724) 是一个视频扩散**蒸馏**框架,把预训练的 Wan2.1 教师
+模型蒸馏成在标准 Euler 采样下支持*任意步数 (any-step)* 的学生模型。同一个蒸馏出来的 checkpoint 可以
+在 1、2、4、8、16... NFE 下推理,**质量随步数单调提升** —— 这一点和 consistency models 不同,后者
+NFE 增加反而经常掉点。
+
+核心思路是学习 **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$(任意 $1 \ge t \ge r \ge 0$),
+而不是 consistency models 学的固定端点映射 $\mathbf{z}_t \to \mathbf{z}_0$。Flow map 的可组合性消除了
+采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation**
+(3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。
+
+AnyFlow 由 Yuchao Gu、Guian Fang 等人在 [NUS ShowLab](https://sites.google.com/view/showlab) 与 NVIDIA 合作完成。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)。4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。
+
+本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / V2V 工作流。
+
+## Bidirectional 还是 Causal —— 怎么选 pipeline
+
+AnyFlow 提供两个 pipeline 形态,scheduler 和蒸馏方法相同,区别在于**怎么对帧采样**:
+
+- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) —— **bidirectional** T2V。一次性对整个
+ 视频张量去噪,全局自注意力。**纯 prompt 输入、不要流式输出**时选这个。
+- [`AnyFlowFARPipeline`](../api/pipelines/anyflow#anyflowfarpipeline) —— **causal (FAR)**。
+ 按 chunk 分段去噪,块稀疏因果注意力 + 跨 chunk 复用 KV cache。**图生视频 (I2V)**、**视频续写 (V2V)**、
+ 或任何受益于逐帧自回归采样的场景选这个。同一个模型通过传入 `context_sequence` 来切换三种任务模式。
+
+简化对照表:
+
+| 场景 | Pipeline | 调用方式 |
+|------|----------|----------|
+| 纯文生视频,固定 NFE 求最大质量 | `AnyFlowPipeline` | `pipe(prompt, ...)` |
+| 图生视频(首帧给定) | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": <单帧 tensor>}, ...)` |
+| 视频续写 / V2V | `AnyFlowFARPipeline` | `pipe(prompt, context_sequence={"raw": <多帧 tensor>}, ...)` |
+| 流式 / 渐进式生成 | `AnyFlowFARPipeline` | — |
+
+高分辨率下 bidirectional 单 token 更快;causal 牺牲一点单步速度,换来在所有 latent 帧分配前就能开始
+采样的能力,对超长序列尤其有用。
+
+## 加载 checkpoint
+
+NVIDIA 发布了 4 个 AnyFlow checkpoint,pipeline × 规模各一份:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline, AnyFlowFARPipeline
+
+# Bidirectional, 轻量
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Bidirectional, 满血
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Causal (FAR), 1.3B
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+# Causal (FAR), 14B
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+```
+
+四个 checkpoint 共用同一份 [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete),
+默认 `shift=5.0`。
+
+## Any-step 采样
+
+AnyFlow 最关键的特性是同一个 checkpoint **不需重新调度**,NFE 越大质量越高。固定 prompt、扫一下步数
+就能看出模型怎么在延迟和保真度之间权衡:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.utils import export_to_video
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "森林里一只小熊猫在啃竹子,电影感光照"
+
+for nfe in [1, 2, 4, 8, 16, 32]:
+ # 每轮重建 generator —— 这样跨步数对比时唯一变量是 NFE。
+ generator = torch.Generator("cuda").manual_seed(0)
+ video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0]
+ export_to_video(video, f"out_nfe{nfe}.mp4", fps=16)
+```
+
+paper 的 Tab 3 / Fig 1 表明:每个 AnyFlow checkpoint 在 4 → 32 NFE 范围 VBench Quality 都单调上升,而
+consistency 类基线(rCM、Self-Forcing)在同区间反而掉点。
+
+> [!TIP]
+> Classifier-free guidance (CFG) 已经在训练阶段融进权重。pipeline 推理
+> 时**不会**再跑一次 unconditional 前向 —— guidance 直接由蒸馏后的权重带出。release 出来的 checkpoint
+> 都用默认的 `guidance_scale=1.0` 即可。
+
+## 图生视频 与 视频续写
+
+Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `context_sequence` 隐式选择**(dict,含
+`"raw"` 视频张量或 `"latent"` 已编码 latent)。Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE
+时间步长对齐。
+
+> [!IMPORTANT]
+> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。默认
+> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21)对应发布 checkpoint 的标准 `num_frames=81`
+> (21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`,使其求和等于
+> `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `AssertionError`。比如 `num_frames=33` 对应 9 个 latent
+> 帧,可用 `chunk_partition=[1, 4, 4]`。
+
+```py
+import numpy as np
+import torch
+from diffusers import AnyFlowFARPipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+pipe = AnyFlowFARPipeline.from_pretrained(
+ "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+).to("cuda")
+
+
+def to_video_tensor(images, height=480, width=832):
+ """把 PIL 列表转成 FAR pipeline 需要的 (B, C, T, H, W) [0, 1] 张量。"""
+ frames = np.stack([np.asarray(img.resize((width, height))) for img in images]).astype("float32") / 255.0
+ return torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0) # (1, C, T, H, W)
+
+
+# 1) 文生视频(无 context)。81 帧匹配默认 chunk_partition。
+video = pipe(prompt="一只猫在夕阳下冲浪", num_inference_steps=4, num_frames=81).frames[0]
+export_to_video(video, "t2v.mp4", fps=16)
+
+# 2) 图生视频 —— 单帧 context 经过 VAE 是 1 个 latent,正好对上默认 chunk_partition 的第一项 (`[1, ...]`)。
+first_frame = load_image("path/to/first_frame.png")
+context_tensor = to_video_tensor([first_frame]).to("cuda") # (1, 3, 1, 480, 832), [0, 1]
+video = pipe(
+ prompt="一只猫走过阳光下的草坪",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+).frames[0]
+export_to_video(video, "i2v.mp4", fps=16)
+
+# 3) 视频续写。9 帧 raw context → 3 个 latent context;显式覆盖 chunk_partition,让第一块正好覆盖 context。
+context_frames = load_video("path/to/context.mp4")[:9] # 9 = 4·2 + 1
+context_tensor = to_video_tensor(context_frames).to("cuda") # (1, 3, 9, 480, 832)
+video = pipe(
+ prompt="继续这个故事",
+ context_sequence={"raw": context_tensor},
+ num_inference_steps=4,
+ num_frames=81,
+ chunk_partition=[3, 3, 3, 3, 3, 3, 3], # 7 个 chunk × 3 = 21 latent;首块就是 context
+).frames[0]
+export_to_video(video, "v2v.mp4", fps=16)
+```
+
+底层 patchify chunk 调度根据 `context_sequence` 自动调整:纯文生用 kernel 2 (full) 和 4 (compressed);
+有 context 时第一个 chunk 改成 kernel 1,让条件帧保留全分辨率。
+
+如果你已经有 VAE 编码过的 latent,可以直接传 `context_sequence={"latent": ...}` 跳过 `vae_encode` 步骤。
+
+## 显存与推理速度
+
+14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑:
+
+```py
+import torch
+from diffusers import AnyFlowPipeline
+from diffusers.hooks import apply_group_offloading
+
+pipe = AnyFlowPipeline.from_pretrained(
+ "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+)
+apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
+pipe.vae.enable_slicing()
+pipe.vae.enable_tiling()
+```
+
+延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好:
+
+```py
+pipe = pipe.to("cuda")
+pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
+```
+
+编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),`torch.compile` 在 14B 上相比 eager
+模式有明显加速。
+
+## LoRA 微调
+
+两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的
+LoRA adapter 直接加载即可:
+
+```py
+pipe.load_lora_weights("path/or/repo/with/wan_lora")
+```
+
+如果要做**继续 on-policy 蒸馏微调**(用论文里相同的 DMD 反向散度监督配方训新 LoRA),请参考原始
+AnyFlow 训练框架 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),这套训练流程不在
+diffusers 范围内。
+
+## 常见坑
+
+- **永远 `guidance_scale=1.0`。** 蒸馏后的 checkpoint 已经把 CFG 融进权重。设 `> 1` 会多跑一遍
+ unconditional 前向、延迟翻倍、质量微降。
+- **Bidirectional pipeline 不支持流式。** 所有 `num_frames` 一起去噪。需要边采边播请用 causal pipeline。
+- **Causal pipeline KV cache 假设 chunk 调度跨调用一致。** 中途重建 cache 不被 release 模型支持。
+- **`num_frames` 必须满足 VAE 时间步长。** release checkpoint 用 `(N - 1) % 4 == 0` 的值(如 9、17、33、81)。
+
+## 引用
+
+```bibtex
+@misc{gu2026anyflowanystepvideodiffusion,
+ title={AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation},
+ author={Yuchao Gu and Guian Fang and Yuxin Jiang and Weijia Mao and Song Han and Han Cai and Mike Zheng Shou},
+ year={2026},
+ eprint={2605.13724},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2605.13724},
+}
+
+@article{gu2025long,
+ title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction},
+ author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng},
+ journal={arXiv preprint arXiv:2503.19325},
+ year={2025}
+}
+```
diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py
new file mode 100644
index 000000000000..60574ca23a1e
--- /dev/null
+++ b/scripts/convert_anyflow_to_diffusers.py
@@ -0,0 +1,152 @@
+# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Convert AnyFlow training checkpoints to the diffusers ``save_pretrained`` layout.
+
+The AnyFlow training pipeline emits ``.pt`` files containing an ``ema`` key whose value is a flat state
+dict for the transformer. This script:
+
+1. Loads the matching base Wan2.1 pipeline from the Hub (provides VAE, tokenizer, and text encoder).
+2. Constructs an ``AnyFlowTransformer3DModel`` with the right config flags for the chosen variant.
+3. Loads the ``ema`` weights into the transformer.
+4. Wraps everything in an ``AnyFlowPipeline`` (bidirectional) or ``AnyFlowFARPipeline`` (FAR causal).
+5. Calls ``pipeline.save_pretrained(output_dir)``.
+
+Example:
+
+```bash
+python scripts/convert_anyflow_to_diffusers.py \\
+ --variant AnyFlow-FAR-Wan2.1-1.3B-Diffusers \\
+ --ckpt /path/to/anyflow-checkpoint.pt \\
+ --output-dir /path/to/output/AnyFlow-FAR-Wan2.1-1.3B-Diffusers
+```
+"""
+
+import argparse
+import logging
+import os
+
+import torch
+
+from diffusers import (
+ AnyFlowFARPipeline,
+ AnyFlowFARTransformer3DModel,
+ AnyFlowPipeline,
+ AnyFlowTransformer3DModel,
+ FlowMapEulerDiscreteScheduler,
+)
+
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
+
+
+# Per-variant configuration. ``base_model`` is fetched from the Hub to source the matching VAE / text encoder.
+VARIANTS = {
+ "AnyFlow-FAR-Wan2.1-1.3B-Diffusers": {
+ "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ "transformer_cls": AnyFlowFARTransformer3DModel,
+ "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
+ "pipeline_cls": AnyFlowFARPipeline,
+ },
+ "AnyFlow-FAR-Wan2.1-14B-Diffusers": {
+ "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ "transformer_cls": AnyFlowFARTransformer3DModel,
+ "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
+ "pipeline_cls": AnyFlowFARPipeline,
+ },
+ "AnyFlow-Wan2.1-T2V-1.3B-Diffusers": {
+ "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ "transformer_cls": AnyFlowTransformer3DModel,
+ "transformer_kwargs": {},
+ "pipeline_cls": AnyFlowPipeline,
+ },
+ "AnyFlow-Wan2.1-T2V-14B-Diffusers": {
+ "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ "transformer_cls": AnyFlowTransformer3DModel,
+ "transformer_kwargs": {},
+ "pipeline_cls": AnyFlowPipeline,
+ },
+}
+
+
+def build_pipeline(variant: str, ckpt_path: str):
+ if variant not in VARIANTS:
+ raise ValueError(f"Unknown variant {variant!r}. Choices: {list(VARIANTS)}.")
+ spec = VARIANTS[variant]
+
+ transformer = spec["transformer_cls"].from_pretrained(
+ spec["base_model"],
+ subfolder="transformer",
+ gate_value=0.25,
+ deltatime_type="r",
+ **spec["transformer_kwargs"],
+ )
+ # NVlabs/AnyFlow training checkpoints are wrapped Python objects (the `ema` key carries metadata
+ # alongside tensors), so the unpickle is required. Only run this script on checkpoints you trust.
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["ema"]
+ missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
+ if unexpected:
+ logger.warning(
+ "Unexpected keys in state dict (ignored): %s%s",
+ unexpected[:5],
+ "..." if len(unexpected) > 5 else "",
+ )
+ if missing:
+ logger.warning(
+ "Missing keys not loaded from state dict: %s%s",
+ missing[:5],
+ "..." if len(missing) > 5 else "",
+ )
+
+ scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0)
+
+ pipeline = spec["pipeline_cls"].from_pretrained(
+ spec["base_model"],
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ return pipeline
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert an AnyFlow training checkpoint into a diffusers pipeline directory."
+ )
+ parser.add_argument(
+ "--variant",
+ required=True,
+ choices=list(VARIANTS),
+ help="Which AnyFlow variant the checkpoint corresponds to.",
+ )
+ parser.add_argument(
+ "--ckpt",
+ required=True,
+ help="Path to the AnyFlow training checkpoint (a .pt file containing an 'ema' key).",
+ )
+ parser.add_argument(
+ "--output-dir",
+ required=True,
+ help="Destination directory for pipeline.save_pretrained.",
+ )
+ args = parser.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ pipeline = build_pipeline(args.variant, args.ckpt)
+ pipeline.save_pretrained(args.output_dir)
+ logger.info("Saved %s pipeline to %s", args.variant, args.output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index d120d0a22818..3a8332dc0c3a 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -191,6 +191,8 @@
[
"AceStepTransformer1DModel",
"AllegroTransformer3DModel",
+ "AnyFlowFARTransformer3DModel",
+ "AnyFlowTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionBackendName",
"AuraFlowTransformer2DModel",
@@ -380,6 +382,7 @@
"EDMEulerScheduler",
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
+ "FlowMapEulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"FlowMatchLCMScheduler",
@@ -511,6 +514,8 @@
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
+ "AnyFlowFARPipeline",
+ "AnyFlowPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
@@ -1019,6 +1024,8 @@
from .models import (
AceStepTransformer1DModel,
AllegroTransformer3DModel,
+ AnyFlowFARTransformer3DModel,
+ AnyFlowTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionBackendName,
AuraFlowTransformer2DModel,
@@ -1204,6 +1211,7 @@
EDMEulerScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
+ FlowMapEulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
FlowMatchLCMScheduler,
@@ -1316,6 +1324,8 @@
AnimateDiffSparseControlNetPipeline,
AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
+ AnyFlowFARPipeline,
+ AnyFlowPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index ff8e16aad447..05d29aabb7e0 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -95,6 +95,10 @@
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
+ _import_structure["transformers.transformer_anyflow"] = [
+ "AnyFlowFARTransformer3DModel",
+ "AnyFlowTransformer3DModel",
+ ]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
@@ -214,6 +218,8 @@
from .transformers import (
AceStepTransformer1DModel,
AllegroTransformer3DModel,
+ AnyFlowFARTransformer3DModel,
+ AnyFlowTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 156b54e7f07d..4b0a55d0a8ad 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -18,6 +18,7 @@
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
+ from .transformer_anyflow import AnyFlowFARTransformer3DModel, AnyFlowTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py
new file mode 100644
index 000000000000..55fadc317559
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_anyflow.py
@@ -0,0 +1,1683 @@
+# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This file derives from the FAR architecture (Gu et al., 2025, arXiv:2503.19325) and adds AnyFlow's
+# dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding). The base 3D DiT structure
+# is adapted from the v0.35.1 Wan2.1 transformer (transformer_wan.py); upstream Wan has since been
+# refactored, so this file is intentionally self-contained rather than annotated with `# Copied from`.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.attention.flex_attention import create_block_mask
+from torch.nn.attention.flex_attention import flex_attention as _flex_attention_eager
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import BaseOutput, logging
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class AnyFlowFARTransformerOutput(BaseOutput):
+ """
+ Output dataclass for ``AnyFlowTransformer3DModel``'s causal forward paths.
+
+ Args:
+ sample (`torch.Tensor` or `None`):
+ Predicted denoising target for the autoregressive chunk. ``None`` for the cache-prefill path,
+ which only writes the KV cache and produces no usable sample.
+ kv_cache (`list[dict[str, torch.Tensor]]`, *optional*):
+ Per-block KV cache state used by subsequent autoregressive steps.
+ """
+
+ sample: Optional[torch.Tensor] = None
+ kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None
+
+
+# `flex_attention` is compiled lazily on first CUDA use. The compiled kernel goes through
+# Triton/Inductor C++ codegen which can fail on small or unusual shapes seen in fast tests
+# (e.g. CPU-side `pipe.to("cpu")` with tiny dummy components), so we always fall back to the
+# eager kernel on CPU. Both paths produce identical numeric output — the wrapper only differs
+# in execution speed.
+_flex_attention_compiled = None
+
+
+def _get_compiled_flex_attention():
+ global _flex_attention_compiled
+ if _flex_attention_compiled is None:
+ try:
+ _flex_attention_compiled = torch.compile(_flex_attention_eager, dynamic=True)
+ except Exception as e: # pragma: no cover - environment-dependent
+ logger.warning(
+ "Failed to torch.compile flex_attention; falling back to the eager kernel. Error: %s",
+ e,
+ )
+ _flex_attention_compiled = _flex_attention_eager
+ return _flex_attention_compiled
+
+
+def flex_attention(query, key, value, *args, **kwargs):
+ """Dispatch to the compiled flex_attention on CUDA (fast path) or the eager one on CPU
+ (avoids Triton/Inductor codegen failures on tiny test shapes)."""
+ if query.device.type == "cuda":
+ return _get_compiled_flex_attention()(query, key, value, *args, **kwargs)
+ return _flex_attention_eager(query, key, value, *args, **kwargs)
+
+
+def build_block_mask(mask_2d, device):
+ if mask_2d.dim() != 2 or mask_2d.dtype != torch.bool:
+ raise ValueError(
+ f"`mask_2d` must be a 2D boolean tensor, got shape {tuple(mask_2d.shape)} and dtype {mask_2d.dtype}."
+ )
+ mask_2d = mask_2d.contiguous()
+
+ Q_LEN, KV_LEN = mask_2d.shape
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ return mask_2d[q_idx, kv_idx]
+
+ return create_block_mask(mask_mod, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, _compile=False)
+
+
+def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ # MPS / NPU backends do not support complex128 / float64; fall back to float32 on those devices.
+ is_mps = hidden_states.device.type == "mps"
+ is_npu = hidden_states.device.type == "npu"
+ rotary_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ x_rotated = torch.view_as_complex(hidden_states.to(rotary_dtype).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+
+class AnyFlowAttnProcessor:
+ """
+ Self-attention processor for AnyFlow. Supports three modes:
+
+ * Bidirectional (``attention_mask`` is ``None``) — standard SDPA / dispatched backend.
+ * FAR causal — :class:`~torch.nn.attention.flex_attention.BlockMask` via ``flex_attention``.
+ * Autoregressive inference with KV cache (read or write) for the FAR variant.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AnyFlowAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "AnyFlowAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[Any] = None,
+ rotary_emb: Optional[Dict[str, torch.Tensor]] = None,
+ kv_cache: Optional[Dict[str, torch.Tensor]] = None,
+ kv_cache_flag: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Layout (B, H, L, D) is required by KV-cache slicing, rotary application, and flex_attention.
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if kv_cache is not None:
+ if kv_cache_flag["is_cache_step"]:
+ kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_compressed_tokens"], :] = key[
+ :, :, : kv_cache_flag["num_compressed_tokens"]
+ ]
+ kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_compressed_tokens"], :] = value[
+ :, :, : kv_cache_flag["num_compressed_tokens"]
+ ]
+ kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_full_tokens"], :] = key[
+ :, :, kv_cache_flag["num_compressed_tokens"] :
+ ]
+ kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_full_tokens"], :] = value[
+ :, :, kv_cache_flag["num_compressed_tokens"] :
+ ]
+ else:
+ key = torch.cat(
+ [
+ kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :],
+ kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_cached_full_tokens"], :],
+ key,
+ ],
+ dim=2,
+ )
+ value = torch.cat(
+ [
+ kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :],
+ kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_cached_full_tokens"], :],
+ value,
+ ],
+ dim=2,
+ )
+
+ if rotary_emb is not None:
+ query = apply_rotary_emb(query, rotary_emb["query"])
+ key = apply_rotary_emb(key, rotary_emb["key"])
+
+ if attention_mask is None:
+ # Pass (B, L, H, D) to dispatch_attention_fn (its native backend permutes back to (B, H, L, D)
+ # internally before calling the kernel).
+ hidden_states = dispatch_attention_fn(
+ query.transpose(1, 2),
+ key.transpose(1, 2),
+ value.transpose(1, 2),
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ # Output already in (B, L, H, D); fold heads into the channel dim.
+ hidden_states = hidden_states.flatten(2, 3)
+ else:
+ # FAR causal path: BlockMask is consumed by flex_attention only, so we keep the
+ # (B, H, L, D) layout and pad to a multiple of 128 for the BlockMask block size.
+ seq_len = query.shape[2]
+ head_dim = query.shape[3]
+ padded_length = int(math.ceil(seq_len / 128.0) * 128.0 - seq_len)
+ query = torch.cat(
+ [
+ query,
+ torch.zeros(
+ [query.shape[0], query.shape[1], padded_length, query.shape[3]],
+ device=query.device,
+ dtype=query.dtype,
+ ),
+ ],
+ dim=2,
+ )
+ key = torch.cat(
+ [
+ key,
+ torch.zeros(
+ [key.shape[0], key.shape[1], padded_length, key.shape[3]],
+ device=key.device,
+ dtype=key.dtype,
+ ),
+ ],
+ dim=2,
+ )
+ value = torch.cat(
+ [
+ value,
+ torch.zeros(
+ [value.shape[0], value.shape[1], padded_length, value.shape[3]],
+ device=value.device,
+ dtype=value.dtype,
+ ),
+ ],
+ dim=2,
+ )
+ # flex_attention requires head_dim >= 16. When tiny dummy components use head_dim < 16
+ # (real ckpts use 128) we right-pad q/k/v with zeros and pass an explicit scale matched
+ # to the original head_dim; padded value rows contribute 0, so trimming back preserves
+ # output equivalence.
+ head_pad = max(0, 16 - head_dim)
+ scale = 1.0 / (head_dim**0.5) if head_pad > 0 else None
+ if head_pad > 0:
+ query = F.pad(query, (0, head_pad))
+ key = F.pad(key, (0, head_pad))
+ value = F.pad(value, (0, head_pad))
+ hidden_states = flex_attention(query, key, value, block_mask=attention_mask, scale=scale)
+ if head_pad > 0:
+ hidden_states = hidden_states[..., :head_dim]
+ hidden_states = hidden_states[:, :, :seq_len]
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+
+ hidden_states = hidden_states.type_as(query)
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class AnyFlowCrossAttnProcessor:
+ """
+ Cross-attention processor for AnyFlow. Always uses the dispatched SDPA-compatible backend; no rotary
+ embedding or KV cache is applied to the text→video cross-attention path.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AnyFlowCrossAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "AnyFlowAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # (B, L, H, D) layout for dispatch_attention_fn.
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class AnyFlowAttention(torch.nn.Module, AttentionModuleMixin):
+ """
+ Attention module used by :class:`AnyFlowTransformerBlock`. Layout matches the legacy
+ :class:`~diffusers.models.attention_processor.Attention` so existing AnyFlow checkpoints load
+ bit-exactly into this class.
+ """
+
+ _default_processor_cls = AnyFlowAttnProcessor
+ _available_processors = [AnyFlowAttnProcessor, AnyFlowCrossAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ dim_head: int,
+ eps: float = 1e-6,
+ processor: Optional[Any] = None,
+ ):
+ super().__init__()
+ self.heads = heads
+ self.inner_dim = heads * dim_head
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(0.0),
+ ]
+ )
+ # ``rms_norm_across_heads`` per-axis: normalize Q and K across the entire ``heads * dim_head``
+ # channel axis. We use diffusers' RMSNorm (rather than ``torch.nn.RMSNorm``) so the numerics
+ # match the legacy Attention class that produced the released checkpoints.
+ self.norm_q = RMSNorm(self.inner_dim, eps=eps)
+ self.norm_k = RMSNorm(self.inner_dim, eps=eps)
+
+ self.set_processor(processor if processor is not None else self._default_processor_cls())
+
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ return self.processor(self, hidden_states, **kwargs)
+
+
+class AnyFlowImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class AnyFlowDualTimestepTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ gate_value: float,
+ deltatime_type: str,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.delta_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim)
+
+ self.register_buffer("delta_emb_gate", torch.tensor([gate_value], dtype=torch.float32), persistent=False)
+ self.deltatime_type = deltatime_type
+
+ def forward_timestep(
+ self, timestep: torch.Tensor, delta_timestep: torch.Tensor, encoder_hidden_states, token_per_frame
+ ):
+ batch_size, num_frames = timestep.shape
+ timestep = timestep.reshape(-1)
+ delta_timestep = delta_timestep.reshape(-1)
+
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+
+ delta_timestep = self.timesteps_proj(delta_timestep)
+
+ delta_embedder_dtype = next(iter(self.delta_embedder.parameters())).dtype
+ if delta_timestep.dtype != delta_embedder_dtype and delta_embedder_dtype != torch.int8:
+ delta_timestep = delta_timestep.to(delta_embedder_dtype)
+ delta_emb = self.delta_embedder(delta_timestep).type_as(encoder_hidden_states)
+
+ gate = self.delta_emb_gate.to(delta_embedder_dtype)
+
+ rt_emb = (1 - gate) * temb + gate * delta_emb
+ timestep_proj = self.time_proj(self.act_fn(rt_emb))
+
+ rt_emb = rt_emb.unflatten(0, (batch_size, num_frames)).repeat_interleave(token_per_frame, dim=1)
+ timestep_proj = timestep_proj.unflatten(0, (batch_size, num_frames)).repeat_interleave(token_per_frame, dim=1)
+
+ return rt_emb, timestep_proj
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ r_timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ far_cfg=None,
+ clean_timestep=None,
+ is_causal=True,
+ ):
+ if self.deltatime_type == "r":
+ delta_timestep = r_timestep
+ elif self.deltatime_type == "t-r":
+ delta_timestep = timestep - r_timestep
+ else:
+ raise NotImplementedError
+
+ if is_causal:
+ full_frame_timestep, full_frame_timestep_proj = self.forward_timestep(
+ timestep[:, -far_cfg["num_full_frames"] :],
+ delta_timestep[:, -far_cfg["num_full_frames"] :],
+ encoder_hidden_states,
+ far_cfg["full_token_per_frame"],
+ ) # noqa: E501
+ compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep(
+ timestep[:, : -far_cfg["num_full_frames"]],
+ delta_timestep[:, : -far_cfg["num_full_frames"]],
+ encoder_hidden_states,
+ far_cfg["compressed_token_per_frame"],
+ ) # noqa: E501
+
+ if clean_timestep is not None:
+ clean_timestep, clean_timestep_proj = self.forward_timestep(
+ clean_timestep, clean_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"]
+ ) # noqa: E501
+ timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1)
+ timestep_proj = torch.cat(
+ [compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1
+ )
+ else:
+ timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1)
+ timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1)
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+ else:
+ timestep, timestep_proj = self.forward_timestep(
+ timestep, delta_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"]
+ ) # noqa: E501
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return timestep, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class AnyFlowRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ compressed_patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.compressed_patch_size = compressed_patch_size
+ self.max_seq_len = max_seq_len
+ self.theta = theta
+
+ # Frequency table is lazily built per-device in ``_build_freqs``: MPS / NPU don't support
+ # complex128, so we downcast to complex64 there.
+ self._freqs_cache: Optional[Tuple[Any, torch.Tensor]] = None
+
+ def _build_freqs(self, device: torch.device) -> torch.Tensor:
+ cache_key = (device.type, str(device))
+ if self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
+ return self._freqs_cache[1]
+
+ is_mps = device.type == "mps"
+ is_npu = device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+
+ h_dim = w_dim = 2 * (self.attention_head_dim // 6)
+ t_dim = self.attention_head_dim - h_dim - w_dim
+
+ freqs_list = []
+ for dim in (t_dim, h_dim, w_dim):
+ f = get_1d_rotary_pos_embed(
+ dim,
+ self.max_seq_len,
+ self.theta,
+ use_real=False,
+ repeat_interleave_real=False,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_list.append(f.to(device))
+ freqs = torch.cat(freqs_list, dim=1)
+ self._freqs_cache = (cache_key, freqs)
+ return freqs
+
+ def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int):
+
+ real = freq.real # [B, C, L], float
+ real = real.transpose(0, 1).unsqueeze(0)
+ imag = freq.imag # [B, C, L], float
+ imag = imag.transpose(0, 1).unsqueeze(0)
+
+ pr = F.avg_pool1d(real, kernel_size, stride)
+ pi = F.avg_pool1d(imag, kernel_size, stride)
+
+ pr = pr.squeeze(0).transpose(0, 1)
+ pi = pi.squeeze(0).transpose(0, 1)
+
+ norm = torch.sqrt(pr**2 + pi**2)
+ pr_unit = pr / norm
+ pi_unit = pi / norm
+
+ return torch.complex(pr_unit, pi_unit)
+
+ def _forward_compressed_frame(self, num_frames, height, width, device):
+ ppf, pph, ppw = num_frames, height, width
+ # Tiny dummy components (e.g. height=16/width=16 with compressed_patch_size=(1,4,4) and
+ # an upstream VAE stride of 8) can produce 0-element grids; the .view(0, k, 1, -1) reshape
+ # below would be ambiguous. Real ckpts use 60x104 latents and never hit this path.
+ freqs_full = self._build_freqs(device)
+ if min(ppf, pph, ppw) <= 0:
+ freq_channels = self.attention_head_dim // 2
+ return torch.empty((ppf, pph, ppw, freq_channels), dtype=freqs_full.dtype, device=device)
+ downscale = [self.compressed_patch_size[i] // self.patch_size[i] for i in range(len(self.patch_size))]
+
+ freqs = freqs_full.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = self.avg_pool_complex(freqs[0], kernel_size=downscale[0], stride=downscale[0])
+ freqs_h = self.avg_pool_complex(freqs[1], kernel_size=downscale[1], stride=downscale[1])
+ freqs_w = self.avg_pool_complex(freqs[2], kernel_size=downscale[2], stride=downscale[2])
+
+ freqs_f = freqs_f[:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs_h[:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs_w[:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1)
+ return freqs
+
+ def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor:
+ ppf, pph, ppw = num_frames, height, width
+
+ freqs_full = self._build_freqs(device)
+ if min(ppf, pph, ppw) <= 0:
+ freq_channels = self.attention_head_dim // 2
+ return torch.empty((ppf, pph, ppw, freq_channels), dtype=freqs_full.dtype, device=device)
+
+ freqs = freqs_full.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1)
+ return freqs
+
+ def forward(self, far_cfg, device, clean_hidden_states=None, is_causal=True):
+ if is_causal:
+ full_frame_freqs = self._forward_full_frame(
+ num_frames=far_cfg["total_frames"],
+ height=far_cfg["full_frame_shape"][0],
+ width=far_cfg["full_frame_shape"][1],
+ device=device,
+ )
+ compressed_frame_freqs = self._forward_compressed_frame(
+ num_frames=far_cfg["total_frames"],
+ height=far_cfg["compressed_frame_shape"][0],
+ width=far_cfg["compressed_frame_shape"][1],
+ device=device,
+ )
+
+ compressed_frame_freqs, full_frame_freqs = (
+ compressed_frame_freqs[: far_cfg["num_compressed_frames"]],
+ full_frame_freqs[far_cfg["num_compressed_frames"] :],
+ ) # noqa: E501
+
+ compressed_frame_freqs = compressed_frame_freqs.flatten(start_dim=0, end_dim=2)
+ full_frame_freqs = full_frame_freqs.flatten(start_dim=0, end_dim=2)
+
+ if clean_hidden_states is not None:
+ freqs = torch.cat([compressed_frame_freqs, full_frame_freqs, full_frame_freqs], dim=0)
+ else:
+ freqs = torch.cat([compressed_frame_freqs, full_frame_freqs], dim=0)
+
+ freqs = freqs[None, None, ...]
+
+ return {"query": freqs, "key": freqs}
+ else:
+ freqs = self._forward_full_frame(
+ num_frames=far_cfg["total_frames"],
+ height=far_cfg["full_frame_shape"][0],
+ width=far_cfg["full_frame_shape"][1],
+ device=device,
+ )
+ freqs = freqs.flatten(start_dim=0, end_dim=2)
+ freqs = freqs[None, None, ...]
+ return {"query": freqs, "key": freqs}
+
+
+class AnyFlowTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = AnyFlowAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ processor=AnyFlowAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = AnyFlowAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ processor=AnyFlowCrossAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ attention_mask: torch.Tensor,
+ kv_cache=None,
+ kv_cache_flag=None,
+ ) -> torch.Tensor:
+
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=2)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ shift_msa.squeeze(2),
+ scale_msa.squeeze(2),
+ gate_msa.squeeze(2),
+ c_shift_msa.squeeze(2),
+ c_scale_msa.squeeze(2),
+ c_gate_msa.squeeze(2),
+ ) # noqa: E501
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(
+ hidden_states=norm_hidden_states,
+ rotary_emb=rotary_emb,
+ attention_mask=attention_mask,
+ kv_cache=kv_cache,
+ kv_cache_flag=kv_cache_flag,
+ ) # noqa: E501
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ Bidirectional 3D Transformer for AnyFlow flow-map sampling.
+
+ The architecture is the v0.35.1 Wan2.1 3D DiT backbone with one structural change: the timestep
+ embedder is replaced by ``AnyFlowDualTimestepTextImageEmbedding`` so that every forward call conditions
+ on both the source timestep ``t`` and the target timestep ``r``. This is the embedding required to
+ learn the flow map :math:`\Phi_{r\leftarrow t}` introduced in
+ [AnyFlow](https://huggingface.co/papers/2605.13724).
+
+ For frame-level autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead;
+ that variant adds the FAR causal block-mask and a compressed-frame patch embedding on top of the same
+ backbone.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Number of attention heads.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input latent.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output latent.
+ text_dim (`int`, defaults to `4096`):
+ Input dimension for text embeddings (UMT5).
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ Number of transformer blocks.
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon for normalization layers.
+ image_dim (`Optional[int]`, *optional*, defaults to `None`):
+ Image embedding dimension for I2V conditioning (`1280` for the original Wan2.1-I2V model).
+ rope_max_seq_len (`int`, defaults to `1024`):
+ Maximum sequence length used to precompute rotary position frequencies.
+ gate_value (`float`, defaults to `0.25`):
+ Mixing gate between source-timestep and delta-timestep embeddings (the AnyFlow paper's :math:`g`
+ parameter, fixed at 0.25 in stage-1 distillation).
+ deltatime_type (`str`, defaults to `'r'`):
+ Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval).
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["AnyFlowTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _repeated_blocks = ["AnyFlowTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ gate_value: float = 0.25,
+ deltatime_type: str = "r",
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding. ``compressed_patch_size`` is unused in the bidirectional path;
+ # we forward ``patch_size`` itself so the rotary helper has a valid (no-op) value.
+ self.rope = AnyFlowRotaryPosEmbed(attention_head_dim, patch_size, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embedding (always dual-timestep for AnyFlow distilled checkpoints).
+ self.condition_embedder = AnyFlowDualTimestepTextImageEmbedding(
+ dim=inner_dim,
+ gate_value=gate_value,
+ deltatime_type=deltatime_type,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ AnyFlowTransformerBlock(inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps)
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size):
+ batch_size, num_patches, channels = latents.shape
+ height, width = height // patch_size, width // patch_size
+
+ latents = latents.view(
+ batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size)
+ )
+ latents = latents.permute(0, 5, 1, 3, 2, 4)
+ latents = latents.reshape(
+ batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size
+ )
+ return latents
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ r_timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Transformer2DModelOutput, Tuple]:
+ """
+ Bidirectional flow-map forward pass.
+
+ ``hidden_states`` is laid out as ``(B, F, C, H, W)`` (per-frame latents). The input is patchified
+ with the standard ``patch_embedding`` (kernel = stride = ``patch_size``) and denoised with global
+ bidirectional self-attention over the resulting flat token sequence.
+ """
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ full_token_per_frame = (height * width) // (self.config.patch_size[1] * self.config.patch_size[2])
+
+ far_cfg = {
+ "total_frames": num_frames,
+ "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]),
+ "full_token_per_frame": full_token_per_frame,
+ }
+
+ rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device, is_causal=False)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep,
+ r_timestep,
+ encoder_hidden_states,
+ encoder_hidden_states_image,
+ is_causal=False,
+ far_cfg=far_cfg,
+ )
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+
+ attention_mask = None
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask)
+
+ # Output norm, projection & unpatchify
+ if temb.ndim == 3:
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move shift/scale to hidden_states' device for multi-GPU accelerate inference.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ output = self._unpack_latent_sequence(
+ hidden_states,
+ num_frames=far_cfg["total_frames"],
+ height=height,
+ width=width,
+ patch_size=self.config.patch_size[1],
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with frame-level autoregressive generation.
+
+ Extends the v0.35.1 Wan2.1 backbone with:
+
+ * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting frame-level
+ autoregressive generation (FAR; [Gu et al., 2025](https://arxiv.org/abs/2503.19325)).
+ * **Compressed-frame patch embedding** ``far_patch_embedding`` for context (already-generated) frames,
+ initialized from ``patch_embedding`` via trilinear interpolation so a freshly constructed model is
+ already at a reasonable starting point even before LoRA fine-tuning.
+ * **Dual-timestep flow-map embedding** for any-step sampling (same as ``AnyFlowTransformer3DModel``).
+
+ Use ``AnyFlowTransformer3DModel`` instead for plain bidirectional T2V — that variant skips the FAR
+ causal masking and ``far_patch_embedding`` and is ~5–10% smaller.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for full-resolution chunks.
+ compressed_patch_size (`Tuple[int]`, defaults to `(1, 4, 4)`):
+ Larger patch dimensions for the FAR-compressed (context) chunks.
+ full_chunk_limit (`int`, defaults to `3`):
+ Maximum number of full-resolution chunks before earlier chunks are demoted to compressed FAR
+ context. The released checkpoints use ``3``.
+ num_attention_heads (`int`, defaults to `40`):
+ Number of attention heads.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input latent.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output latent.
+ text_dim (`int`, defaults to `4096`):
+ Input dimension for text embeddings (UMT5).
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ Number of transformer blocks.
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon for normalization layers.
+ image_dim (`Optional[int]`, *optional*, defaults to `None`):
+ Image embedding dimension for I2V conditioning.
+ rope_max_seq_len (`int`, defaults to `1024`):
+ Maximum sequence length used to precompute rotary position frequencies.
+ gate_value (`float`, defaults to `0.25`):
+ Mixing gate between source-timestep and delta-timestep embeddings.
+ deltatime_type (`str`, defaults to `'r'`):
+ Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval).
+
+ .. note::
+ ``chunk_partition`` is **not** a model config field — it is a per-call argument passed to
+ :meth:`forward`. Different inference setups (varying ``num_frames`` or full-vs-compressed schedules)
+ therefore do not require separate checkpoints.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "far_patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["AnyFlowTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _repeated_blocks = ["AnyFlowTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ compressed_patch_size: Tuple[int] = (1, 4, 4),
+ full_chunk_limit: int = 3,
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ gate_value: float = 0.25,
+ deltatime_type: str = "r",
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding (full + FAR-compressed branches).
+ self.rope = AnyFlowRotaryPosEmbed(attention_head_dim, patch_size, compressed_patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ self.far_patch_embedding = nn.Conv3d(
+ in_channels, inner_dim, kernel_size=compressed_patch_size, stride=compressed_patch_size
+ )
+ # Warm-start the compressed branch from the full-resolution branch by trilinear interpolation. This
+ # matches FAR-Dev's `setup_far_model()` initialization. State-dict loading will overwrite these
+ # weights for trained checkpoints; the warm-start only matters when constructing a fresh model.
+ original_weight = self.patch_embedding.weight.data.view(-1, 1, *patch_size)
+ new_weight = F.interpolate(original_weight, size=compressed_patch_size, mode="trilinear", align_corners=False)
+ new_weight = new_weight.view(inner_dim, in_channels, *compressed_patch_size)
+ with torch.no_grad():
+ self.far_patch_embedding.weight.copy_(new_weight)
+ self.far_patch_embedding.bias.copy_(self.patch_embedding.bias)
+
+ # 2. Condition embedding (always dual-timestep for AnyFlow distilled checkpoints).
+ self.condition_embedder = AnyFlowDualTimestepTextImageEmbedding(
+ dim=inner_dim,
+ gate_value=gate_value,
+ deltatime_type=deltatime_type,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ AnyFlowTransformerBlock(inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps)
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ r_timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ chunk_partition: List[int],
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ clean_hidden_states: Optional[torch.Tensor] = None,
+ clean_timestep: Optional[torch.Tensor] = None,
+ kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None,
+ kv_cache_flag: Optional[Dict[str, Any]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Transformer2DModelOutput, AnyFlowFARTransformerOutput, Tuple]:
+ """
+ FAR causal forward pass. Dispatches to one of three internal paths:
+
+ * ``kv_cache is None`` → causal training rollout (returns
+ :class:`Transformer2DModelOutput`).
+ * ``kv_cache is not None`` and ``kv_cache_flag["is_cache_step"]`` → cache-prefill (returns
+ :class:`AnyFlowFARTransformerOutput` with ``sample=None``).
+ * Otherwise → autoregressive inference step (returns :class:`AnyFlowFARTransformerOutput`).
+
+ Args:
+ hidden_states (`torch.Tensor`): Latent input of shape ``(B, F, C, H, W)``.
+ timestep, r_timestep (`torch.Tensor`): Source / target diffusion timesteps.
+ encoder_hidden_states (`torch.Tensor`): UMT5 text embeddings.
+ chunk_partition (`List[int]`): Per-chunk frame counts; total must match the number of latent
+ frames in ``hidden_states``.
+ encoder_hidden_states_image (`torch.Tensor`, *optional*): I2V image embedding.
+ clean_hidden_states, clean_timestep (`torch.Tensor`, *optional*): Clean conditioning frames
+ used by the training rollout.
+ kv_cache, kv_cache_flag (*optional*): Per-block KV cache and metadata for autoregressive
+ inference.
+ attention_kwargs (*optional*): forwarded to the attention processors.
+ return_dict (`bool`, defaults to `True`): If `False`, returns positional tuples.
+ """
+ common = {
+ "hidden_states": hidden_states,
+ "chunk_partition": chunk_partition,
+ "timestep": timestep,
+ "r_timestep": r_timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_image": encoder_hidden_states_image,
+ "return_dict": return_dict,
+ "attention_kwargs": attention_kwargs,
+ }
+ if kv_cache is not None:
+ common["kv_cache"] = kv_cache
+ common["kv_cache_flag"] = kv_cache_flag
+ if kv_cache_flag is not None and kv_cache_flag.get("is_cache_step"):
+ return self._forward_cache(
+ clean_hidden_states=clean_hidden_states,
+ clean_timestep=clean_timestep,
+ **common,
+ )
+ return self._forward_inference(**common)
+ return self._forward_train(
+ clean_hidden_states=clean_hidden_states,
+ clean_timestep=clean_timestep,
+ **common,
+ )
+
+ def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size):
+ batch_size, num_patches, channels = latents.shape
+ height, width = height // patch_size, width // patch_size
+
+ latents = latents.view(
+ batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size)
+ )
+
+ latents = latents.permute(0, 5, 1, 3, 2, 4)
+ latents = latents.reshape(
+ batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size
+ )
+ return latents
+
+ def forward_far_patchify(self, hidden_states, far_cfg, clean_hidden_states=None):
+
+ full_hidden_states, compressed_hidden_states = (
+ hidden_states[:, :, far_cfg["num_compressed_frames"] :],
+ hidden_states[:, :, : far_cfg["num_compressed_frames"]],
+ ) # noqa: E501
+
+ patchified_full_hidden_states = (
+ self.patch_embedding(full_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2)
+ )
+ if clean_hidden_states is not None:
+ clean_hidden_states = (
+ self.patch_embedding(clean_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2)
+ )
+ patchified_full_hidden_states = torch.cat([patchified_full_hidden_states, clean_hidden_states], dim=1)
+
+ if far_cfg["num_compressed_frames"] > 0:
+ patchified_compressed_hidden_states = (
+ self.far_patch_embedding(compressed_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2)
+ )
+ hidden_states = torch.cat([patchified_compressed_hidden_states, patchified_full_hidden_states], dim=1)
+ else:
+ hidden_states = patchified_full_hidden_states
+ return hidden_states
+
+ def forward_far_patchify_inference(self, hidden_states):
+ hidden_states = self.patch_embedding(hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2)
+ return hidden_states
+
+ def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype):
+ chunk_partition = far_cfg["chunk_partition"]
+
+ noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]
+ context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"]
+
+ noise_start = context_seq_len
+ noise_end = noise_start + noise_seq_len
+
+ clean_start = context_seq_len + noise_seq_len
+ clean_end = clean_start + clean_seq_len
+
+ if clean_hidden_states is not None:
+ real_seq_len = context_seq_len + noise_seq_len + clean_seq_len
+ else:
+ real_seq_len = context_seq_len + noise_seq_len
+
+ padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0)
+
+ if clean_hidden_states is not None:
+ context_chunk_partition, noise_chunk_partition = (
+ chunk_partition[: far_cfg["num_compressed_chunk"]],
+ chunk_partition[far_cfg["num_compressed_chunk"] :],
+ ) # noqa: E501
+
+ if len(context_chunk_partition) != 0:
+ context_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx
+ for chunk_idx, chunk_len in enumerate(context_chunk_partition)
+ ]
+ ) # noqa: E501
+ else:
+ context_frame_idx = None
+ noise_frame_idx = clean_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
+ * (chunk_idx + len(context_chunk_partition))
+ for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
+ ]
+ ) # noqa: E501
+ pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
+
+ if len(context_chunk_partition) != 0:
+ frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
+ else:
+ frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ # q_idx, kv_idx: LongTensor, range: [0, padded_seq_len)
+
+ # 1) whether is padding
+ is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
+
+ # 3) chunk casual
+ base = frame_idx[q_idx] >= frame_idx[kv_idx]
+
+ # 4) interval mask
+ q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end)
+ q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end)
+
+ k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end)
+ k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end)
+
+ # 5) clean -> noise: disallowed
+ is_clean_to_noise = q_is_clean & k_is_noise
+
+ # 6) noise -> noise: only same frame
+ same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx]
+
+ noise_to_noise = q_is_noise & k_is_noise
+ noise_to_clean = q_is_noise & k_is_clean
+
+ noise_to_noise_allow = noise_to_noise & same_frame_idx
+ noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow
+
+ noise_to_clean_same = noise_to_clean & same_frame_idx
+ noise_to_clean_disallow = noise_to_clean_same
+
+ # attention mask is chunk casual
+ allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow
+ return allowed
+
+ return create_block_mask(
+ mask_mod,
+ B=None,
+ H=None,
+ Q_LEN=padded_seq_len,
+ KV_LEN=padded_seq_len,
+ device=device,
+ _compile=False,
+ )
+ else:
+ context_chunk_partition, noise_chunk_partition = (
+ chunk_partition[: far_cfg["num_compressed_chunk"]],
+ chunk_partition[far_cfg["num_compressed_chunk"] :],
+ ) # noqa: E501
+
+ if len(context_chunk_partition) != 0:
+ context_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx
+ for chunk_idx, chunk_len in enumerate(context_chunk_partition)
+ ]
+ ) # noqa: E501
+ else:
+ context_frame_idx = None
+
+ noise_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
+ * (chunk_idx + len(context_chunk_partition))
+ for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
+ ]
+ ) # noqa: E501
+ pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
+
+ if len(context_chunk_partition) != 0:
+ frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0)
+ else:
+ frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0)
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
+ base = frame_idx[q_idx] >= frame_idx[kv_idx]
+ return base & ~is_padding
+
+ return create_block_mask(
+ mask_mod,
+ B=None,
+ H=None,
+ Q_LEN=padded_seq_len,
+ KV_LEN=padded_seq_len,
+ device=device,
+ _compile=False,
+ )
+
+ def _forward_inference(
+ self,
+ hidden_states: torch.Tensor,
+ chunk_partition,
+ timestep: torch.LongTensor,
+ r_timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ kv_cache=None,
+ kv_cache_flag=None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2])
+ compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (
+ width // self.config.compressed_patch_size[2]
+ )
+
+ total_chunks = 1 + kv_cache_flag["num_cached_chunks"]
+
+ if total_chunks >= self.config.full_chunk_limit:
+ num_full_chunk, num_compressed_chunk = (
+ self.config.full_chunk_limit,
+ total_chunks - self.config.full_chunk_limit,
+ )
+ else:
+ num_full_chunk, num_compressed_chunk = total_chunks, 0
+
+ kv_cache_flag["num_cached_full_tokens"] = (
+ sum(chunk_partition[num_compressed_chunk : num_compressed_chunk + (num_full_chunk - 1)])
+ * full_token_per_frame
+ ) # noqa: E501
+ kv_cache_flag["num_cached_compressed_tokens"] = (
+ sum(chunk_partition[:num_compressed_chunk]) * compressed_token_per_frame
+ )
+
+ far_cfg = {
+ "total_frames": sum(chunk_partition),
+ "num_full_frames": sum(chunk_partition[num_compressed_chunk:]),
+ "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]),
+ "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]),
+ "compressed_frame_shape": (
+ height // self.config.compressed_patch_size[1],
+ width // self.config.compressed_patch_size[2],
+ ),
+ "full_token_per_frame": full_token_per_frame,
+ "compressed_token_per_frame": compressed_token_per_frame,
+ }
+
+ # step 3: generate attention mask
+ attention_mask = None
+ hidden_states = self.forward_far_patchify_inference(hidden_states)
+
+ rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device)
+ rotary_emb["query"] = rotary_emb["query"][:, :, -hidden_states.shape[1] :]
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep,
+ r_timestep,
+ encoder_hidden_states,
+ encoder_hidden_states_image,
+ far_cfg=far_cfg, # noqa: E501
+ )
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ for index_block, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ attention_mask,
+ kv_cache[index_block],
+ kv_cache_flag,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ attention_mask,
+ kv_cache[index_block],
+ kv_cache_flag,
+ )
+
+ # 5. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift, scale = shift.squeeze(2), scale.squeeze(2)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+
+ output = self.proj_out(hidden_states)
+ output = self._unpack_latent_sequence(
+ output, num_frames=chunk_partition[-1], height=height, width=width, patch_size=self.config.patch_size[1]
+ )
+
+ if not return_dict:
+ return output, kv_cache
+
+ return AnyFlowFARTransformerOutput(sample=output, kv_cache=kv_cache)
+
+ def _forward_cache(
+ self,
+ hidden_states: torch.Tensor,
+ chunk_partition,
+ timestep: torch.LongTensor,
+ r_timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ clean_hidden_states=None,
+ clean_timestep=None,
+ kv_cache=None,
+ kv_cache_flag=None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ if clean_hidden_states is not None:
+ clean_hidden_states = clean_hidden_states.permute(0, 2, 1, 3, 4)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2])
+ compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (
+ width // self.config.compressed_patch_size[2]
+ )
+ total_chunks = len(chunk_partition)
+
+ full_chunk_limit = self.config.full_chunk_limit - 1
+
+ if total_chunks > full_chunk_limit:
+ num_full_chunk, num_compressed_chunk = full_chunk_limit, total_chunks - full_chunk_limit
+ else:
+ num_full_chunk, num_compressed_chunk = total_chunks, 0
+
+ far_cfg = {
+ "total_frames": sum(chunk_partition),
+ "num_full_chunk": num_full_chunk,
+ "num_full_frames": sum(chunk_partition[num_compressed_chunk:]),
+ "num_compressed_chunk": num_compressed_chunk,
+ "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]),
+ "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]),
+ "compressed_frame_shape": (
+ height // self.config.compressed_patch_size[1],
+ width // self.config.compressed_patch_size[2],
+ ),
+ "full_token_per_frame": full_token_per_frame,
+ "compressed_token_per_frame": compressed_token_per_frame,
+ "chunk_partition": chunk_partition,
+ }
+
+ kv_cache_flag["num_full_tokens"] = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]
+ kv_cache_flag["num_compressed_tokens"] = (
+ far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"]
+ )
+
+ # step 3: generate attention mask
+ attention_mask = self._build_causal_mask(
+ far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype
+ )
+
+ rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device)
+ hidden_states = self.forward_far_patchify(
+ hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states
+ )
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep,
+ r_timestep,
+ encoder_hidden_states,
+ encoder_hidden_states_image,
+ far_cfg=far_cfg,
+ clean_timestep=clean_timestep,
+ )
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ for index_block, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ attention_mask,
+ kv_cache[index_block],
+ kv_cache_flag,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ attention_mask,
+ kv_cache[index_block],
+ kv_cache_flag,
+ )
+
+ if not return_dict:
+ return None, kv_cache
+
+ return AnyFlowFARTransformerOutput(sample=None, kv_cache=kv_cache)
+
+ def _forward_train(
+ self,
+ hidden_states: torch.Tensor,
+ chunk_partition,
+ timestep: torch.LongTensor,
+ r_timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ clean_hidden_states=None,
+ clean_timestep=None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ if clean_hidden_states is not None:
+ clean_hidden_states = clean_hidden_states.permute(0, 2, 1, 3, 4)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2])
+ compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (
+ width // self.config.compressed_patch_size[2]
+ )
+ total_chunks = len(chunk_partition)
+
+ if total_chunks > self.config.full_chunk_limit:
+ num_full_chunk, num_compressed_chunk = (
+ self.config.full_chunk_limit,
+ total_chunks - self.config.full_chunk_limit,
+ )
+ else:
+ num_full_chunk, num_compressed_chunk = total_chunks, 0
+
+ far_cfg = {
+ "total_frames": sum(chunk_partition),
+ "num_full_chunk": num_full_chunk,
+ "num_full_frames": sum(chunk_partition[num_compressed_chunk:]),
+ "num_compressed_chunk": num_compressed_chunk,
+ "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]),
+ "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]),
+ "compressed_frame_shape": (
+ height // self.config.compressed_patch_size[1],
+ width // self.config.compressed_patch_size[2],
+ ),
+ "full_token_per_frame": full_token_per_frame,
+ "compressed_token_per_frame": compressed_token_per_frame,
+ "chunk_partition": chunk_partition,
+ }
+
+ # step 3: generate attention mask
+ attention_mask = self._build_causal_mask(
+ far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype
+ )
+
+ rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device)
+
+ hidden_states = self.forward_far_patchify(
+ hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states
+ )
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep,
+ r_timestep,
+ encoder_hidden_states,
+ encoder_hidden_states_image,
+ far_cfg=far_cfg,
+ clean_timestep=clean_timestep,
+ )
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ for index_block, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ attention_mask,
+ )
+ else:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask)
+
+ # 5. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift, scale = shift.squeeze(2), scale.squeeze(2)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+
+ if clean_hidden_states is not None:
+ hidden_states = hidden_states[
+ :, : -(far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"])
+ ] # remove clean copy
+ output = self.proj_out(
+ hidden_states[:, far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] :]
+ ) # remove far context
+ output = self._unpack_latent_sequence(
+ output,
+ num_frames=far_cfg["num_full_frames"],
+ height=height,
+ width=width,
+ patch_size=self.config.patch_size[1],
+ ) # noqa: E501
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index d4b3974322b4..c0d12121d5e8 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -164,6 +164,10 @@
"AnimateDiffVideoToVideoPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
]
+ _import_structure["anyflow"] = [
+ "AnyFlowPipeline",
+ "AnyFlowFARPipeline",
+ ]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["flux2"] = [
@@ -603,6 +607,10 @@
AnimateDiffVideoToVideoControlNetPipeline,
AnimateDiffVideoToVideoPipeline,
)
+ from .anyflow import (
+ AnyFlowFARPipeline,
+ AnyFlowPipeline,
+ )
from .audioldm2 import (
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
diff --git a/src/diffusers/pipelines/anyflow/__init__.py b/src/diffusers/pipelines/anyflow/__init__.py
new file mode 100644
index 000000000000..10603cdedc3b
--- /dev/null
+++ b/src/diffusers/pipelines/anyflow/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_anyflow"] = ["AnyFlowPipeline"]
+ _import_structure["pipeline_anyflow_far"] = ["AnyFlowFARPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_anyflow import AnyFlowPipeline
+ from .pipeline_anyflow_far import AnyFlowFARPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py
new file mode 100644
index 000000000000..19a0ec69b08d
--- /dev/null
+++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py
@@ -0,0 +1,669 @@
+# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for any-step flow-map sampling.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AnyFlowTransformer3DModel, AutoencoderKLWan
+from ...models.autoencoders.vae import DiagonalGaussianDistribution
+from ...schedulers import FlowMapEulerDiscreteScheduler
+from ...utils import is_ftfy_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import AnyFlowPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import AnyFlowPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = AnyFlowPipeline.from_pretrained(
+ ... "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+
+ >>> prompt = "A red panda eating bamboo in a forest, cinematic lighting"
+ >>> video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
+ >>> export_to_video(video, "anyflow_t2v.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean
+def basic_clean(text):
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints.
+
+ AnyFlow learns arbitrary-interval transitions :math:`z_t \to z_r` rather than the fixed
+ :math:`z_t \to z_0` mapping of consistency models, so a single distilled checkpoint can be evaluated at
+ 1, 2, 4, 8, 16... NFE without retraining. This pipeline operates over the full video tensor in one
+ bidirectional pass; for frame-level autoregressive (causal) generation use ``AnyFlowFARPipeline``.
+
+ Sampling is plain Euler in mean-velocity form (``z_r = z_t - (t - r) * u``) with no re-noising. The
+ released NVIDIA checkpoints fold classifier-free guidance into the model weights, so the default
+ ``guidance_scale=1.0`` is the recommended setting.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [google/umt5-xxl](https://huggingface.co/google/umt5-xxl).
+ text_encoder ([`UMT5EncoderModel`]):
+ [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) text encoder.
+ transformer ([`AnyFlowTransformer3DModel`]):
+ Bidirectional flow-map 3D Transformer.
+ vae ([`AutoencoderKLWan`]):
+ VAE that encodes/decodes videos to and from latent representations.
+ scheduler ([`FlowMapEulerDiscreteScheduler`]):
+ Flow-map sampler. The pipeline drives ``scheduler.step(..., timestep, sample, r_timestep)`` per
+ inference step.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: AnyFlowTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMapEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: str | list[str] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str | list[str],
+ negative_prompt: str | list[str] | None = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ def vae_encode(self, context_sequence):
+ # normalize: [0, 1] -> [-1, 1]
+ context_sequence = context_sequence * 2 - 1
+ context_sequence = self.encode_latents(
+ context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False
+ )
+ context_sequence = context_sequence.permute(0, 2, 1, 3, 4)
+ return context_sequence
+
+ def _normalize_latents(self, latents, latents_mean, latents_std):
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents = ((latents.float() - latents_mean) * latents_std).to(latents)
+ return latents
+
+ @torch.no_grad()
+ def encode_latents(self, videos, sample=True):
+ videos = videos.permute(0, 2, 1, 3, 4)
+ moments = self.vae._encode(videos)
+
+ latents_mean = torch.tensor(self.vae.config.latents_mean)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std)
+
+ mu, logvar = torch.chunk(moments, 2, dim=1)
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
+
+ if sample:
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
+
+ latents = torch.cat([mu, logvar], dim=1)
+ posterior = DiagonalGaussianDistribution(latents)
+ latents = posterior.sample(generator=None)
+ del posterior
+ else:
+ latents = mu
+ return latents
+
+ def _denoise_rollout(
+ self,
+ context_sequence=None,
+ num_inference_steps: int = 50,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ guidance_scale: float = 1.0,
+ use_mean_velocity: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ ):
+ device = self._execution_device
+
+ if negative_prompt_embeds is not None:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ context_length = context_sequence.shape[1] if context_sequence is not None else 0
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ for i, t in enumerate(tqdm(timesteps[:-1])):
+ r = timesteps[i + 1]
+
+ if t == r:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
+ timestep = timestep.repeat((1, latent_model_input.shape[1]))
+
+ if use_mean_velocity:
+ r_timestep = r.expand(latent_model_input.shape[0]).unsqueeze(-1)
+ r_timestep = r_timestep.repeat((1, latent_model_input.shape[1]))
+ else:
+ r_timestep = timestep
+
+ if context_sequence is not None:
+ latent_model_input[:, :context_length, ...] = context_sequence
+ timestep[:, :context_length] = 0
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ r_timestep=r_timestep,
+ encoder_hidden_states=prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond, noise_pred = noise_pred.chunk(2)
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ latents = self.scheduler.step(noise_pred, t, latents, r_timestep=r, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs or []:
+ if k == "latents":
+ callback_kwargs[k] = latents
+ elif k == "prompt_embeds":
+ callback_kwargs[k] = prompt_embeds
+ elif k == "negative_prompt_embeds":
+ callback_kwargs[k] = negative_prompt_embeds
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ context_sequence: Optional[torch.Tensor] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 1.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ use_mean_velocity: bool = True,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds`
+ instead.
+ context_sequence (`torch.Tensor`, *optional*):
+ Pre-VAE conditioning frames of shape `(B, C, T, H, W)` in `[0, 1]`. When provided, the pipeline
+ VAE-encodes them and keeps the corresponding latent prefix fixed during sampling.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. Ignored when not using guidance
+ (`guidance_scale < 1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video. Must satisfy `(num_frames - 1) %
+ vae_scale_factor_temporal == 0`.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. Distilled AnyFlow checkpoints support any-step sampling, so
+ values as low as `1`, `2`, `4`, or `8` are typical.
+ guidance_scale (`float`, defaults to `1.0`):
+ Classifier-free guidance scale. The released AnyFlow checkpoints fuse CFG into the weights
+ during training; keep at `1.0` unless you know your checkpoint expects otherwise.
+ num_videos_per_prompt (`int`, *optional*, defaults to `1`):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents to use as inputs. If not provided, latents are sampled from the
+ supplied `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to tweak text inputs (e.g., prompt weighting). If
+ not provided, embeddings are generated from `prompt`.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format. One of `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return an [`AnyFlowPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Reserved for future use; currently forwarded to the transformer but not consumed.
+ callback_on_step_end (`Callable`, *optional*):
+ A function or [`PipelineCallback`] called at the end of each inference step. See
+ [`callbacks`](../callbacks) for details.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to `["latents"]`):
+ The tensor inputs forwarded to the callback. Must be a subset of
+ `self._callback_tensor_inputs`.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum text-encoder sequence length. Longer prompts are truncated.
+ use_mean_velocity (`bool`, defaults to `True`):
+ When `True`, the flow-map model is conditioned on both the source timestep `t` and the target
+ timestep `r` to predict a mean velocity, matching the training-time behavior. Disable to
+ mirror raw Euler stepping (`r = t`).
+
+ Examples:
+
+ Returns:
+ [`~AnyFlowPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`AnyFlowPipelineOutput`] is returned, otherwise a `tuple` whose
+ first element is the generated video.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._num_timesteps = num_inference_steps
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 5. Prepare latent variables. ``prepare_latents`` returns the standard ``(B, C, T, H, W)``
+ # diffusers layout; the AnyFlow rollout expects ``(B, T, C, H, W)`` so we permute here.
+ num_channels_latents = self.transformer.config.in_channels
+ init_latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ init_latents = init_latents.permute(0, 2, 1, 3, 4).to(transformer_dtype)
+
+ # setup start sequence
+ if context_sequence is not None:
+ context_sequence = self.vae_encode(context_sequence)
+ context_length = context_sequence.shape[1]
+
+ latents = self._denoise_rollout(
+ context_sequence=context_sequence,
+ num_inference_steps=num_inference_steps,
+ latents=init_latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ use_mean_velocity=use_mean_velocity,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+ if context_sequence is not None:
+ latents[:, :context_length, ...] = context_sequence
+ latents = latents.permute(0, 2, 1, 3, 4)
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnyFlowPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py
new file mode 100644
index 000000000000..3b26904cb82f
--- /dev/null
+++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py
@@ -0,0 +1,874 @@
+# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for FAR causal flow-map sampling.
+
+import copy
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AnyFlowFARTransformer3DModel, AutoencoderKLWan
+from ...models.autoencoders.vae import DiagonalGaussianDistribution
+from ...schedulers import FlowMapEulerDiscreteScheduler
+from ...utils import is_ftfy_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import AnyFlowPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import numpy as np
+ >>> import torch
+ >>> from diffusers import AnyFlowFARPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = AnyFlowFARPipeline.from_pretrained(
+ ... "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+
+ >>> # Single-frame I2V: wrap the conditioning image as a (1, 3, 1, H, W) tensor in [0, 1].
+ >>> first_frame = load_image("path/to/first_frame.png").resize((832, 480))
+ >>> arr = np.asarray(first_frame).astype("float32") / 255.0
+ >>> context = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).to("cuda")
+
+ >>> video = pipe(
+ ... prompt="a cat walks across a sunlit lawn",
+ ... context_sequence={"raw": context},
+ ... num_inference_steps=4,
+ ... num_frames=81,
+ ... ).frames[0]
+ >>> export_to_video(video, "anyflow_far.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean
+def basic_clean(text):
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Causal (FAR-based) text-to-video / image-to-video / video-to-video pipeline for AnyFlow checkpoints.
+
+ The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with
+ flow-map steps while attending only to past chunks via block-sparse causal attention, and intermediate
+ KV cache is reused across chunks.
+
+ The task mode (T2V / I2V / V2V) is selected by the ``context_sequence`` argument passed to ``__call__``:
+
+ - ``context_sequence=None`` — pure text-to-video.
+ - ``context_sequence={"raw":