From 507fd9b7a6ae921c7fe15f65e0eae6d52a743041 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:41:54 +0800 Subject: [PATCH 01/16] [Pipelines] AnyFlow: scaffold pipelines/anyflow + register all top-level imports This is the lazy-loader scaffolding only. Body files (pipeline_anyflow.py, pipeline_anyflow_causal.py, transformer_anyflow.py, scheduling_flow_map_euler_discrete.py) come in subsequent commits. --- src/diffusers/__init__.py | 8 ++++ src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 8 ++++ src/diffusers/pipelines/anyflow/__init__.py | 48 +++++++++++++++++++ .../pipelines/anyflow/pipeline_output.py | 20 ++++++++ src/diffusers/schedulers/__init__.py | 2 + 7 files changed, 88 insertions(+) create mode 100644 src/diffusers/pipelines/anyflow/__init__.py create mode 100644 src/diffusers/pipelines/anyflow/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0c6083cafd0a..57f53de544cc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -190,6 +190,7 @@ [ "AceStepTransformer1DModel", "AllegroTransformer3DModel", + "AnyFlowTransformer3DModel", "AsymmetricAutoencoderKL", "AttentionBackendName", "AuraFlowTransformer2DModel", @@ -377,6 +378,7 @@ "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", + "FlowMapEulerDiscreteScheduler", "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", @@ -506,6 +508,8 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoControlNetPipeline", "AnimateDiffVideoToVideoPipeline", + "AnyFlowCausalPipeline", + "AnyFlowPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", @@ -1007,6 +1011,7 @@ from .models import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnyFlowTransformer3DModel, AsymmetricAutoencoderKL, AttentionBackendName, AuraFlowTransformer2DModel, @@ -1190,6 +1195,7 @@ EDMEulerScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + FlowMapEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, @@ -1300,6 +1306,8 @@ AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, + AnyFlowCausalPipeline, + AnyFlowPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc772fcc6d0c..29a733009eef 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -212,6 +212,7 @@ from .transformers import ( AceStepTransformer1DModel, AllegroTransformer3DModel, + AnyFlowTransformer3DModel, AuraFlowTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index bbd7ecfa911b..11d6098e1fbf 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel + from .transformer_anyflow import AnyFlowTransformer3DModel from .transformer_bria import BriaTransformer2DModel from .transformer_bria_fibo import BriaFiboTransformer2DModel from .transformer_chroma import ChromaTransformer2DModel diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c49ad3938cdc..7d4f258a0e0f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -164,6 +164,10 @@ "AnimateDiffVideoToVideoPipeline", "AnimateDiffVideoToVideoControlNetPipeline", ] + _import_structure["anyflow"] = [ + "AnyFlowPipeline", + "AnyFlowCausalPipeline", + ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["flux2"] = [ @@ -595,6 +599,10 @@ AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, ) + from .anyflow import ( + AnyFlowCausalPipeline, + AnyFlowPipeline, + ) from .audioldm2 import ( AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/pipelines/anyflow/__init__.py b/src/diffusers/pipelines/anyflow/__init__.py new file mode 100644 index 000000000000..3acc4bc0114b --- /dev/null +++ b/src/diffusers/pipelines/anyflow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_anyflow"] = ["AnyFlowPipeline"] + _import_structure["pipeline_anyflow_causal"] = ["AnyFlowCausalPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_anyflow import AnyFlowPipeline + from .pipeline_anyflow_causal import AnyFlowCausalPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/anyflow/pipeline_output.py b/src/diffusers/pipelines/anyflow/pipeline_output.py new file mode 100644 index 000000000000..7cbe9a019013 --- /dev/null +++ b/src/diffusers/pipelines/anyflow/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class AnyFlowPipelineOutput(BaseOutput): + r""" + Output class for AnyFlow pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..447586c6f436 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -59,6 +59,7 @@ _import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"] _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] + _import_structure["scheduling_flow_map_euler_discrete"] = ["FlowMapEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] @@ -165,6 +166,7 @@ from .scheduling_edm_euler import EDMEulerScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_flow_map_euler_discrete import FlowMapEulerDiscreteScheduler from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler From 29229d754098cb37cc77a3745f357ec8e9d3aec4 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:44:32 +0800 Subject: [PATCH 02/16] [Schedulers] AnyFlow: add FlowMapEulerDiscreteScheduler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The flow-map scheduler advances samples from timestep t to caller-provided target r in a single Euler step, supporting any-step sampling on flow-map- distilled checkpoints. It is a general-purpose scheduler — not specific to the AnyFlow checkpoints. Tests: 12 standalone tests covering instantiation, set_timesteps endpoints, shift identity/monotonicity, step shape preservation, zero-interval identity, one-shot sampling, train weight schemes, scale_noise endpoints. Docs: api/schedulers/flow_map_euler_discrete.md --- docs/source/en/_toctree.yml | 2 + .../api/schedulers/flow_map_euler_discrete.md | 28 ++++ .../scheduling_flow_map_euler_discrete.py | 148 ++++++++++++++++++ .../test_scheduler_flow_map_euler_discrete.py | 141 +++++++++++++++++ 4 files changed, 319 insertions(+) create mode 100644 docs/source/en/api/schedulers/flow_map_euler_discrete.md create mode 100644 src/diffusers/schedulers/scheduling_flow_map_euler_discrete.py create mode 100644 tests/schedulers/test_scheduler_flow_map_euler_discrete.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8e8776d4a8c2..6c6d36a7f483 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -727,6 +727,8 @@ title: EulerAncestralDiscreteScheduler - local: api/schedulers/euler title: EulerDiscreteScheduler + - local: api/schedulers/flow_map_euler_discrete + title: FlowMapEulerDiscreteScheduler - local: api/schedulers/flow_match_euler_discrete title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete diff --git a/docs/source/en/api/schedulers/flow_map_euler_discrete.md b/docs/source/en/api/schedulers/flow_map_euler_discrete.md new file mode 100644 index 000000000000..64cd9e60dae9 --- /dev/null +++ b/docs/source/en/api/schedulers/flow_map_euler_discrete.md @@ -0,0 +1,28 @@ + + +# FlowMapEulerDiscreteScheduler + +`FlowMapEulerDiscreteScheduler` is an Euler-style sampler designed for flow-map-distilled diffusion +models. Flow-map models learn arbitrary-interval transitions $\mathbf{z}_t \to \mathbf{z}_r$ rather than +the fixed $\mathbf{z}_t \to \mathbf{z}_0$ mapping of consistency models. Both endpoints of the step are +caller-provided, which is what enables any-step sampling: a single distilled checkpoint can be evaluated at +1, 2, 4, 8, 16... NFE without retraining. + +The scheduler was introduced in +[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/) +and ships with the `AnyFlowPipeline` and `AnyFlowCausalPipeline` integrations, but it is not +AnyFlow-specific — any flow-map-distilled checkpoint can use it. + +## FlowMapEulerDiscreteScheduler + +[[autodoc]] FlowMapEulerDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_flow_map_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_map_euler_discrete.py new file mode 100644 index 000000000000..459e73be44ed --- /dev/null +++ b/src/diffusers/schedulers/scheduling_flow_map_euler_discrete.py @@ -0,0 +1,148 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FlowMapEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler-style sampler for flow-map-distilled diffusion models. + + Flow-map models learn arbitrary-interval transitions :math:`z_t \\to z_r` rather than the fixed + :math:`z_t \\to z_0` mapping of consistency models, so a single distilled checkpoint can be evaluated at + 1, 2, 4, 8, ... NFE without retraining. The `step` method advances the sample from `timestep` to + `r_timestep` along the predicted velocity. + + Introduced in + [AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/). + + This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the + generic methods implemented for all schedulers (loading, saving, etc.). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps used to train the underlying flow-map model. + shift (`float`, defaults to 1.0): + Multiplicative timestep shift applied to the inference schedule. ``shift=1.0`` is the identity; values + greater than 1.0 push the schedule toward more denoising at later steps (e.g., ``shift=5`` matches the + Wan2.1 default). + weight_type (`str`, defaults to `"gaussian"`): + Loss-weighting scheme for training. ``"gaussian"`` uses logit-normal weighting centered at + ``num_train_timesteps / 2``. ``"beta08"`` uses a beta(1.0, 0.5)-shaped weighting biased toward small + timesteps. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + weight_type: str = "gaussian", + ): + self.set_timesteps(num_train_timesteps, device="cpu") + self.set_train_weight(weight_type) + + def adaptive_weighting(self, loss, p=1.0, eps=1e-3): + """Inverse-loss reweighting used during distillation training.""" + weight = 1.0 / torch.pow(loss.detach() + eps, p) + return weight * loss + + def set_train_weight(self, weight_type): + """Precompute per-timestep training loss weights.""" + if self.config.weight_type == "gaussian": + x = self.timesteps + y = torch.exp(-2 * ((x - self.config.num_train_timesteps / 2) / self.config.num_train_timesteps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (self.config.num_train_timesteps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + elif self.config.weight_type == "beta08": + t = self.timesteps / self.config.num_train_timesteps + y = (t**1.0) * ((1 - t) ** 0.5) + self.linear_timesteps_weights = y * (self.config.num_train_timesteps / y.sum()) + else: + raise ValueError(f"Invalid weight type: {weight_type}") + + @torch.no_grad() + def get_train_weight(self, timesteps): + """Return the precomputed loss weight for each entry in ``timesteps``.""" + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(1) - timesteps.flatten().unsqueeze(0).to(self.timesteps.device)).abs(), + dim=0, + ).reshape(timesteps.shape) + weights = self.linear_timesteps_weights[timestep_id] + return weights.to(timesteps.device) + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """Linearly interpolate ``sample`` toward ``noise`` according to the normalized ``timestep``.""" + timestep = timestep.to(device=sample.device, dtype=sample.dtype) + + timestep = timestep / self.config.num_train_timesteps + timestep = timestep.view(*timestep.shape, *([1] * (noise.ndim - timestep.ndim))) + sample = timestep * noise + (1.0 - timestep) * sample + return sample + + def apply_shift(self, sigmas): + """Apply the configured shift transformation to a sigma tensor.""" + if self.config.shift == 1.0: + return sigmas + return self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + ): + """Build the inference timestep schedule on ``device`` and store it on ``self.timesteps``.""" + timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, dtype=torch.float64, device=device) + timesteps = self.apply_shift(timesteps) + + self.timesteps = timesteps * self.config.num_train_timesteps + + def step( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + timestep: Optional[Union[float, torch.FloatTensor]] = None, + r_timestep: Optional[Union[float, torch.FloatTensor]] = None, + ): + """ + Advance ``sample`` from ``timestep`` to ``r_timestep`` using the model-predicted velocity. + + Unlike a standard Euler scheduler, both endpoints of the interval are caller-provided so that any-step + sampling is possible: a single model call can step from `t` to any chosen target `r` (including `r=0` for + a one-shot generation). + """ + timestep = timestep / self.config.num_train_timesteps + r_timestep = r_timestep / self.config.num_train_timesteps + timestep = timestep.view(*timestep.shape, *([1] * (model_output.ndim - timestep.ndim))) + r_timestep = r_timestep.view(*r_timestep.shape, *([1] * (model_output.ndim - r_timestep.ndim))) + prev_sample = sample - (timestep - r_timestep) * model_output + return prev_sample.to(model_output.dtype) diff --git a/tests/schedulers/test_scheduler_flow_map_euler_discrete.py b/tests/schedulers/test_scheduler_flow_map_euler_discrete.py new file mode 100644 index 000000000000..4542988ab4d7 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_map_euler_discrete.py @@ -0,0 +1,141 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMapEulerDiscreteScheduler + + +class FlowMapEulerDiscreteSchedulerTest(unittest.TestCase): + """ + The flow-map scheduler has a non-standard ``step`` signature that takes both ``timestep`` and + ``r_timestep`` (the target timestep), so it cannot use ``SchedulerCommonTest``. The tests below + exercise the contract that the scheduler exposes to ``AnyFlowPipeline`` and ``AnyFlowCausalPipeline``. + """ + + scheduler_class = FlowMapEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "weight_type": "gaussian", + } + config.update(**kwargs) + return config + + def test_instantiation_with_defaults(self): + scheduler = self.scheduler_class(**self.get_default_config()) + self.assertEqual(scheduler.config.num_train_timesteps, 1000) + self.assertEqual(scheduler.config.shift, 1.0) + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + self.assertAlmostEqual(scheduler.timesteps[-1].item(), 0.0, places=4) + + def test_apply_shift_identity(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=1.0)) + sigmas = torch.linspace(0.0, 1.0, 10) + torch.testing.assert_close(scheduler.apply_shift(sigmas), sigmas) + + def test_apply_shift_monotonic(self): + scheduler = self.scheduler_class(**self.get_default_config(shift=5.0)) + sigmas = torch.linspace(0.01, 0.99, 16) + shifted = scheduler.apply_shift(sigmas) + # shift > 1 must monotonically map [0,1] to [0,1] and increase intermediate values + self.assertTrue(torch.all(shifted >= 0)) + self.assertTrue(torch.all(shifted <= 1)) + self.assertTrue(torch.all(shifted[1:] - shifted[:-1] >= -1e-6)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 16, 21, 30, 52) # B, C, T, H, W (Wan2.1 latent shape) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + r_timestep = scheduler.timesteps[1:2] + + prev_sample = scheduler.step(model_output, sample, timestep=timestep, r_timestep=r_timestep) + self.assertEqual(prev_sample.shape, sample.shape) + self.assertEqual(prev_sample.dtype, model_output.dtype) + + def test_step_zero_interval_is_identity(self): + # When timestep == r_timestep the update collapses to the input sample. + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8, 8) + model_output = torch.randn_like(sample) + t = scheduler.timesteps[2:3] + + prev_sample = scheduler.step(model_output, sample, timestep=t, r_timestep=t) + torch.testing.assert_close(prev_sample, sample.to(model_output.dtype)) + + def test_step_one_shot_sampling(self): + # Flow-map promise: stepping straight from t=T to r=0 produces a clean sample in a single call. + scheduler = self.scheduler_class(**self.get_default_config(shift=5.0)) + scheduler.set_timesteps(num_inference_steps=1) + timesteps = scheduler.timesteps + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + + prev_sample = scheduler.step( + model_output, + sample, + timestep=timesteps[0:1], + r_timestep=timesteps[1:2], + ) + self.assertEqual(prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(prev_sample, sample)) + + def test_train_weight_gaussian_shape(self): + scheduler = self.scheduler_class(**self.get_default_config(weight_type="gaussian")) + weights = scheduler.linear_timesteps_weights + self.assertEqual(weights.shape, (scheduler.config.num_train_timesteps + 1,)) + self.assertTrue(torch.all(weights >= 0)) + + def test_train_weight_beta08_shape(self): + scheduler = self.scheduler_class(**self.get_default_config(weight_type="beta08")) + weights = scheduler.linear_timesteps_weights + self.assertEqual(weights.shape, (scheduler.config.num_train_timesteps + 1,)) + self.assertTrue(torch.all(weights >= 0)) + + def test_train_weight_invalid_raises(self): + with self.assertRaises(ValueError): + self.scheduler_class(**self.get_default_config(weight_type="not-a-real-type")) + + def test_get_train_weight_returns_per_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + timesteps = torch.tensor([0.0, 250.0, 500.0, 750.0, 1000.0]) + weights = scheduler.get_train_weight(timesteps) + self.assertEqual(weights.shape, timesteps.shape) + self.assertTrue(torch.all(weights >= 0)) + + def test_scale_noise_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + sample = torch.zeros(2, 4, 4, 4) + noise = torch.ones_like(sample) + # t=0 -> all sample, t=num_train_timesteps -> all noise. + zero_t = torch.tensor([0.0]) + torch.testing.assert_close(scheduler.scale_noise(sample, zero_t, noise), sample) + full_t = torch.tensor([float(scheduler.config.num_train_timesteps)]) + torch.testing.assert_close(scheduler.scale_noise(sample, full_t, noise), noise) From 2d1e39c0e042cc1fdc7b91b9d7035e9d2f92f725 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:51:02 +0800 Subject: [PATCH 03/16] [Models] AnyFlow: add AnyFlowTransformer3DModel A 3D DiT extending the v0.35.1 Wan2.1 backbone with two config-toggled modules: * FAR causal blocks (init_far_model=True): block-sparse causal attention via flex_attention + compressed-frame patch embedding for frame-level autoregressive generation (Gu et al., 2025, arXiv:2503.19325). * Dual-timestep flow-map embedding (init_flowmap_model=True): adds a delta timestep embedder enabling flow-map sampling z_t -> z_r over arbitrary intervals (AnyFlow). With both flags off, the model reduces to stock Wan2.1. The class is intentionally self-contained rather than annotated with '# Copied from diffusers.models.transformers.transformer_wan' because upstream Wan has been refactored extensively since v0.35.1 (new WanAttention class, different processor architecture). Tests: 9 unit tests covering construction in 3 modes, bidi forward shape and determinism, return_dict variants, save/load round-trip with and without init_far_model, gradient checkpointing toggle. Docs: api/models/anyflow_transformer3d.md --- docs/source/en/_toctree.yml | 2 + .../en/api/models/anyflow_transformer3d.md | 44 + src/diffusers/models/__init__.py | 1 + .../transformers/transformer_anyflow.py | 1244 +++++++++++++++++ .../test_models_transformer_anyflow.py | 180 +++ 5 files changed, 1471 insertions(+) create mode 100644 docs/source/en/api/models/anyflow_transformer3d.md create mode 100644 src/diffusers/models/transformers/transformer_anyflow.py create mode 100644 tests/models/transformers/test_models_transformer_anyflow.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6c6d36a7f483..290dd1ed164d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -328,6 +328,8 @@ title: AceStepTransformer1DModel - local: api/models/allegro_transformer3d title: AllegroTransformer3DModel + - local: api/models/anyflow_transformer3d + title: AnyFlowTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/transformer_bria_fibo diff --git a/docs/source/en/api/models/anyflow_transformer3d.md b/docs/source/en/api/models/anyflow_transformer3d.md new file mode 100644 index 000000000000..d30ceef91102 --- /dev/null +++ b/docs/source/en/api/models/anyflow_transformer3d.md @@ -0,0 +1,44 @@ + + +# AnyFlowTransformer3DModel + +A 3D Transformer used by `AnyFlowPipeline` and `AnyFlowCausalPipeline`. The architecture extends the +Wan2.1 3D DiT backbone with two optional modules controlled by config flags: + +1. **FAR causal blocks** (`init_far_model=True`) — block-sparse causal attention via + `torch.nn.attention.flex_attention` plus a compressed-frame patch embedding. Enables frame-level + autoregressive generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325). +2. **Dual-timestep flow-map embedding** (`init_flowmap_model=True`) — adds a second timestep embedder + (`delta_embedder`) that conditions on the target timestep `r_timestep` in addition to the source + timestep, enabling flow-map sampling $\mathbf{z}_t \to \mathbf{z}_r$ over arbitrary intervals (introduced + in [AnyFlow](https://huggingface.co/papers/)). + +Setting both flags to `False` reduces this model to the v0.35.1 Wan2.1 transformer. + +```python +from diffusers import AnyFlowTransformer3DModel + +# Bidirectional AnyFlow checkpoint (T2V): +transformer = AnyFlowTransformer3DModel.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer" +) + +# Causal AnyFlow checkpoint (FAR): +transformer = AnyFlowTransformer3DModel.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", subfolder="transformer" +) +``` + +## AnyFlowTransformer3DModel + +[[autodoc]] AnyFlowTransformer3DModel diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29a733009eef..1f27e1ec7f4c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -95,6 +95,7 @@ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] + _import_structure["transformers.transformer_anyflow"] = ["AnyFlowTransformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py new file mode 100644 index 000000000000..326341352da6 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -0,0 +1,1244 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file derives from the FAR architecture (Gu et al., 2025, arXiv:2503.19325) and adds AnyFlow's +# dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding). The base 3D DiT structure +# is adapted from the v0.35.1 Wan2.1 transformer (transformer_wan.py); upstream Wan has since been +# refactored, so this file is intentionally self-contained rather than annotated with `# Copied from`. + +import copy +import math +from types import SimpleNamespace +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import logging +from ..attention import FeedForward +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +flex_attention = torch.compile(flex_attention, dynamic=True) + + +def build_block_mask(mask_2d, device): + assert mask_2d.dim() == 2 and mask_2d.dtype == torch.bool + mask_2d = mask_2d.contiguous() + + Q_LEN, KV_LEN = mask_2d.shape + + def mask_mod(b, h, q_idx, kv_idx): + return mask_2d[q_idx, kv_idx] + + return create_block_mask(mask_mod, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, _compile=False) + + +def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + +class AnyFlowSelfAttnProcessor2_0: + def __init__(self): + if not hasattr(F, 'scaled_dot_product_attention'): + raise ImportError('AnyFlowSelfAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.') + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + kv_cache=None, + kv_cache_flag=None + ) -> torch.Tensor: + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if kv_cache is not None: + if kv_cache_flag['is_cache_step']: + kv_cache['compressed_cache'][0, :, :, :kv_cache_flag['num_compressed_tokens'], :] = key[:, :, :kv_cache_flag['num_compressed_tokens']] + kv_cache['compressed_cache'][1, :, :, :kv_cache_flag['num_compressed_tokens'], :] = value[:, :, :kv_cache_flag['num_compressed_tokens']] + kv_cache['full_cache'][0, :, :, :kv_cache_flag['num_full_tokens'], :] = key[:, :, kv_cache_flag['num_compressed_tokens']:] + kv_cache['full_cache'][1, :, :, :kv_cache_flag['num_full_tokens'], :] = value[:, :, kv_cache_flag['num_compressed_tokens']:] + else: + key = torch.cat([ + kv_cache['compressed_cache'][0, :, :, :kv_cache_flag['num_cached_compressed_tokens'], :], + kv_cache['full_cache'][0, :, :, :kv_cache_flag['num_cached_full_tokens'], :], + key + ], dim=2) + value = torch.cat([ + kv_cache['compressed_cache'][1, :, :, :kv_cache_flag['num_cached_compressed_tokens'], :], + kv_cache['full_cache'][1, :, :, :kv_cache_flag['num_cached_full_tokens'], :], + value + ], dim=2) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb['query']) + key = apply_rotary_emb(key, rotary_emb['key']) + + if attention_mask is None: + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + else: + seq_len = query.shape[2] + padded_length = int(math.ceil(seq_len / 128.0) * 128.0 - seq_len) + query = torch.cat([query, torch.zeros([query.shape[0], query.shape[1], padded_length, query.shape[3]], device=query.device, dtype=query.dtype)], dim=2) # noqa: E501 + key = torch.cat([key, torch.zeros([key.shape[0], key.shape[1], padded_length, key.shape[3]], device=key.device, dtype=key.dtype)], dim=2) # noqa: E501 + value = torch.cat([value, torch.zeros([value.shape[0], value.shape[1], padded_length, value.shape[3]], device=value.device, dtype=value.dtype)], dim=2) # noqa: E501 + + hidden_states = flex_attention(query, key, value, block_mask=attention_mask)[:, :, :seq_len] + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class AnyFlowCrossAttnProcessor2_0: + def __init__(self): + if not hasattr(F, 'scaled_dot_product_attention'): + raise ImportError('AnyFlowCrossAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.') + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb['query']) + key = apply_rotary_emb(key, rotary_emb['key']) + + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AnyFlowImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn='gelu') + self.norm2 = FP32LayerNorm(out_features) + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class AnyFlowTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn='gelu_tanh') + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim) + + def forward_timestep( + self, + timestep: torch.Tensor, + encoder_hidden_states, + token_per_frame + ): + batch_size, num_frames = timestep.shape + timestep = rearrange(timestep, 'b t -> (b t)') + + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + temb = rearrange(temb, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + timestep_proj = rearrange(timestep_proj, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + + return temb, timestep_proj + + def forward( + self, + timestep: torch.Tensor, + r_timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + far_cfg=None, + clean_timestep=None, + is_causal=True + ): + + if is_causal: + full_frame_timestep, full_frame_timestep_proj = self.forward_timestep(timestep[:, -far_cfg['num_full_frames']:], encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep(timestep[:, :-far_cfg['num_full_frames']], encoder_hidden_states, far_cfg['compressed_token_per_frame']) # noqa: E501 + + if clean_timestep is not None: + clean_timestep, clean_timestep_proj = self.forward_timestep(clean_timestep, clean_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1) + timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1) + else: + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1) + timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + else: + timestep, timestep_proj = self.forward_timestep(timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return timestep, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class AnyFlowDualTimestepTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + gate_value: float, + deltatime_type: str, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.delta_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn='gelu_tanh') + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim) + + self.register_buffer('delta_emb_gate', torch.tensor([gate_value], dtype=torch.float32), persistent=False) + self.deltatime_type = deltatime_type + + def forward_timestep( + self, + timestep: torch.Tensor, + delta_timestep: torch.Tensor, + encoder_hidden_states, + token_per_frame + ): + batch_size, num_frames = timestep.shape + timestep = rearrange(timestep, 'b t -> (b t)') + delta_timestep = rearrange(delta_timestep, 'b t -> (b t)') + + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + + delta_timestep = self.timesteps_proj(delta_timestep) + + delta_embedder_dtype = next(iter(self.delta_embedder.parameters())).dtype + if delta_timestep.dtype != delta_embedder_dtype and delta_embedder_dtype != torch.int8: + delta_timestep = delta_timestep.to(delta_embedder_dtype) + delta_emb = self.delta_embedder(delta_timestep).type_as(encoder_hidden_states) + + gate = self.delta_emb_gate.to(delta_embedder_dtype) + self.gate_track = float(gate) + + rt_emb = (1 - gate) * temb + gate * delta_emb + timestep_proj = self.time_proj(self.act_fn(rt_emb)) + + rt_emb = rearrange(rt_emb, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + timestep_proj = rearrange(timestep_proj, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + + return rt_emb, timestep_proj + + def forward( + self, + timestep: torch.Tensor, + r_timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + far_cfg=None, + clean_timestep=None, + is_causal=True + ): + if self.deltatime_type == 'r': + delta_timestep = r_timestep + elif self.deltatime_type == 't-r': + delta_timestep = timestep - r_timestep + else: + raise NotImplementedError + + if is_causal: + full_frame_timestep, full_frame_timestep_proj = self.forward_timestep(timestep[:, -far_cfg['num_full_frames']:], delta_timestep[:, -far_cfg['num_full_frames']:], encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep(timestep[:, :-far_cfg['num_full_frames']], delta_timestep[:, :-far_cfg['num_full_frames']], encoder_hidden_states, far_cfg['compressed_token_per_frame']) # noqa: E501 + + if clean_timestep is not None: + clean_timestep, clean_timestep_proj = self.forward_timestep(clean_timestep, clean_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1) + timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1) + else: + timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1) + timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + else: + timestep, timestep_proj = self.forward_timestep(timestep, delta_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return timestep, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class AnyFlowRotaryPosEmbed(nn.Module): + def __init__( + self, attention_head_dim: int, patch_size: Tuple[int, int, int], compressed_patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.compressed_patch_size = compressed_patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 + ) + freqs.append(freq) + self.freqs = torch.cat(freqs, dim=1) + + def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int): + + real = freq.real # [B, C, L], float + real = real.transpose(0, 1).unsqueeze(0) + imag = freq.imag # [B, C, L], float + imag = imag.transpose(0, 1).unsqueeze(0) + + pr = F.avg_pool1d(real, kernel_size, stride) + pi = F.avg_pool1d(imag, kernel_size, stride) + + pr = pr.squeeze(0).transpose(0, 1) + pi = pi.squeeze(0).transpose(0, 1) + + norm = torch.sqrt(pr**2 + pi**2) + pr_unit = pr / norm + pi_unit = pi / norm + + return torch.complex(pr_unit, pi_unit) + + def _forward_compressed_frame(self, num_frames, height, width, device): + ppf, pph, ppw = num_frames, height, width + downscale = [self.compressed_patch_size[i] // self.patch_size[i] for i in range(len(self.patch_size))] + + self.freqs = self.freqs.to(device) + freqs = self.freqs.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = self.avg_pool_complex(freqs[0], kernel_size=downscale[0], stride=downscale[0]) + freqs_h = self.avg_pool_complex(freqs[1], kernel_size=downscale[1], stride=downscale[1]) + freqs_w = self.avg_pool_complex(freqs[2], kernel_size=downscale[2], stride=downscale[2]) + + freqs_f = freqs_f[:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs_h[:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs_w[:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) + return freqs + + def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor: + ppf, pph, ppw = num_frames, height, width + + self.freqs = self.freqs.to(device) + freqs = self.freqs.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) + return freqs + + def forward(self, far_cfg, device, clean_hidden_states=None, is_causal=True): + if is_causal: + full_frame_freqs = self._forward_full_frame( + num_frames=far_cfg['total_frames'], + height=far_cfg['full_frame_shape'][0], + width=far_cfg['full_frame_shape'][1], + device=device + ) + compressed_frame_freqs = self._forward_compressed_frame( + num_frames=far_cfg['total_frames'], + height=far_cfg['compressed_frame_shape'][0], + width=far_cfg['compressed_frame_shape'][1], + device=device + ) + + compressed_frame_freqs, full_frame_freqs = compressed_frame_freqs[:far_cfg['num_compressed_frames']], full_frame_freqs[far_cfg['num_compressed_frames']:] # noqa: E501 + + compressed_frame_freqs = compressed_frame_freqs.flatten(start_dim=0, end_dim=2) + full_frame_freqs = full_frame_freqs.flatten(start_dim=0, end_dim=2) + + if clean_hidden_states is not None: + freqs = torch.cat([compressed_frame_freqs, full_frame_freqs, full_frame_freqs], dim=0) + else: + freqs = torch.cat([compressed_frame_freqs, full_frame_freqs], dim=0) + + freqs = freqs[None, None, ...] + + return {'query': freqs, 'key': freqs} + else: + freqs = self._forward_full_frame( + num_frames=far_cfg['total_frames'], + height=far_cfg['full_frame_shape'][0], + width=far_cfg['full_frame_shape'][1], + device=device + ) + freqs = freqs.flatten(start_dim=0, end_dim=2) + freqs = freqs[None, None, ...] + return {'query': freqs, 'key': freqs} + + +class AnyFlowTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = 'rms_norm_across_heads', + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = Attention( + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + processor=AnyFlowSelfAttnProcessor2_0(), + ) + + # 2. Cross-attention + self.attn2 = Attention( + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + added_kv_proj_dim=added_kv_proj_dim, + added_proj_bias=True, + processor=AnyFlowCrossAttnProcessor2_0(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn='gelu-approximate') + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + attention_mask: torch.Tensor, + kv_cache=None, + kv_cache_flag=None, + ) -> torch.Tensor: + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=2) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), c_shift_msa.squeeze(2), c_scale_msa.squeeze(2), c_gate_msa.squeeze(2) # noqa: E501 + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask, kv_cache=kv_cache, kv_cache_flag=kv_cache_flag) # noqa: E501 + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A 3D Transformer for any-step video diffusion. The architecture extends the Wan2.1 3D DiT backbone with two + optional modules controlled by config flags: + + 1. **FAR causal blocks** (``init_far_model=True``) — block-sparse causal attention via + ``torch.nn.attention.flex_attention`` plus a compressed-frame patch embedding. This enables frame-level + autoregressive generation as introduced in FAR ([Gu et al., 2025](https://arxiv.org/abs/2503.19325)). + 2. **Dual-timestep flow-map embedding** (``init_flowmap_model=True``) — adds a second timestep embedder + (``delta_embedder``) that conditions on the target timestep ``r_timestep`` in addition to the source + timestep, enabling flow-map sampling :math:`z_t \to z_r` over arbitrary intervals (AnyFlow). + + With both flags off, this model reduces to the v0.35.1 Wan2.1 transformer. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + compressed_patch_size (`Tuple[int]`, defaults to `(1, 4, 4)`): + Larger patch dimensions used by the FAR-compressed branch for context frames. Only consulted when + ``init_far_model=True``. + full_chunk_limit (`int`, defaults to `3`): + Maximum number of full-resolution chunks before earlier chunks are demoted to compressed FAR context. + Only consulted when ``init_far_model=True``. + num_attention_heads (`int`, defaults to `40`): + Number of attention heads. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input latent. + out_channels (`int`, defaults to `16`): + The number of channels in the output latent. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings (UMT5). + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + Number of transformer blocks. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`Optional[str]`, defaults to `'rms_norm_across_heads'`): + Query/key normalization scheme. + eps (`float`, defaults to `1e-6`): + Epsilon for normalization layers. + image_dim (`Optional[int]`, *optional*, defaults to `None`): + Image embedding dimension for I2V conditioning (`1280` for the original Wan2.1-I2V model). + added_kv_proj_dim (`Optional[int]`, *optional*, defaults to `None`): + The number of channels to use for the added key/value projections. If `None`, no projection is added. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length used to precompute rotary position frequencies. + chunk_partition (optional list of int, *optional*): + Default chunk partition for FAR (overridable per ``forward`` call). + init_far_model (`bool`, defaults to `False`): + Toggle the FAR causal-attention components. + init_flowmap_model (`bool`, defaults to `False`): + Toggle the dual-timestep flow-map embedding. Required by the AnyFlow distilled checkpoints. + gate_value (`float`, defaults to `0`): + Initial mixing gate between source-timestep and delta-timestep embeddings. Only consulted when + ``init_flowmap_model=True``. + deltatime_type (`str`, defaults to `'r'`): + Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval). Only + consulted when ``init_flowmap_model=True``. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ['patch_embedding', 'condition_embedder', 'norm'] + _no_split_modules = ['AnyFlowTransformerBlock'] + _keep_in_fp32_modules = ['time_embedder', 'scale_shift_table', 'norm1', 'norm2', 'norm3'] + _keys_to_ignore_on_load_unexpected = ['norm_added_q'] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + compressed_patch_size: Tuple[int] = (1, 4, 4), + full_chunk_limit: int = 3, + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = 'rms_norm_across_heads', + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + chunk_partition=None, + init_far_model=False, + init_flowmap_model=False, + gate_value=0, + deltatime_type='r' + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = AnyFlowRotaryPosEmbed(attention_head_dim, patch_size, compressed_patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = AnyFlowTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + AnyFlowTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + if init_far_model: + self.setup_far_model() + if init_flowmap_model: + self.setup_flowmap_model(gate_value=self.config.gate_value, deltatime_type=self.config.deltatime_type) + + def setup_flowmap_model(self, gate_value=0, deltatime_type='r'): + inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + condition_embedder = AnyFlowDualTimestepTextImageEmbedding( + dim=inner_dim, + gate_value=gate_value, + deltatime_type=deltatime_type, + time_freq_dim=self.config.freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=self.config.text_dim, + image_embed_dim=self.config.image_dim, + ) + condition_embedder.time_embedder = copy.deepcopy(self.condition_embedder.time_embedder) + condition_embedder.delta_embedder = copy.deepcopy(self.condition_embedder.time_embedder) + condition_embedder.time_proj = copy.deepcopy(self.condition_embedder.time_proj) + condition_embedder.text_embedder = copy.deepcopy(self.condition_embedder.text_embedder) + condition_embedder.image_embedder = copy.deepcopy(self.condition_embedder.image_embedder) + del self.condition_embedder + + self.condition_embedder = condition_embedder + + def setup_far_model(self): + inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.far_patch_embedding = nn.Conv3d( + self.config.in_channels, inner_dim, kernel_size=self.config.compressed_patch_size, stride=self.config.compressed_patch_size) + + # init far patch embedding + original_weight = self.patch_embedding.weight.data.view(-1, 1, *self.config.patch_size) + + new_weight = F.interpolate( + original_weight, size=self.config.compressed_patch_size, mode='trilinear', align_corners=False + ) + new_weight = new_weight.view(inner_dim, self.config.in_channels, *self.config.compressed_patch_size) + + with torch.no_grad(): + self.far_patch_embedding.weight.copy_(new_weight) + self.far_patch_embedding.bias.copy_(self.patch_embedding.bias) + + def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size): + batch_size, num_patches, channels = latents.shape + height, width = height // patch_size, width // patch_size + + latents = latents.view(batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size)) + + latents = latents.permute(0, 5, 1, 3, 2, 4) + latents = latents.reshape(batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size) + return latents + + def forward_far_patchify(self, hidden_states, far_cfg, clean_hidden_states=None): + + full_hidden_states, compressed_hidden_states = hidden_states[:, :, far_cfg['num_compressed_frames']:], hidden_states[:, :, :far_cfg['num_compressed_frames']] # noqa: E501 + + patchified_full_hidden_states = self.patch_embedding(full_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + if clean_hidden_states is not None: + clean_hidden_states = self.patch_embedding(clean_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + patchified_full_hidden_states = torch.cat([patchified_full_hidden_states, clean_hidden_states], dim=1) + + if far_cfg['num_compressed_frames'] > 0: + patchified_compressed_hidden_states = self.far_patch_embedding(compressed_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + hidden_states = torch.cat([patchified_compressed_hidden_states, patchified_full_hidden_states], dim=1) + else: + hidden_states = patchified_full_hidden_states + return hidden_states + + def forward_far_patchify_inference(self, hidden_states): + hidden_states = self.patch_embedding(hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + return hidden_states + + def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype): + chunk_partition = far_cfg['chunk_partition'] + + noise_seq_len = clean_seq_len = far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'] + context_seq_len = far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame'] + + noise_start = context_seq_len + noise_end = noise_start + noise_seq_len + + clean_start = context_seq_len + noise_seq_len + clean_end = clean_start + clean_seq_len + + if clean_hidden_states is not None: + real_seq_len = context_seq_len + noise_seq_len + clean_seq_len + else: + real_seq_len = context_seq_len + noise_seq_len + + padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0) + + if clean_hidden_states is not None: + context_chunk_partition, noise_chunk_partition = chunk_partition[:far_cfg['num_compressed_chunk']], chunk_partition[far_cfg['num_compressed_chunk']:] # noqa: E501 + + if len(context_chunk_partition) != 0: + context_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['compressed_token_per_frame'], device=device) * chunk_idx for chunk_idx, chunk_len in enumerate(context_chunk_partition)]) # noqa: E501 + else: + context_frame_idx = None + noise_frame_idx = clean_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['full_token_per_frame'], device=device) * (chunk_idx + len(context_chunk_partition)) for chunk_idx, chunk_len in enumerate(noise_chunk_partition)]) # noqa: E501 + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + # q_idx, kv_idx: LongTensor, range: [0, padded_seq_len) + + # 1) whether is padding + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + + # 3) chunk casual + base = frame_idx[q_idx] >= frame_idx[kv_idx] + + # 4) interval mask + q_is_context = q_idx < context_seq_len # noqa: F841 + q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end) + q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end) + + k_is_context = kv_idx < context_seq_len # noqa: F841 + k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end) + k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end) + + # 5) clean -> noise: disallowed + is_clean_to_noise = q_is_clean & k_is_noise + + # 6) noise -> noise: only same frame + same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx] + + noise_to_noise = q_is_noise & k_is_noise + noise_to_clean = q_is_noise & k_is_clean + + noise_to_noise_allow = noise_to_noise & same_frame_idx + noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow + + noise_to_clean_same = noise_to_clean & same_frame_idx + noise_to_clean_disallow = noise_to_clean_same + + # attention mask is chunk casual + allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow + return allowed + + return create_block_mask( + mask_mod, + B=None, + H=None, + Q_LEN=padded_seq_len, + KV_LEN=padded_seq_len, + device=device, + _compile=False, + ) + else: + context_chunk_partition, noise_chunk_partition = chunk_partition[:far_cfg['num_compressed_chunk']], chunk_partition[far_cfg['num_compressed_chunk']:] # noqa: E501 + + if len(context_chunk_partition) != 0: + context_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['compressed_token_per_frame'], device=device) * chunk_idx for chunk_idx, chunk_len in enumerate(context_chunk_partition)]) # noqa: E501 + else: + context_frame_idx = None + + noise_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['full_token_per_frame'], device=device) * (chunk_idx + len(context_chunk_partition)) for chunk_idx, chunk_len in enumerate(noise_chunk_partition)]) # noqa: E501 + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + base = base = frame_idx[q_idx] >= frame_idx[kv_idx] + return base & ~is_padding + + return create_block_mask( + mask_mod, + B=None, + H=None, + Q_LEN=padded_seq_len, + KV_LEN=padded_seq_len, + device=device, + _compile=False, + ) + + def forward(self, *args, **kwargs): + if kwargs.get('is_causal', True): + if kwargs.get('kv_cache', None) is not None: + if kwargs['kv_cache_flag'].get('is_cache_step'): + return self._forward_cache(*args, **kwargs) + else: + return self._forward_inference(*args, **kwargs) + else: + return self._forward_train(*args, **kwargs) + else: + return self._forward_bidirection(*args, **kwargs) + + def _forward_inference( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache=None, + kv_cache_flag=None + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + + total_chunks = 1 + kv_cache_flag['num_cached_chunks'] + + if total_chunks >= self.config.full_chunk_limit: + num_full_chunk, num_compressed_chunk = self.config.full_chunk_limit, total_chunks - self.config.full_chunk_limit + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + kv_cache_flag['num_cached_full_tokens'] = sum(chunk_partition[num_compressed_chunk: num_compressed_chunk + (num_full_chunk - 1)]) * full_token_per_frame # noqa: E501 + kv_cache_flag['num_cached_compressed_tokens'] = sum(chunk_partition[:num_compressed_chunk]) * compressed_token_per_frame + + far_cfg = { + 'total_frames': sum(chunk_partition), + 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), + 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), + 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), + 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), + 'full_token_per_frame': full_token_per_frame, + 'compressed_token_per_frame': compressed_token_per_frame + } + + # step 3: generate attention mask + attention_mask = None + hidden_states = self.forward_far_patchify_inference(hidden_states) + + rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device) + rotary_emb['query'] = rotary_emb['query'][:, :, -hidden_states.shape[1]:] + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg # noqa: E501 + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2), scale.squeeze(2) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + output = self.proj_out(hidden_states) + output = self._unpack_latent_sequence(output, num_frames=chunk_partition[-1], height=height, width=width, patch_size=self.config.patch_size[1]) + + if not return_dict: + return output, kv_cache + + return SimpleNamespace(sample=output, kv_cache=kv_cache) + + def _forward_cache( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + clean_hidden_states=None, + clean_timestep=None, + kv_cache=None, + kv_cache_flag=None + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + if clean_hidden_states is not None: + clean_hidden_states = rearrange(clean_hidden_states, 'b f c h w -> b c f h w') + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + total_chunks = len(chunk_partition) + + full_chunk_limit = self.config.full_chunk_limit - 1 + + if total_chunks > full_chunk_limit: + num_full_chunk, num_compressed_chunk = full_chunk_limit, total_chunks - full_chunk_limit + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + far_cfg = { + 'total_frames': sum(chunk_partition), + 'num_full_chunk': num_full_chunk, + 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), + 'num_compressed_chunk': num_compressed_chunk, + 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), + 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), + 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), + 'full_token_per_frame': full_token_per_frame, + 'compressed_token_per_frame': compressed_token_per_frame, + 'chunk_partition': chunk_partition + } + + kv_cache_flag['num_full_tokens'] = far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'] + kv_cache_flag['num_compressed_tokens'] = far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame'] + + # step 3: generate attention mask + attention_mask = self._build_causal_mask(far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype) + + rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) + hidden_states = self.forward_far_patchify(hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg, clean_timestep=clean_timestep + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag) + + return None, kv_cache + + def _forward_train( + self, + hidden_states: torch.Tensor, + chunk_partition, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + clean_hidden_states=None, + clean_timestep=None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + if clean_hidden_states is not None: + clean_hidden_states = rearrange(clean_hidden_states, 'b f c h w -> b c f h w') + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + total_chunks = len(chunk_partition) + + if total_chunks > self.config.full_chunk_limit: + num_full_chunk, num_compressed_chunk = self.config.full_chunk_limit, total_chunks - self.config.full_chunk_limit + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + far_cfg = { + 'total_frames': sum(chunk_partition), + 'num_full_chunk': num_full_chunk, + 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), + 'num_compressed_chunk': num_compressed_chunk, + 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), + 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), + 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), + 'full_token_per_frame': full_token_per_frame, + 'compressed_token_per_frame': compressed_token_per_frame, + 'chunk_partition': chunk_partition + } + + # step 3: generate attention mask + attention_mask = self._build_causal_mask(far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype) + + rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) + + hidden_states = self.forward_far_patchify(hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg, clean_timestep=clean_timestep + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + for index_block, block in enumerate(self.blocks): + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2), scale.squeeze(2) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + if clean_hidden_states is not None: + hidden_states = hidden_states[:, :-(far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'])] # remove clean copy + output = self.proj_out(hidden_states[:, far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame']:]) # remove far context + output = self._unpack_latent_sequence(output, num_frames=far_cfg['num_full_frames'], height=height, width=width, patch_size=self.config.patch_size[1]) # noqa: E501 + + if not return_dict: + return output + + return SimpleNamespace(sample=output) + + def _forward_bidirection( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + r_timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + is_causal=False + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + + assert is_causal is False + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + full_token_per_frame = (height * width) // (self.config.patch_size[1] * self.config.patch_size[2]) + + far_cfg = { + 'total_frames': num_frames, + 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), + 'full_token_per_frame': full_token_per_frame, + } + + rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device, is_causal=is_causal) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, is_causal=is_causal, far_cfg=far_cfg + ) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + + attention_mask = None + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + output = self._unpack_latent_sequence(hidden_states, num_frames=far_cfg['total_frames'], height=height, width=width, patch_size=self.config.patch_size[1]) # noqa: E501 + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/tests/models/transformers/test_models_transformer_anyflow.py b/tests/models/transformers/test_models_transformer_anyflow.py new file mode 100644 index 000000000000..2d44a476c472 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_anyflow.py @@ -0,0 +1,180 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch + +from diffusers import AnyFlowTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +from ...testing_utils import enable_full_determinism + + +# AnyFlow's rotary position embeddings use float64 buffers for numerical precision; the model is exercised +# on CPU/CUDA in production and is not validated on MPS. Tests pin all tensors to CPU to keep CI green on +# any backend. + +enable_full_determinism() + + +class AnyFlowTransformer3DModelTest(unittest.TestCase): + """ + Unit tests for ``AnyFlowTransformer3DModel``. + + The model has a non-standard ``forward`` signature (``is_causal``, ``r_timestep``, ``chunk_partition``, + ``kv_cache``, ``kv_cache_flag``) and dispatches between four code paths (bidirectional inference, causal + training, causal cache prefill, causal autoregressive inference). Fast unit tests cover the bidirectional + path here; the causal paths are exercised end-to-end by ``AnyFlowCausalPipelineIntegrationTests`` in + ``tests/pipelines/anyflow/test_anyflow_causal.py``. + """ + + @staticmethod + def _tiny_init_kwargs(**overrides): + kwargs = dict( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=4, + out_channels=4, + text_dim=16, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + kwargs.update(overrides) + return kwargs + + @staticmethod + def _tiny_bidi_inputs(batch_size=1, num_frames=2, height=16, width=16, text_seq_len=12, text_dim=16): + return dict( + hidden_states=torch.randn(batch_size, num_frames, 4, height, width, device="cpu"), + timestep=torch.full((batch_size, num_frames), 500.0, device="cpu"), + r_timestep=torch.full((batch_size, num_frames), 250.0, device="cpu"), + encoder_hidden_states=torch.randn(batch_size, text_seq_len, text_dim, device="cpu"), + is_causal=False, + return_dict=True, + ) + + def test_construction_base_wan(self): + m = AnyFlowTransformer3DModel(**self._tiny_init_kwargs()) + self.assertEqual(type(m.condition_embedder).__name__, "AnyFlowTimeTextImageEmbedding") + self.assertFalse(hasattr(m, "far_patch_embedding")) + + def test_construction_flowmap_only(self): + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ) + self.assertEqual(type(m.condition_embedder).__name__, "AnyFlowDualTimestepTextImageEmbedding") + self.assertFalse(hasattr(m, "far_patch_embedding")) + + def test_construction_far_plus_flowmap(self): + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs( + compressed_patch_size=(1, 4, 4), + full_chunk_limit=3, + init_far_model=True, + init_flowmap_model=True, + gate_value=0.25, + deltatime_type="r", + ) + ) + self.assertEqual(type(m.condition_embedder).__name__, "AnyFlowDualTimestepTextImageEmbedding") + self.assertTrue(hasattr(m, "far_patch_embedding")) + + def test_bidi_forward_shape_preserved(self): + torch.manual_seed(0) + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ).to("cpu").eval() + + inputs = self._tiny_bidi_inputs() + with torch.no_grad(): + out = m(**inputs) + self.assertIsInstance(out, Transformer2DModelOutput) + self.assertEqual(out.sample.shape, inputs["hidden_states"].shape) + + def test_bidi_forward_return_dict_false(self): + torch.manual_seed(0) + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ).to("cpu").eval() + + inputs = self._tiny_bidi_inputs() + inputs["return_dict"] = False + with torch.no_grad(): + out = m(**inputs) + self.assertIsInstance(out, tuple) + self.assertEqual(out[0].shape, inputs["hidden_states"].shape) + + def test_bidi_forward_determinism(self): + torch.manual_seed(0) + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ).to("cpu").eval() + + inputs_a = self._tiny_bidi_inputs() + inputs_b = {k: v.clone() if torch.is_tensor(v) else v for k, v in inputs_a.items()} + + with torch.no_grad(): + out_a = m(**inputs_a).sample + out_b = m(**inputs_b).sample + torch.testing.assert_close(out_a, out_b) + + def test_save_load_pretrained_roundtrip(self): + torch.manual_seed(0) + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ) + + with tempfile.TemporaryDirectory() as tmpdir: + m.save_pretrained(tmpdir) + m2 = AnyFlowTransformer3DModel.from_pretrained(tmpdir) + self.assertEqual(type(m2.condition_embedder).__name__, "AnyFlowDualTimestepTextImageEmbedding") + # Compare a parameter to ensure weights round-tripped. + torch.testing.assert_close( + m.condition_embedder.delta_embedder.linear_1.weight, + m2.condition_embedder.delta_embedder.linear_1.weight, + ) + + def test_save_load_pretrained_far_plus_flowmap(self): + torch.manual_seed(0) + m = AnyFlowTransformer3DModel( + **self._tiny_init_kwargs( + compressed_patch_size=(1, 4, 4), + full_chunk_limit=3, + init_far_model=True, + init_flowmap_model=True, + gate_value=0.25, + deltatime_type="r", + ) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + m.save_pretrained(tmpdir) + m2 = AnyFlowTransformer3DModel.from_pretrained(tmpdir) + self.assertTrue(hasattr(m2, "far_patch_embedding")) + torch.testing.assert_close(m.far_patch_embedding.weight, m2.far_patch_embedding.weight) + + def test_gradient_checkpointing_toggle(self): + m = AnyFlowTransformer3DModel(**self._tiny_init_kwargs(init_flowmap_model=True)) + self.assertFalse(m.gradient_checkpointing) + m.enable_gradient_checkpointing() + self.assertTrue(m.gradient_checkpointing) + m.disable_gradient_checkpointing() + self.assertFalse(m.gradient_checkpointing) From c0e8b12967236b5aa910921381097d9a005270d7 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:54:26 +0800 Subject: [PATCH 04/16] [Pipelines] AnyFlow: add AnyFlowPipeline and AnyFlowCausalPipeline * AnyFlowPipeline (pipeline_anyflow.py, ~590 LOC): bidirectional T2V using flow-map sampling. Loads checkpoints from nvidia/AnyFlow-Wan2.1-T2V-{1.3B,14B}. * AnyFlowCausalPipeline (pipeline_anyflow_causal.py, ~700 LOC): FAR-based causal pipeline supporting T2V/I2V/TV2V via task_type kwarg. Loads checkpoints from nvidia/AnyFlow-FAR-Wan2.1-{1.3B,14B}-Diffusers. Both pipelines reuse stock WanLoraLoaderMixin, AutoencoderKLWan, UMT5EncoderModel, and AutoTokenizer from upstream. The transformer is the AnyFlowTransformer3DModel introduced in the previous commit. The scheduler is FlowMapEulerDiscreteScheduler. Tests: * tests/pipelines/anyflow/test_anyflow.py: PipelineTesterMixin fast tests + slow integration test against nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers. * tests/pipelines/anyflow/test_anyflow_causal.py: same structure for FAR variant. Reference slices for slow integration tests are deferred to Phase 7 (Final quality pass) where the user runs them on a real GPU. --- .../pipelines/anyflow/pipeline_anyflow.py | 593 +++++++++++++++ .../anyflow/pipeline_anyflow_causal.py | 708 ++++++++++++++++++ tests/pipelines/anyflow/__init__.py | 0 tests/pipelines/anyflow/test_anyflow.py | 184 +++++ .../pipelines/anyflow/test_anyflow_causal.py | 194 +++++ 5 files changed, 1679 insertions(+) create mode 100644 src/diffusers/pipelines/anyflow/pipeline_anyflow.py create mode 100644 src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py create mode 100644 tests/pipelines/anyflow/__init__.py create mode 100644 tests/pipelines/anyflow/test_anyflow.py create mode 100644 tests/pipelines/anyflow/test_anyflow_causal.py diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py new file mode 100644 index 000000000000..ebddfc0b10b7 --- /dev/null +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py @@ -0,0 +1,593 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for any-step flow-map sampling. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from einops import rearrange +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AnyFlowTransformer3DModel, AutoencoderKLWan +from ...models.autoencoders.vae import DiagonalGaussianDistribution +from ...schedulers import FlowMapEulerDiscreteScheduler +from ...utils import is_ftfy_available, logging +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnyFlowPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints. + + AnyFlow learns arbitrary-interval transitions :math:`z_t \to z_r` rather than the fixed + :math:`z_t \to z_0` mapping of consistency models, so a single distilled checkpoint can be evaluated at + 1, 2, 4, 8, 16... NFE without retraining. This pipeline operates over the full video tensor in one + bidirectional pass; for frame-level autoregressive (causal) generation use ``AnyFlowCausalPipeline``. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [google/umt5-xxl](https://huggingface.co/google/umt5-xxl). + text_encoder ([`UMT5EncoderModel`]): + [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) text encoder. + transformer ([`AnyFlowTransformer3DModel`]): + Conditional 3D Transformer (must be configured with ``init_flowmap_model=True``). + vae ([`AutoencoderKLWan`]): + VAE that encodes/decodes videos to and from latent representations. + scheduler ([`FlowMapEulerDiscreteScheduler`]): + Flow-map sampler. The pipeline drives ``scheduler.step(..., timestep, r_timestep)`` per inference + step. + use_mean_velocity (`bool`, defaults to `True`): + When ``True`` the model output is averaged across two anchor times to reduce discretization error + (the default training-time behavior). Disable to mirror raw Euler stepping. + """ + + model_cpu_offload_seq = 'text_encoder->transformer->vae' + _callback_tensor_inputs = ['latents', 'prompt_embeds', 'negative_prompt_embeds'] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: AnyFlowTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMapEulerDiscreteScheduler, + use_mean_velocity: bool = True + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, 'vae', None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, 'vae', None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.use_mean_velocity = use_mean_velocity + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding='max_length', + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors='pt', + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or '' + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=' + f' {type(prompt)}.' + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' + f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' + ' the batch size of `prompt`.' + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f'`height` and `width` have to be divisible by 16 but are {height} and {width}.') + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}' # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to' + ' only forward one of the two.' + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to' + ' only forward one of the two.' + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.' + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f'`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}') + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f'You have passed a list of generators of length {len(generator)}, but requested an effective batch' + f' size of {batch_size}. Make sure the batch size matches the length of the generators.' + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = rearrange(latents, 'b c t h w -> b t c h w') + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + def vae_encode(self, context_sequence): + # normalize: [0, 1] -> [-1, 1] + context_sequence = context_sequence * 2 - 1 + context_sequence = self.encode_latents(context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False) + context_sequence = rearrange(context_sequence, 'b c t h w -> b t c h w') + return context_sequence + + def _normalize_latents(self, latents, latents_mean, latents_std): + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents + + @torch.no_grad() + def encode_latents(self, videos, sample=True): + videos = rearrange(videos, 'b t c h w -> b c t h w') + moments = self.vae._encode(videos) + + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std) + + mu, logvar = torch.chunk(moments, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + + if sample: + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + + latents = torch.cat([mu, logvar], dim=1) + posterior = DiagonalGaussianDistribution(latents) + latents = posterior.sample(generator=None) + del posterior + else: + latents = mu + return latents + + def training_rollout( + self, + context_sequence=None, + num_inference_steps: int = 50, + grad_timestep: int = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + guidance_scale: float = 1.0, + ): + self._guidance_scale = guidance_scale + + if negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # setup start sequence + if context_sequence is not None: + context_length = context_sequence.shape[1] + + def inference_range(latents, timesteps): + + for i, t in enumerate(tqdm(timesteps[:-1])): + r = timesteps[i + 1] + + if t == r: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = timestep.repeat((1, latent_model_input.shape[1])) + + if self.use_mean_velocity: + r_timestep = r.expand(latent_model_input.shape[0]).unsqueeze(-1) + r_timestep = r_timestep.repeat((1, latent_model_input.shape[1])) + else: + r_timestep = timestep + + if context_sequence is not None: + latent_model_input[:, :context_length, ...] = context_sequence + timestep[:, :context_length] = 0 + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + r_timestep=r_timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + is_causal=False + )[0] + + if self.do_classifier_free_guidance: + noise_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step(noise_pred, latents, t, r) + + return latents + + device = self._execution_device + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if grad_timestep is None: + x_final_val = inference_range(latents, timesteps) + return x_final_val + + # 6. Denoising loop + self._num_timesteps = len(timesteps) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + prev_timestep = [timesteps[0], timesteps[grad_timestep]] + current_timestep = [timesteps[grad_timestep], timesteps[grad_timestep + 1]] + post_timestep = [timesteps[grad_timestep + 1], timesteps[-1]] + + # 1. Fast-forward to the target timestep without tracking gradients + latents = inference_range(latents, prev_timestep) + + # 2. Execute a single differentiable step to anchor the gradient flow + x_next_grad = inference_range(latents, current_timestep) + + # 3. Complete the rollout to x0 in no_grad mode to save VRAM + x_final_val = inference_range(x_next_grad, post_timestep) + return x_final_val + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + context_sequence=None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = 'np', + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ['latents'], + max_sequence_length: int = 512, + ): + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f'`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number.' + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = torch.bfloat16 + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + init_latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + init_latents = init_latents.to(transformer_dtype) + + # setup start sequence + if context_sequence is not None: + context_sequence = self.vae_encode(context_sequence) + context_length = context_sequence.shape[1] + + latents = self.training_rollout( + context_sequence=context_sequence, + num_inference_steps=num_inference_steps, + grad_timestep=None, + latents=init_latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale + ) + if context_sequence is not None: + latents[:, :context_length, ...] = context_sequence + latents = rearrange(latents, 'b f c h w -> b c f h w') + + if not output_type == 'latent': + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnyFlowPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py new file mode 100644 index 000000000000..04ea594e2127 --- /dev/null +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py @@ -0,0 +1,708 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from diffusers.pipelines.wan.pipeline_wan.WanPipeline (v0.35.1) for FAR causal flow-map sampling. + +import copy +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from einops import rearrange +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AnyFlowTransformer3DModel, AutoencoderKLWan +from ...models.autoencoders.vae import DiagonalGaussianDistribution +from ...schedulers import FlowMapEulerDiscreteScheduler +from ...utils import is_ftfy_available, logging +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnyFlowPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class AnyFlowCausalPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Causal (FAR-based) text-to-video / image-to-video / text+video-to-video pipeline for AnyFlow checkpoints. + + The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with + flow-map steps while attending only to past chunks via block-sparse causal attention, and intermediate + KV cache is reused across chunks. Set ``task_type`` per call to switch between ``"t2v"``, ``"i2v"``, and + ``"tv2v"``. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [google/umt5-xxl](https://huggingface.co/google/umt5-xxl). + text_encoder ([`UMT5EncoderModel`]): + [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) text encoder. + transformer ([`AnyFlowTransformer3DModel`]): + Conditional 3D Transformer (must be configured with ``init_far_model=True`` and + ``init_flowmap_model=True``). + vae ([`AutoencoderKLWan`]): + VAE that encodes/decodes videos to and from latent representations. + scheduler ([`FlowMapEulerDiscreteScheduler`]): + Flow-map sampler. + use_mean_velocity (`bool`, defaults to `True`): + When ``True`` the model output is averaged across two anchor times to reduce discretization error. + """ + + model_cpu_offload_seq = 'text_encoder->transformer->vae' + _callback_tensor_inputs = ['latents', 'prompt_embeds', 'negative_prompt_embeds'] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: AnyFlowTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMapEulerDiscreteScheduler, + use_mean_velocity: bool = True + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, 'vae', None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, 'vae', None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.use_mean_velocity = use_mean_velocity + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding='max_length', + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors='pt', + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or '' + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=' + f' {type(prompt)}.' + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' + f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' + ' the batch size of `prompt`.' + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f'`height` and `width` have to be divisible by 16 but are {height} and {width}.') + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}' # noqa: E501 + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to' + ' only forward one of the two.' + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to' + ' only forward one of the two.' + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.' + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f'`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}') + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f'You have passed a list of generators of length {len(generator)}, but requested an effective batch' + f' size of {batch_size}. Make sure the batch size matches the length of the generators.' + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = rearrange(latents, 'b c t h w -> b t c h w') + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + def vae_encode(self, context_sequence): + # normalize: [0, 1] -> [-1, 1] + context_sequence = context_sequence * 2 - 1 + context_sequence = self.encode_latents(context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False) + context_sequence = rearrange(context_sequence, 'b c t h w -> b t c h w') + return context_sequence + + def _normalize_latents(self, latents, latents_mean, latents_std): + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents + + @torch.no_grad() + def encode_latents(self, videos, sample=True): + videos = rearrange(videos, 'b t c h w -> b c t h w') + moments = self.vae._encode(videos) + + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std) + + mu, logvar = torch.chunk(moments, 2, dim=1) + mu = self._normalize_latents(mu, latents_mean, latents_std) + + if sample: + logvar = self._normalize_latents(logvar, latents_mean, latents_std) + + latents = torch.cat([mu, logvar], dim=1) + posterior = DiagonalGaussianDistribution(latents) + latents = posterior.sample(generator=None) + del posterior + else: + latents = mu + return latents + + def inference( + self, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = 'np', + return_dict: bool = True, + kv_cache=None, + kv_cache_flag=None, + grad_timestep=None, + chunk_partition=None + ): + if negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + def inference_range(latents, timesteps): + + for i, t in enumerate(timesteps[:-1]): + r = timesteps[i + 1] + + if t == r: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = timestep.repeat((1, latent_model_input.shape[1])) + + if self.use_mean_velocity: + r_timestep = r.expand(latent_model_input.shape[0]).unsqueeze(-1) + r_timestep = r_timestep.repeat((1, latent_model_input.shape[1])) + else: + r_timestep = timestep + + noise_pred, _ = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + r_timestep=r_timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + chunk_partition=chunk_partition, + # kv-cache related + kv_cache=kv_cache, + kv_cache_flag=copy.deepcopy(kv_cache_flag) + ) + if self.do_classifier_free_guidance: + noise_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step(noise_pred, latents, t, r) + + return latents + + device = self._execution_device + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if grad_timestep is None: + x_final_val = inference_range(latents, timesteps) + return x_final_val + + # 6. Denoising loop + self._num_timesteps = len(timesteps) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + prev_timestep = [timesteps[0], timesteps[grad_timestep]] + current_timestep = [timesteps[grad_timestep], timesteps[grad_timestep + 1]] + post_timestep = [timesteps[grad_timestep + 1], timesteps[-1]] + + # 1. Fast-forward to the target timestep without tracking gradients + latents = inference_range(latents, prev_timestep) + + # 2. Execute a single differentiable step to anchor the gradient flow + x_next_grad = inference_range(latents, current_timestep) + + # 3. Complete the rollout to x0 in no_grad mode to save VRAM + x_final_val = inference_range(x_next_grad, post_timestep) + return x_final_val + + def training_rollout( + self, + context_sequence=None, + num_inference_steps: int = 50, + grad_timestep: int = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + guidance_scale: float = 1.0, + use_kv_cache=True, + ): + self._guidance_scale = guidance_scale + + latents = rearrange(latents, 'b c t h w -> b t c h w') + batch_size, num_frame, _, height, width = latents.shape + + # 5. Prepare latent variables + init_latents = latents + + chunk_partition = self.transformer.config.chunk_partition + + assert init_latents.shape[1] == sum(chunk_partition), 'please check the chunk_partition equal to num_smaple_frames' + + full_token_per_frame = (init_latents.shape[3] // self.transformer.config.patch_size[1]) * (init_latents.shape[4] // self.transformer.config.patch_size[2]) # noqa: E501 + compressed_token_per_frame = (init_latents.shape[3] // self.transformer.config.compressed_patch_size[1]) * (init_latents.shape[4] // self.transformer.config.compressed_patch_size[2]) # noqa: E501 + + # init kv cache + if use_kv_cache: + kv_cache = {} + + batch_size = batch_size * 2 if self.do_classifier_free_guidance else batch_size + + for layer_idx in range(self.transformer.config.num_layers): + kv_cache[layer_idx] = { + 'full_cache': torch.zeros(( + 2, batch_size, self.transformer.config.num_attention_heads, + self.transformer.config.full_chunk_limit * max(chunk_partition) * full_token_per_frame, + self.transformer.config.attention_head_dim + ), device=init_latents.device, dtype=init_latents.dtype), + 'compressed_cache': torch.zeros(( + 2, batch_size, self.transformer.config.num_attention_heads, + (len(chunk_partition) - self.transformer.config.full_chunk_limit + 1) * max(chunk_partition) * compressed_token_per_frame, + self.transformer.config.attention_head_dim + ), device=init_latents.device, dtype=init_latents.dtype) + } + + kv_cache_flag = { + 'num_cached_chunks': 0, + 'is_cache_step': False, + } + else: + kv_cache = None + kv_cache_flag = None + + output = torch.zeros_like(init_latents) + + # setup start sequence + if context_sequence is not None: + if 'latent' in context_sequence: + latents = rearrange(context_sequence['latent'], 'b c t h w -> b t c h w') + else: + assert (context_sequence['raw'].shape[1] - 1) % 4 == 0, 'require 4n+1 frames' + latents = self.vae_encode(context_sequence['raw']) + current_context_length = latents.shape[1] + output[:, :current_context_length] = latents + num_context_chunks = next(i + 1 for i in range(len(chunk_partition)) if sum(chunk_partition[:i + 1]) >= current_context_length) + else: + num_context_chunks = 0 + + for chunk_idx in tqdm(range(len(chunk_partition))): + + if chunk_idx >= num_context_chunks: + pred_latents = self.inference( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + kv_cache=kv_cache, + kv_cache_flag=kv_cache_flag, + latents=init_latents[:, sum(chunk_partition[:chunk_idx]): sum(chunk_partition[:chunk_idx + 1])], + num_inference_steps=num_inference_steps, + grad_timestep=grad_timestep, + guidance_scale=guidance_scale, + chunk_partition=chunk_partition[:chunk_idx + 1] + ) + output[:, sum(chunk_partition[:chunk_idx]): sum(chunk_partition[:chunk_idx + 1])] = pred_latents + + # step1: save to kv cache + if chunk_idx < len(chunk_partition) - 1: + kv_cache = self.encode_kv_cache(kv_cache, kv_cache_flag, chunk_partition=chunk_partition[:chunk_idx + 1], chunk_idx=chunk_idx, output=output, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds) # noqa: E501 + + output = rearrange(output, 'b f c h w -> b c f h w') + return output + + @torch.no_grad() + def encode_kv_cache(self, kv_cache, kv_cache_flag, chunk_partition, chunk_idx, output, prompt_embeds, negative_prompt_embeds): + kv_cache_flag['is_cache_step'] = True + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + latents = output[:, :sum(chunk_partition)] + latent_model_input = torch.cat([latents] * 2).to(torch.bfloat16) if self.do_classifier_free_guidance else latents.to(torch.bfloat16) + + timestep = torch.tensor([0], device=latents.device).expand(latent_model_input.shape[0]).unsqueeze(-1) + timestep = timestep.repeat((1, latent_model_input.shape[1])) + + r_timestep = torch.tensor([0], device=latents.device).expand(latent_model_input.shape[0]).unsqueeze(-1) + r_timestep = r_timestep.repeat((1, latent_model_input.shape[1])) + + _, kv_cache = self.transformer( + hidden_states=latent_model_input, + chunk_partition=chunk_partition, + timestep=timestep, + r_timestep=r_timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + # kv-cache related + kv_cache=kv_cache, + kv_cache_flag=copy.deepcopy(kv_cache_flag) + ) + + kv_cache_flag['num_cached_chunks'] += 1 + kv_cache_flag['is_cache_step'] = False + + return kv_cache + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + context_sequence=None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = 'np', + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ['latents'], + max_sequence_length: int = 512, + show_progress=True, + use_kv_cache=True + ): + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f'`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number.' + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = torch.bfloat16 + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + init_latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + init_latents = init_latents.to(transformer_dtype) + init_latents = rearrange(init_latents, 'b f c h w -> b c f h w') + + latents = self.training_rollout( + context_sequence=context_sequence, num_inference_steps=num_inference_steps, + grad_timestep=None, + latents=init_latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale + ) + + if not output_type == 'latent': + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnyFlowPipelineOutput(frames=video) diff --git a/tests/pipelines/anyflow/__init__.py b/tests/pipelines/anyflow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/anyflow/test_anyflow.py b/tests/pipelines/anyflow/test_anyflow.py new file mode 100644 index 000000000000..831b1f46823b --- /dev/null +++ b/tests/pipelines/anyflow/test_anyflow.py @@ -0,0 +1,184 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer, T5EncoderModel + +from diffusers import ( + AnyFlowPipeline, + AnyFlowTransformer3DModel, + AutoencoderKLWan, + FlowMapEulerDiscreteScheduler, +) + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AnyFlowPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AnyFlowPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0, weight_type="gaussian") + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = AnyFlowTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + init_flowmap_model=True, + gate_value=0.25, + deltatime_type="r", + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + @unittest.skip("AnyFlow uses mixed-precision flow-map sampling; FP16 round-trip is not numerically stable.") + def test_save_load_float16(self): + pass + + @unittest.skip("AnyFlow has no optional components.") + def test_save_load_optional_components(self): + pass + + +@slow +@require_torch_accelerator +class AnyFlowPipelineIntegrationTests(unittest.TestCase): + """End-to-end integration tests against released NVIDIA AnyFlow checkpoints. Run with ``RUN_SLOW=1``.""" + + prompt = "A cat walks on the grass, realistic style." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_anyflow_t2v_1_3b(self): + pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", + torch_dtype=torch.bfloat16, + ) + pipe.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + video = pipe( + prompt=self.prompt, + num_inference_steps=4, + num_frames=33, + height=480, + width=832, + generator=generator, + output_type="pt", + ).frames + + self.assertEqual(video[0].shape, (33, 3, 480, 832)) + # TODO(Phase 7): capture reference slice on real GPU and add tolerance assertion. Until then, we only + # assert the output tensor's shape is correct (catches regressions in the sampling loop's frame count). diff --git a/tests/pipelines/anyflow/test_anyflow_causal.py b/tests/pipelines/anyflow/test_anyflow_causal.py new file mode 100644 index 000000000000..3bce0cd6d19a --- /dev/null +++ b/tests/pipelines/anyflow/test_anyflow_causal.py @@ -0,0 +1,194 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer, T5EncoderModel + +from diffusers import ( + AnyFlowCausalPipeline, + AnyFlowTransformer3DModel, + AutoencoderKLWan, + FlowMapEulerDiscreteScheduler, +) + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AnyFlowCausalPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + """ + Fast tests for the FAR-causal AnyFlow pipeline. Only T2V is exercised here; the I2V / TV2V branches are + only meaningful at the spatial resolutions used by released checkpoints and are covered in the slow + integration tests below. + """ + + pipeline_class = AnyFlowCausalPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0, weight_type="gaussian") + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = AnyFlowTransformer3DModel( + patch_size=(1, 2, 2), + compressed_patch_size=(1, 4, 4), + full_chunk_limit=3, + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + init_far_model=True, + init_flowmap_model=True, + gate_value=0.25, + deltatime_type="r", + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", + "task_type": "t2v", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + @unittest.skip("AnyFlow uses mixed-precision flow-map sampling; FP16 round-trip is not numerically stable.") + def test_save_load_float16(self): + pass + + @unittest.skip("AnyFlow has no optional components.") + def test_save_load_optional_components(self): + pass + + +@slow +@require_torch_accelerator +class AnyFlowCausalPipelineIntegrationTests(unittest.TestCase): + """End-to-end integration tests against released NVIDIA AnyFlow-FAR checkpoints. Run with ``RUN_SLOW=1``.""" + + prompt = "A cat walks on the grass, realistic style." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_anyflow_far_t2v_1_3b(self): + pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", + torch_dtype=torch.bfloat16, + ) + pipe.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + video = pipe( + prompt=self.prompt, + task_type="t2v", + num_inference_steps=4, + num_frames=33, + height=480, + width=832, + generator=generator, + output_type="pt", + ).frames + + self.assertEqual(video[0].shape, (33, 3, 480, 832)) + # TODO(Phase 7): capture reference slice on real GPU and add tolerance assertion. From c650f705bc280bc3ba676c7caaba224442bc8fba Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:55:35 +0800 Subject: [PATCH 05/16] [Docs] AnyFlow: add main pipeline documentation page Modeled on the Helios pipeline doc (PR #13208). Sections: paper link + abstract, supported checkpoints table, memory/speed optimization tabs, T2V/I2V/TV2V examples for both bidirectional and causal variants, autodoc trailers. --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/anyflow.md | 192 ++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 docs/source/en/api/pipelines/anyflow.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 290dd1ed164d..41be879d8173 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -504,6 +504,8 @@ - sections: - local: api/pipelines/animatediff title: AnimateDiff + - local: api/pipelines/anyflow + title: AnyFlow - local: api/pipelines/aura_flow title: AuraFlow - local: api/pipelines/bria_3_2 diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md new file mode 100644 index 000000000000..e62bb2cb2ebd --- /dev/null +++ b/docs/source/en/api/pipelines/anyflow.md @@ -0,0 +1,192 @@ + + +
+
+ + LoRA + +
+
+ +# AnyFlow + +[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/) by Yuchao Gu et al. + +*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.* + +The original code can be found at [Enderfga/AnyFlow](https://github.com/Enderfga/AnyFlow). The project page is at [anyflow.github.io](https://anyflow.github.io/). + +The following AnyFlow checkpoints are supported: + +| Checkpoint | Backbone | Description | +|------------|----------|-------------| +| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight | +| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality | +| [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / TV2V | +| [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / TV2V | + +> [!TIP] +> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowCausalPipeline` for streaming I2V, video continuation (TV2V), or any setup that benefits from frame-by-frame autoregressive sampling. + +> [!TIP] +> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks. + +### Optimizing Memory and Inference Speed + + + + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + + + + +```py +import torch +from diffusers import AnyFlowPipeline + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + + + + +### Generation with AnyFlow (Bidirectional T2V) + + + + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "A red panda eating bamboo in a forest, cinematic lighting" +video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +### Generation with AnyFlow (FAR Causal) + + + + +```py +import torch +from diffusers import AnyFlowCausalPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +video = pipe( + prompt="A cat surfing a wave, sunset", + task_type="t2v", + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +```py +import torch +from diffusers import AnyFlowCausalPipeline +from diffusers.utils import export_to_video, load_image + +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +img = load_image("path/to/first_frame.png") +video = pipe( + prompt="a cat walks across a sunlit lawn", + image=img, + task_type="i2v", + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +```py +import torch +from diffusers import AnyFlowCausalPipeline +from diffusers.utils import export_to_video, load_video + +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +context = load_video("path/to/context.mp4") +video = pipe( + prompt="continue the story", + video=context, + task_type="tv2v", + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "out.mp4", fps=16) +``` + + + + +## Notes + +- `FlowMapEulerDiscreteScheduler` is general-purpose. You can attach it to any flow-map-distilled checkpoint via `from_pretrained(..., scheduler=FlowMapEulerDiscreteScheduler.from_config(...))`. +- The bidirectional pipeline accepts any `AnyFlowTransformer3DModel` configured with `init_flowmap_model=True`. The causal pipeline additionally requires `init_far_model=True`. +- LoRA training is supported via `WanLoraLoaderMixin`, the same mixin used by the upstream Wan pipelines. + +## AnyFlowPipeline + +[[autodoc]] AnyFlowPipeline + - all + - __call__ + +## AnyFlowCausalPipeline + +[[autodoc]] AnyFlowCausalPipeline + - all + - __call__ + +## AnyFlowPipelineOutput + +[[autodoc]] pipelines.anyflow.pipeline_output.AnyFlowPipelineOutput From 3276d0ae53fb06846053fca4f0d4ef20f3553ae2 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:56:28 +0800 Subject: [PATCH 06/16] [Auto/Scripts] AnyFlow: register AutoPipelineForText2Video + add conversion script * Register AnyFlowPipeline in AUTO_TEXT2VIDEO_PIPELINES_MAPPING. * AnyFlowCausalPipeline is intentionally NOT registered for AutoPipeline because its task switch (t2v / i2v / tv2v) is too rich for a single auto-resolve key. * scripts/convert_anyflow_to_diffusers.py: convert .pt training checkpoints (with 'ema' state dict) into a diffusers save_pretrained layout. Supports all 4 released NVIDIA AnyFlow variants. Replaces the omegaconf-based config in the upstream repo with argparse to match other diffusers conversion scripts. --- scripts/convert_anyflow_to_diffusers.py | 150 +++++++++++++++++++++++ src/diffusers/pipelines/auto_pipeline.py | 2 + 2 files changed, 152 insertions(+) create mode 100644 scripts/convert_anyflow_to_diffusers.py diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py new file mode 100644 index 000000000000..e58b0bd0b9e4 --- /dev/null +++ b/scripts/convert_anyflow_to_diffusers.py @@ -0,0 +1,150 @@ +# Copyright 2026 The AnyFlow Team, NVIDIA Corp., and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert AnyFlow training checkpoints to the diffusers ``save_pretrained`` layout. + +The AnyFlow training pipeline emits ``.pt`` files containing an ``ema`` key whose value is a flat state +dict for the transformer. This script: + +1. Loads the matching base Wan2.1 pipeline from the Hub (provides VAE, tokenizer, and text encoder). +2. Constructs an ``AnyFlowTransformer3DModel`` with the right config flags for the chosen variant. +3. Loads the ``ema`` weights into the transformer. +4. Wraps everything in an ``AnyFlowPipeline`` (bidirectional) or ``AnyFlowCausalPipeline`` (FAR causal). +5. Calls ``pipeline.save_pretrained(output_dir)``. + +Example: + +```bash +python scripts/convert_anyflow_to_diffusers.py \\ + --variant AnyFlow-FAR-Wan2.1-1.3B-Diffusers \\ + --ckpt /path/to/anyflow-checkpoint.pt \\ + --output-dir /path/to/output/AnyFlow-FAR-Wan2.1-1.3B-Diffusers +``` +""" + +import argparse +import os + +import torch + +from diffusers import ( + AnyFlowCausalPipeline, + AnyFlowPipeline, + AnyFlowTransformer3DModel, + FlowMapEulerDiscreteScheduler, +) + + +# Per-variant configuration. ``base_model`` is fetched from the Hub to source the matching VAE / text encoder. +VARIANTS = { + "AnyFlow-FAR-Wan2.1-1.3B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "init_far_model": True, + "init_flowmap_model": True, + "transformer_kwargs": { + "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2], + "full_chunk_limit": 3, + "compressed_patch_size": [1, 4, 4], + }, + "pipeline_cls": AnyFlowCausalPipeline, + }, + "AnyFlow-FAR-Wan2.1-14B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "init_far_model": True, + "init_flowmap_model": True, + "transformer_kwargs": { + "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2], + "full_chunk_limit": 3, + "compressed_patch_size": [1, 4, 4], + }, + "pipeline_cls": AnyFlowCausalPipeline, + }, + "AnyFlow-Wan2.1-T2V-1.3B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "init_far_model": False, + "init_flowmap_model": True, + "transformer_kwargs": {}, + "pipeline_cls": AnyFlowPipeline, + }, + "AnyFlow-Wan2.1-T2V-14B-Diffusers": { + "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "init_far_model": False, + "init_flowmap_model": True, + "transformer_kwargs": {}, + "pipeline_cls": AnyFlowPipeline, + }, +} + + +def build_pipeline(variant: str, ckpt_path: str): + if variant not in VARIANTS: + raise ValueError(f"Unknown variant {variant!r}. Choices: {list(VARIANTS)}.") + spec = VARIANTS[variant] + + transformer = AnyFlowTransformer3DModel.from_pretrained( + spec["base_model"], + subfolder="transformer", + init_far_model=spec["init_far_model"], + init_flowmap_model=spec["init_flowmap_model"], + gate_value=0.25, + deltatime_type="r", + **spec["transformer_kwargs"], + ) + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["ema"] + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + if unexpected: + print(f"[warn] unexpected keys in state dict (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + if missing: + print(f"[warn] missing keys not loaded from state dict: {missing[:5]}{'...' if len(missing) > 5 else ''}") + + scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0) + + pipeline = spec["pipeline_cls"].from_pretrained( + spec["base_model"], + transformer=transformer, + scheduler=scheduler, + ) + return pipeline + + +def main(): + parser = argparse.ArgumentParser( + description="Convert an AnyFlow training checkpoint into a diffusers pipeline directory." + ) + parser.add_argument( + "--variant", + required=True, + choices=list(VARIANTS), + help="Which AnyFlow variant the checkpoint corresponds to.", + ) + parser.add_argument( + "--ckpt", + required=True, + help="Path to the AnyFlow training checkpoint (a .pt file containing an 'ema' key).", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Destination directory for pipeline.save_pretrained.", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + pipeline = build_pipeline(args.variant, args.ckpt) + pipeline.save_pretrained(args.output_dir) + print(f"Saved {args.variant} pipeline to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 2876798e14bd..e758dac95fc5 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -125,6 +125,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from .anyflow import AnyFlowPipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .z_image import ( ZImageControlNetInpaintPipeline, @@ -249,6 +250,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( [ + ("anyflow", AnyFlowPipeline), ("wan", WanPipeline), ] ) From 41b2d9e6e924fe0a447c8f7c485c22ff53bf3467 Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 14:58:05 +0800 Subject: [PATCH 07/16] [Quality] AnyFlow: ruff-format + regenerated dummy stubs * ruff format pass on all 5 source files (long lines + trailing comma fixes) * check_dummies.py --fix_and_overwrite regenerated: - dummy_pt_objects.py: AnyFlowTransformer3DModel + FlowMapEulerDiscreteScheduler - dummy_torch_and_transformers_objects.py: AnyFlowPipeline + AnyFlowCausalPipeline Local fast tests: 21/21 passed - 12 scheduler tests (FlowMapEulerDiscreteScheduler) - 9 transformer tests (AnyFlowTransformer3DModel construction + bidi forward + save/load) The pipeline fast tests in tests/pipelines/anyflow/ require a local dev install that matches the diffusers main branch's transformers >= compatibility floor. The reference slices for slow integration tests (real GPU + 1.3B/14B checkpoints) are intentionally left as TODO stubs to be captured by the user on a real GPU machine before opening the PR. --- scripts/convert_anyflow_to_diffusers.py | 4 +- .../transformers/transformer_anyflow.py | 624 ++++++++++++------ .../pipelines/anyflow/pipeline_anyflow.py | 74 ++- .../anyflow/pipeline_anyflow_causal.py | 186 +++--- src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 30 + .../test_models_transformer_anyflow.py | 74 ++- 7 files changed, 692 insertions(+), 330 deletions(-) diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py index e58b0bd0b9e4..c4193b424976 100644 --- a/scripts/convert_anyflow_to_diffusers.py +++ b/scripts/convert_anyflow_to_diffusers.py @@ -104,7 +104,9 @@ def build_pipeline(variant: str, ckpt_path: str): state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["ema"] missing, unexpected = transformer.load_state_dict(state_dict, strict=False) if unexpected: - print(f"[warn] unexpected keys in state dict (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + print( + f"[warn] unexpected keys in state dict (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}" + ) if missing: print(f"[warn] missing keys not loaded from state dict: {missing[:5]}{'...' if len(missing) > 5 else ''}") diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py index 326341352da6..26c8c882112a 100644 --- a/src/diffusers/models/transformers/transformer_anyflow.py +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -65,8 +65,10 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): class AnyFlowSelfAttnProcessor2_0: def __init__(self): - if not hasattr(F, 'scaled_dot_product_attention'): - raise ImportError('AnyFlowSelfAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.') + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AnyFlowSelfAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -76,7 +78,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, kv_cache=None, - kv_cache_flag=None + kv_cache_flag=None, ) -> torch.Tensor: if encoder_hidden_states is None: @@ -96,35 +98,79 @@ def __call__( value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if kv_cache is not None: - if kv_cache_flag['is_cache_step']: - kv_cache['compressed_cache'][0, :, :, :kv_cache_flag['num_compressed_tokens'], :] = key[:, :, :kv_cache_flag['num_compressed_tokens']] - kv_cache['compressed_cache'][1, :, :, :kv_cache_flag['num_compressed_tokens'], :] = value[:, :, :kv_cache_flag['num_compressed_tokens']] - kv_cache['full_cache'][0, :, :, :kv_cache_flag['num_full_tokens'], :] = key[:, :, kv_cache_flag['num_compressed_tokens']:] - kv_cache['full_cache'][1, :, :, :kv_cache_flag['num_full_tokens'], :] = value[:, :, kv_cache_flag['num_compressed_tokens']:] + if kv_cache_flag["is_cache_step"]: + kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_compressed_tokens"], :] = key[ + :, :, : kv_cache_flag["num_compressed_tokens"] + ] + kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_compressed_tokens"], :] = value[ + :, :, : kv_cache_flag["num_compressed_tokens"] + ] + kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_full_tokens"], :] = key[ + :, :, kv_cache_flag["num_compressed_tokens"] : + ] + kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_full_tokens"], :] = value[ + :, :, kv_cache_flag["num_compressed_tokens"] : + ] else: - key = torch.cat([ - kv_cache['compressed_cache'][0, :, :, :kv_cache_flag['num_cached_compressed_tokens'], :], - kv_cache['full_cache'][0, :, :, :kv_cache_flag['num_cached_full_tokens'], :], - key - ], dim=2) - value = torch.cat([ - kv_cache['compressed_cache'][1, :, :, :kv_cache_flag['num_cached_compressed_tokens'], :], - kv_cache['full_cache'][1, :, :, :kv_cache_flag['num_cached_full_tokens'], :], - value - ], dim=2) + key = torch.cat( + [ + kv_cache["compressed_cache"][0, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :], + kv_cache["full_cache"][0, :, :, : kv_cache_flag["num_cached_full_tokens"], :], + key, + ], + dim=2, + ) + value = torch.cat( + [ + kv_cache["compressed_cache"][1, :, :, : kv_cache_flag["num_cached_compressed_tokens"], :], + kv_cache["full_cache"][1, :, :, : kv_cache_flag["num_cached_full_tokens"], :], + value, + ], + dim=2, + ) if rotary_emb is not None: - query = apply_rotary_emb(query, rotary_emb['query']) - key = apply_rotary_emb(key, rotary_emb['key']) + query = apply_rotary_emb(query, rotary_emb["query"]) + key = apply_rotary_emb(key, rotary_emb["key"]) if attention_mask is None: - hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) else: seq_len = query.shape[2] padded_length = int(math.ceil(seq_len / 128.0) * 128.0 - seq_len) - query = torch.cat([query, torch.zeros([query.shape[0], query.shape[1], padded_length, query.shape[3]], device=query.device, dtype=query.dtype)], dim=2) # noqa: E501 - key = torch.cat([key, torch.zeros([key.shape[0], key.shape[1], padded_length, key.shape[3]], device=key.device, dtype=key.dtype)], dim=2) # noqa: E501 - value = torch.cat([value, torch.zeros([value.shape[0], value.shape[1], padded_length, value.shape[3]], device=value.device, dtype=value.dtype)], dim=2) # noqa: E501 + query = torch.cat( + [ + query, + torch.zeros( + [query.shape[0], query.shape[1], padded_length, query.shape[3]], + device=query.device, + dtype=query.dtype, + ), + ], + dim=2, + ) # noqa: E501 + key = torch.cat( + [ + key, + torch.zeros( + [key.shape[0], key.shape[1], padded_length, key.shape[3]], device=key.device, dtype=key.dtype + ), + ], + dim=2, + ) # noqa: E501 + value = torch.cat( + [ + value, + torch.zeros( + [value.shape[0], value.shape[1], padded_length, value.shape[3]], + device=value.device, + dtype=value.dtype, + ), + ], + dim=2, + ) # noqa: E501 hidden_states = flex_attention(query, key, value, block_mask=attention_mask)[:, :, :seq_len] @@ -139,8 +185,10 @@ def __call__( class AnyFlowCrossAttnProcessor2_0: def __init__(self): - if not hasattr(F, 'scaled_dot_product_attention'): - raise ImportError('AnyFlowCrossAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.') + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AnyFlowCrossAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -165,10 +213,12 @@ def __call__( value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if rotary_emb is not None: - query = apply_rotary_emb(query, rotary_emb['query']) - key = apply_rotary_emb(key, rotary_emb['key']) + query = apply_rotary_emb(query, rotary_emb["query"]) + key = apply_rotary_emb(key, rotary_emb["key"]) - hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -183,7 +233,7 @@ def __init__(self, in_features: int, out_features: int): super().__init__() self.norm1 = FP32LayerNorm(in_features) - self.ff = FeedForward(in_features, out_features, mult=1, activation_fn='gelu') + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") self.norm2 = FP32LayerNorm(out_features) def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: @@ -208,20 +258,15 @@ def __init__( self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) - self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn='gelu_tanh') + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") self.image_embedder = None if image_embed_dim is not None: self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim) - def forward_timestep( - self, - timestep: torch.Tensor, - encoder_hidden_states, - token_per_frame - ): + def forward_timestep(self, timestep: torch.Tensor, encoder_hidden_states, token_per_frame): batch_size, num_frames = timestep.shape - timestep = rearrange(timestep, 'b t -> (b t)') + timestep = rearrange(timestep, "b t -> (b t)") timestep = self.timesteps_proj(timestep) @@ -231,8 +276,10 @@ def forward_timestep( temb = self.time_embedder(timestep).type_as(encoder_hidden_states) timestep_proj = self.time_proj(self.act_fn(temb)) - temb = rearrange(temb, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) - timestep_proj = rearrange(timestep_proj, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + temb = rearrange(temb, "(b t) c -> b t c", b=batch_size).repeat_interleave(token_per_frame, dim=1) + timestep_proj = rearrange(timestep_proj, "(b t) c -> b t c", b=batch_size).repeat_interleave( + token_per_frame, dim=1 + ) return temb, timestep_proj @@ -244,17 +291,27 @@ def forward( encoder_hidden_states_image: Optional[torch.Tensor] = None, far_cfg=None, clean_timestep=None, - is_causal=True + is_causal=True, ): if is_causal: - full_frame_timestep, full_frame_timestep_proj = self.forward_timestep(timestep[:, -far_cfg['num_full_frames']:], encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 - compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep(timestep[:, :-far_cfg['num_full_frames']], encoder_hidden_states, far_cfg['compressed_token_per_frame']) # noqa: E501 + full_frame_timestep, full_frame_timestep_proj = self.forward_timestep( + timestep[:, -far_cfg["num_full_frames"] :], encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 + compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep( + timestep[:, : -far_cfg["num_full_frames"]], + encoder_hidden_states, + far_cfg["compressed_token_per_frame"], + ) # noqa: E501 if clean_timestep is not None: - clean_timestep, clean_timestep_proj = self.forward_timestep(clean_timestep, clean_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + clean_timestep, clean_timestep_proj = self.forward_timestep( + clean_timestep, clean_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1) - timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1) + timestep_proj = torch.cat( + [compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1 + ) else: timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1) timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1) @@ -263,7 +320,9 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) else: - timestep, timestep_proj = self.forward_timestep(timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + timestep, timestep_proj = self.forward_timestep( + timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 encoder_hidden_states = self.text_embedder(encoder_hidden_states) if encoder_hidden_states_image is not None: @@ -290,25 +349,21 @@ def __init__( self.delta_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) - self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn='gelu_tanh') + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") self.image_embedder = None if image_embed_dim is not None: self.image_embedder = AnyFlowImageEmbedding(image_embed_dim, dim) - self.register_buffer('delta_emb_gate', torch.tensor([gate_value], dtype=torch.float32), persistent=False) + self.register_buffer("delta_emb_gate", torch.tensor([gate_value], dtype=torch.float32), persistent=False) self.deltatime_type = deltatime_type def forward_timestep( - self, - timestep: torch.Tensor, - delta_timestep: torch.Tensor, - encoder_hidden_states, - token_per_frame + self, timestep: torch.Tensor, delta_timestep: torch.Tensor, encoder_hidden_states, token_per_frame ): batch_size, num_frames = timestep.shape - timestep = rearrange(timestep, 'b t -> (b t)') - delta_timestep = rearrange(delta_timestep, 'b t -> (b t)') + timestep = rearrange(timestep, "b t -> (b t)") + delta_timestep = rearrange(delta_timestep, "b t -> (b t)") timestep = self.timesteps_proj(timestep) @@ -330,8 +385,10 @@ def forward_timestep( rt_emb = (1 - gate) * temb + gate * delta_emb timestep_proj = self.time_proj(self.act_fn(rt_emb)) - rt_emb = rearrange(rt_emb, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) - timestep_proj = rearrange(timestep_proj, '(b t) c -> b t c', b=batch_size).repeat_interleave(token_per_frame, dim=1) + rt_emb = rearrange(rt_emb, "(b t) c -> b t c", b=batch_size).repeat_interleave(token_per_frame, dim=1) + timestep_proj = rearrange(timestep_proj, "(b t) c -> b t c", b=batch_size).repeat_interleave( + token_per_frame, dim=1 + ) return rt_emb, timestep_proj @@ -343,23 +400,37 @@ def forward( encoder_hidden_states_image: Optional[torch.Tensor] = None, far_cfg=None, clean_timestep=None, - is_causal=True + is_causal=True, ): - if self.deltatime_type == 'r': + if self.deltatime_type == "r": delta_timestep = r_timestep - elif self.deltatime_type == 't-r': + elif self.deltatime_type == "t-r": delta_timestep = timestep - r_timestep else: raise NotImplementedError if is_causal: - full_frame_timestep, full_frame_timestep_proj = self.forward_timestep(timestep[:, -far_cfg['num_full_frames']:], delta_timestep[:, -far_cfg['num_full_frames']:], encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 - compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep(timestep[:, :-far_cfg['num_full_frames']], delta_timestep[:, :-far_cfg['num_full_frames']], encoder_hidden_states, far_cfg['compressed_token_per_frame']) # noqa: E501 + full_frame_timestep, full_frame_timestep_proj = self.forward_timestep( + timestep[:, -far_cfg["num_full_frames"] :], + delta_timestep[:, -far_cfg["num_full_frames"] :], + encoder_hidden_states, + far_cfg["full_token_per_frame"], + ) # noqa: E501 + compressed_frame_timestep, compressed_frame_timestep_proj = self.forward_timestep( + timestep[:, : -far_cfg["num_full_frames"]], + delta_timestep[:, : -far_cfg["num_full_frames"]], + encoder_hidden_states, + far_cfg["compressed_token_per_frame"], + ) # noqa: E501 if clean_timestep is not None: - clean_timestep, clean_timestep_proj = self.forward_timestep(clean_timestep, clean_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + clean_timestep, clean_timestep_proj = self.forward_timestep( + clean_timestep, clean_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 timestep = torch.cat([compressed_frame_timestep, full_frame_timestep, clean_timestep], dim=1) - timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1) + timestep_proj = torch.cat( + [compressed_frame_timestep_proj, full_frame_timestep_proj, clean_timestep_proj], dim=1 + ) else: timestep = torch.cat([compressed_frame_timestep, full_frame_timestep], dim=1) timestep_proj = torch.cat([compressed_frame_timestep_proj, full_frame_timestep_proj], dim=1) @@ -368,7 +439,9 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) else: - timestep, timestep_proj = self.forward_timestep(timestep, delta_timestep, encoder_hidden_states, far_cfg['full_token_per_frame']) # noqa: E501 + timestep, timestep_proj = self.forward_timestep( + timestep, delta_timestep, encoder_hidden_states, far_cfg["full_token_per_frame"] + ) # noqa: E501 encoder_hidden_states = self.text_embedder(encoder_hidden_states) if encoder_hidden_states_image is not None: @@ -379,7 +452,12 @@ def forward( class AnyFlowRotaryPosEmbed(nn.Module): def __init__( - self, attention_head_dim: int, patch_size: Tuple[int, int, int], compressed_patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + compressed_patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, ): super().__init__() @@ -401,9 +479,9 @@ def __init__( def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int): - real = freq.real # [B, C, L], float + real = freq.real # [B, C, L], float real = real.transpose(0, 1).unsqueeze(0) - imag = freq.imag # [B, C, L], float + imag = freq.imag # [B, C, L], float imag = imag.transpose(0, 1).unsqueeze(0) pr = F.avg_pool1d(real, kernel_size, stride) @@ -465,19 +543,22 @@ def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor def forward(self, far_cfg, device, clean_hidden_states=None, is_causal=True): if is_causal: full_frame_freqs = self._forward_full_frame( - num_frames=far_cfg['total_frames'], - height=far_cfg['full_frame_shape'][0], - width=far_cfg['full_frame_shape'][1], - device=device + num_frames=far_cfg["total_frames"], + height=far_cfg["full_frame_shape"][0], + width=far_cfg["full_frame_shape"][1], + device=device, ) compressed_frame_freqs = self._forward_compressed_frame( - num_frames=far_cfg['total_frames'], - height=far_cfg['compressed_frame_shape'][0], - width=far_cfg['compressed_frame_shape'][1], - device=device + num_frames=far_cfg["total_frames"], + height=far_cfg["compressed_frame_shape"][0], + width=far_cfg["compressed_frame_shape"][1], + device=device, ) - compressed_frame_freqs, full_frame_freqs = compressed_frame_freqs[:far_cfg['num_compressed_frames']], full_frame_freqs[far_cfg['num_compressed_frames']:] # noqa: E501 + compressed_frame_freqs, full_frame_freqs = ( + compressed_frame_freqs[: far_cfg["num_compressed_frames"]], + full_frame_freqs[far_cfg["num_compressed_frames"] :], + ) # noqa: E501 compressed_frame_freqs = compressed_frame_freqs.flatten(start_dim=0, end_dim=2) full_frame_freqs = full_frame_freqs.flatten(start_dim=0, end_dim=2) @@ -489,17 +570,17 @@ def forward(self, far_cfg, device, clean_hidden_states=None, is_causal=True): freqs = freqs[None, None, ...] - return {'query': freqs, 'key': freqs} + return {"query": freqs, "key": freqs} else: freqs = self._forward_full_frame( - num_frames=far_cfg['total_frames'], - height=far_cfg['full_frame_shape'][0], - width=far_cfg['full_frame_shape'][1], - device=device + num_frames=far_cfg["total_frames"], + height=far_cfg["full_frame_shape"][0], + width=far_cfg["full_frame_shape"][1], + device=device, ) freqs = freqs.flatten(start_dim=0, end_dim=2) freqs = freqs[None, None, ...] - return {'query': freqs, 'key': freqs} + return {"query": freqs, "key": freqs} class AnyFlowTransformerBlock(nn.Module): @@ -508,7 +589,7 @@ def __init__( dim: int, ffn_dim: int, num_heads: int, - qk_norm: str = 'rms_norm_across_heads', + qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, @@ -548,7 +629,7 @@ def __init__( self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward - self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn='gelu-approximate') + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -564,12 +645,27 @@ def forward( kv_cache_flag=None, ) -> torch.Tensor: - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=2) - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), c_shift_msa.squeeze(2), c_scale_msa.squeeze(2), c_gate_msa.squeeze(2) # noqa: E501 + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=2) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + shift_msa.squeeze(2), + scale_msa.squeeze(2), + gate_msa.squeeze(2), + c_shift_msa.squeeze(2), + c_scale_msa.squeeze(2), + c_gate_msa.squeeze(2), + ) # noqa: E501 # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask, kv_cache=kv_cache, kv_cache_flag=kv_cache_flag) # noqa: E501 + attn_output = self.attn1( + hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + kv_cache_flag=kv_cache_flag, + ) # noqa: E501 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -578,7 +674,9 @@ def forward( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) @@ -651,10 +749,10 @@ class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO """ _supports_gradient_checkpointing = True - _skip_layerwise_casting_patterns = ['patch_embedding', 'condition_embedder', 'norm'] - _no_split_modules = ['AnyFlowTransformerBlock'] - _keep_in_fp32_modules = ['time_embedder', 'scale_shift_table', 'norm1', 'norm2', 'norm3'] - _keys_to_ignore_on_load_unexpected = ['norm_added_q'] + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["AnyFlowTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config def __init__( @@ -671,7 +769,7 @@ def __init__( ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, - qk_norm: Optional[str] = 'rms_norm_across_heads', + qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, @@ -680,7 +778,7 @@ def __init__( init_far_model=False, init_flowmap_model=False, gate_value=0, - deltatime_type='r' + deltatime_type="r", ) -> None: super().__init__() @@ -723,7 +821,7 @@ def __init__( if init_flowmap_model: self.setup_flowmap_model(gate_value=self.config.gate_value, deltatime_type=self.config.deltatime_type) - def setup_flowmap_model(self, gate_value=0, deltatime_type='r'): + def setup_flowmap_model(self, gate_value=0, deltatime_type="r"): inner_dim = self.config.num_attention_heads * self.config.attention_head_dim condition_embedder = AnyFlowDualTimestepTextImageEmbedding( @@ -748,13 +846,17 @@ def setup_far_model(self): inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.far_patch_embedding = nn.Conv3d( - self.config.in_channels, inner_dim, kernel_size=self.config.compressed_patch_size, stride=self.config.compressed_patch_size) + self.config.in_channels, + inner_dim, + kernel_size=self.config.compressed_patch_size, + stride=self.config.compressed_patch_size, + ) # init far patch embedding original_weight = self.patch_embedding.weight.data.view(-1, 1, *self.config.patch_size) new_weight = F.interpolate( - original_weight, size=self.config.compressed_patch_size, mode='trilinear', align_corners=False + original_weight, size=self.config.compressed_patch_size, mode="trilinear", align_corners=False ) new_weight = new_weight.view(inner_dim, self.config.in_channels, *self.config.compressed_patch_size) @@ -766,23 +868,36 @@ def _unpack_latent_sequence(self, latents, num_frames, height, width, patch_size batch_size, num_patches, channels = latents.shape height, width = height // patch_size, width // patch_size - latents = latents.view(batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size)) + latents = latents.view( + batch_size * num_frames, height, width, patch_size, patch_size, channels // (patch_size * patch_size) + ) latents = latents.permute(0, 5, 1, 3, 2, 4) - latents = latents.reshape(batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size) + latents = latents.reshape( + batch_size, num_frames, channels // (patch_size * patch_size), height * patch_size, width * patch_size + ) return latents def forward_far_patchify(self, hidden_states, far_cfg, clean_hidden_states=None): - full_hidden_states, compressed_hidden_states = hidden_states[:, :, far_cfg['num_compressed_frames']:], hidden_states[:, :, :far_cfg['num_compressed_frames']] # noqa: E501 + full_hidden_states, compressed_hidden_states = ( + hidden_states[:, :, far_cfg["num_compressed_frames"] :], + hidden_states[:, :, : far_cfg["num_compressed_frames"]], + ) # noqa: E501 - patchified_full_hidden_states = self.patch_embedding(full_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + patchified_full_hidden_states = ( + self.patch_embedding(full_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) if clean_hidden_states is not None: - clean_hidden_states = self.patch_embedding(clean_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + clean_hidden_states = ( + self.patch_embedding(clean_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) patchified_full_hidden_states = torch.cat([patchified_full_hidden_states, clean_hidden_states], dim=1) - if far_cfg['num_compressed_frames'] > 0: - patchified_compressed_hidden_states = self.far_patch_embedding(compressed_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + if far_cfg["num_compressed_frames"] > 0: + patchified_compressed_hidden_states = ( + self.far_patch_embedding(compressed_hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) + ) hidden_states = torch.cat([patchified_compressed_hidden_states, patchified_full_hidden_states], dim=1) else: hidden_states = patchified_full_hidden_states @@ -793,10 +908,10 @@ def forward_far_patchify_inference(self, hidden_states): return hidden_states def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype): - chunk_partition = far_cfg['chunk_partition'] + chunk_partition = far_cfg["chunk_partition"] - noise_seq_len = clean_seq_len = far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'] - context_seq_len = far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame'] + noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] + context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] noise_start = context_seq_len noise_end = noise_start + noise_seq_len @@ -812,13 +927,27 @@ def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype): padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0) if clean_hidden_states is not None: - context_chunk_partition, noise_chunk_partition = chunk_partition[:far_cfg['num_compressed_chunk']], chunk_partition[far_cfg['num_compressed_chunk']:] # noqa: E501 + context_chunk_partition, noise_chunk_partition = ( + chunk_partition[: far_cfg["num_compressed_chunk"]], + chunk_partition[far_cfg["num_compressed_chunk"] :], + ) # noqa: E501 if len(context_chunk_partition) != 0: - context_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['compressed_token_per_frame'], device=device) * chunk_idx for chunk_idx, chunk_len in enumerate(context_chunk_partition)]) # noqa: E501 + context_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx + for chunk_idx, chunk_len in enumerate(context_chunk_partition) + ] + ) # noqa: E501 else: context_frame_idx = None - noise_frame_idx = clean_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['full_token_per_frame'], device=device) * (chunk_idx + len(context_chunk_partition)) for chunk_idx, chunk_len in enumerate(noise_chunk_partition)]) # noqa: E501 + noise_frame_idx = clean_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) # noqa: E501 pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) if len(context_chunk_partition) != 0: @@ -873,14 +1002,28 @@ def mask_mod(b, h, q_idx, kv_idx): _compile=False, ) else: - context_chunk_partition, noise_chunk_partition = chunk_partition[:far_cfg['num_compressed_chunk']], chunk_partition[far_cfg['num_compressed_chunk']:] # noqa: E501 + context_chunk_partition, noise_chunk_partition = ( + chunk_partition[: far_cfg["num_compressed_chunk"]], + chunk_partition[far_cfg["num_compressed_chunk"] :], + ) # noqa: E501 if len(context_chunk_partition) != 0: - context_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['compressed_token_per_frame'], device=device) * chunk_idx for chunk_idx, chunk_len in enumerate(context_chunk_partition)]) # noqa: E501 + context_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx + for chunk_idx, chunk_len in enumerate(context_chunk_partition) + ] + ) # noqa: E501 else: context_frame_idx = None - noise_frame_idx = torch.cat([torch.ones(chunk_len * far_cfg['full_token_per_frame'], device=device) * (chunk_idx + len(context_chunk_partition)) for chunk_idx, chunk_len in enumerate(noise_chunk_partition)]) # noqa: E501 + noise_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) # noqa: E501 pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) if len(context_chunk_partition) != 0: @@ -904,9 +1047,9 @@ def mask_mod(b, h, q_idx, kv_idx): ) def forward(self, *args, **kwargs): - if kwargs.get('is_causal', True): - if kwargs.get('kv_cache', None) is not None: - if kwargs['kv_cache_flag'].get('is_cache_step'): + if kwargs.get("is_causal", True): + if kwargs.get("kv_cache", None) is not None: + if kwargs["kv_cache_flag"].get("is_cache_step"): return self._forward_cache(*args, **kwargs) else: return self._forward_inference(*args, **kwargs) @@ -926,34 +1069,47 @@ def _forward_inference( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, kv_cache=None, - kv_cache_flag=None + kv_cache_flag=None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + hidden_states = rearrange(hidden_states, "b f c h w -> b c f h w") batch_size, num_channels, num_frames, height, width = hidden_states.shape full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) - compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) - total_chunks = 1 + kv_cache_flag['num_cached_chunks'] + total_chunks = 1 + kv_cache_flag["num_cached_chunks"] if total_chunks >= self.config.full_chunk_limit: - num_full_chunk, num_compressed_chunk = self.config.full_chunk_limit, total_chunks - self.config.full_chunk_limit + num_full_chunk, num_compressed_chunk = ( + self.config.full_chunk_limit, + total_chunks - self.config.full_chunk_limit, + ) else: num_full_chunk, num_compressed_chunk = total_chunks, 0 - kv_cache_flag['num_cached_full_tokens'] = sum(chunk_partition[num_compressed_chunk: num_compressed_chunk + (num_full_chunk - 1)]) * full_token_per_frame # noqa: E501 - kv_cache_flag['num_cached_compressed_tokens'] = sum(chunk_partition[:num_compressed_chunk]) * compressed_token_per_frame + kv_cache_flag["num_cached_full_tokens"] = ( + sum(chunk_partition[num_compressed_chunk : num_compressed_chunk + (num_full_chunk - 1)]) + * full_token_per_frame + ) # noqa: E501 + kv_cache_flag["num_cached_compressed_tokens"] = ( + sum(chunk_partition[:num_compressed_chunk]) * compressed_token_per_frame + ) far_cfg = { - 'total_frames': sum(chunk_partition), - 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), - 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), - 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), - 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), - 'full_token_per_frame': full_token_per_frame, - 'compressed_token_per_frame': compressed_token_per_frame + "total_frames": sum(chunk_partition), + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, } # step 3: generate attention mask @@ -961,10 +1117,14 @@ def _forward_inference( hidden_states = self.forward_far_patchify_inference(hidden_states) rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device) - rotary_emb['query'] = rotary_emb['query'][:, :, -hidden_states.shape[1]:] + rotary_emb["query"] = rotary_emb["query"][:, :, -hidden_states.shape[1] :] temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg # noqa: E501 + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, # noqa: E501 ) timestep_proj = timestep_proj.unflatten(2, (6, -1)) @@ -973,13 +1133,27 @@ def _forward_inference( # 4. Transformer blocks for index_block, block in enumerate(self.blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, ) else: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(2)).chunk(2, dim=2) @@ -995,7 +1169,9 @@ def _forward_inference( hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) output = self.proj_out(hidden_states) - output = self._unpack_latent_sequence(output, num_frames=chunk_partition[-1], height=height, width=width, patch_size=self.config.patch_size[1]) + output = self._unpack_latent_sequence( + output, num_frames=chunk_partition[-1], height=height, width=width, patch_size=self.config.patch_size[1] + ) if not return_dict: return output, kv_cache @@ -1015,17 +1191,19 @@ def _forward_cache( clean_hidden_states=None, clean_timestep=None, kv_cache=None, - kv_cache_flag=None + kv_cache_flag=None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + hidden_states = rearrange(hidden_states, "b f c h w -> b c f h w") if clean_hidden_states is not None: - clean_hidden_states = rearrange(clean_hidden_states, 'b f c h w -> b c f h w') + clean_hidden_states = rearrange(clean_hidden_states, "b f c h w -> b c f h w") batch_size, num_channels, num_frames, height, width = hidden_states.shape full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) - compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) total_chunks = len(chunk_partition) full_chunk_limit = self.config.full_chunk_limit - 1 @@ -1036,29 +1214,43 @@ def _forward_cache( num_full_chunk, num_compressed_chunk = total_chunks, 0 far_cfg = { - 'total_frames': sum(chunk_partition), - 'num_full_chunk': num_full_chunk, - 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), - 'num_compressed_chunk': num_compressed_chunk, - 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), - 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), - 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), - 'full_token_per_frame': full_token_per_frame, - 'compressed_token_per_frame': compressed_token_per_frame, - 'chunk_partition': chunk_partition + "total_frames": sum(chunk_partition), + "num_full_chunk": num_full_chunk, + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_chunk": num_compressed_chunk, + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + "chunk_partition": chunk_partition, } - kv_cache_flag['num_full_tokens'] = far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'] - kv_cache_flag['num_compressed_tokens'] = far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame'] + kv_cache_flag["num_full_tokens"] = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] + kv_cache_flag["num_compressed_tokens"] = ( + far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] + ) # step 3: generate attention mask - attention_mask = self._build_causal_mask(far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = self._build_causal_mask( + far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype + ) rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) - hidden_states = self.forward_far_patchify(hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states) + hidden_states = self.forward_far_patchify( + hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states + ) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg, clean_timestep=clean_timestep + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, + clean_timestep=clean_timestep, ) timestep_proj = timestep_proj.unflatten(2, (6, -1)) @@ -1067,13 +1259,27 @@ def _forward_cache( # 4. Transformer blocks for index_block, block in enumerate(self.blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, ) else: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, kv_cache[index_block], kv_cache_flag) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, + kv_cache[index_block], + kv_cache_flag, + ) return None, kv_cache @@ -1091,43 +1297,60 @@ def _forward_train( clean_timestep=None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + hidden_states = rearrange(hidden_states, "b f c h w -> b c f h w") if clean_hidden_states is not None: - clean_hidden_states = rearrange(clean_hidden_states, 'b f c h w -> b c f h w') + clean_hidden_states = rearrange(clean_hidden_states, "b f c h w -> b c f h w") batch_size, num_channels, num_frames, height, width = hidden_states.shape full_token_per_frame = (height // self.config.patch_size[1]) * (width // self.config.patch_size[2]) - compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * (width // self.config.compressed_patch_size[2]) + compressed_token_per_frame = (height // self.config.compressed_patch_size[1]) * ( + width // self.config.compressed_patch_size[2] + ) total_chunks = len(chunk_partition) if total_chunks > self.config.full_chunk_limit: - num_full_chunk, num_compressed_chunk = self.config.full_chunk_limit, total_chunks - self.config.full_chunk_limit + num_full_chunk, num_compressed_chunk = ( + self.config.full_chunk_limit, + total_chunks - self.config.full_chunk_limit, + ) else: num_full_chunk, num_compressed_chunk = total_chunks, 0 far_cfg = { - 'total_frames': sum(chunk_partition), - 'num_full_chunk': num_full_chunk, - 'num_full_frames': sum(chunk_partition[num_compressed_chunk:]), - 'num_compressed_chunk': num_compressed_chunk, - 'num_compressed_frames': sum(chunk_partition[:num_compressed_chunk]), - 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), - 'compressed_frame_shape': (height // self.config.compressed_patch_size[1], width // self.config.compressed_patch_size[2]), - 'full_token_per_frame': full_token_per_frame, - 'compressed_token_per_frame': compressed_token_per_frame, - 'chunk_partition': chunk_partition + "total_frames": sum(chunk_partition), + "num_full_chunk": num_full_chunk, + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_chunk": num_compressed_chunk, + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "compressed_frame_shape": ( + height // self.config.compressed_patch_size[1], + width // self.config.compressed_patch_size[2], + ), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + "chunk_partition": chunk_partition, } # step 3: generate attention mask - attention_mask = self._build_causal_mask(far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = self._build_causal_mask( + far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype + ) rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) - hidden_states = self.forward_far_patchify(hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states) + hidden_states = self.forward_far_patchify( + hidden_states, far_cfg=far_cfg, clean_hidden_states=clean_hidden_states + ) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, far_cfg=far_cfg, clean_timestep=clean_timestep + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + far_cfg=far_cfg, + clean_timestep=clean_timestep, ) timestep_proj = timestep_proj.unflatten(2, (6, -1)) @@ -1136,10 +1359,14 @@ def _forward_train( # 4. Transformer blocks for index_block, block in enumerate(self.blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask, + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + attention_mask, ) else: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, attention_mask) @@ -1158,9 +1385,19 @@ def _forward_train( hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) if clean_hidden_states is not None: - hidden_states = hidden_states[:, :-(far_cfg['num_full_frames'] * far_cfg['full_token_per_frame'])] # remove clean copy - output = self.proj_out(hidden_states[:, far_cfg['num_compressed_frames'] * far_cfg['compressed_token_per_frame']:]) # remove far context - output = self._unpack_latent_sequence(output, num_frames=far_cfg['num_full_frames'], height=height, width=width, patch_size=self.config.patch_size[1]) # noqa: E501 + hidden_states = hidden_states[ + :, : -(far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]) + ] # remove clean copy + output = self.proj_out( + hidden_states[:, far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] :] + ) # remove far context + output = self._unpack_latent_sequence( + output, + num_frames=far_cfg["num_full_frames"], + height=height, + width=width, + patch_size=self.config.patch_size[1], + ) # noqa: E501 if not return_dict: return output @@ -1176,9 +1413,9 @@ def _forward_bidirection( encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, - is_causal=False + is_causal=False, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - hidden_states = rearrange(hidden_states, 'b f c h w -> b c f h w') + hidden_states = rearrange(hidden_states, "b f c h w -> b c f h w") assert is_causal is False batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -1186,9 +1423,9 @@ def _forward_bidirection( full_token_per_frame = (height * width) // (self.config.patch_size[1] * self.config.patch_size[2]) far_cfg = { - 'total_frames': num_frames, - 'full_frame_shape': (height // self.config.patch_size[1], width // self.config.patch_size[2]), - 'full_token_per_frame': full_token_per_frame, + "total_frames": num_frames, + "full_frame_shape": (height // self.config.patch_size[1], width // self.config.patch_size[2]), + "full_token_per_frame": full_token_per_frame, } rotary_emb = self.rope(far_cfg=far_cfg, device=hidden_states.device, is_causal=is_causal) @@ -1197,7 +1434,12 @@ def _forward_bidirection( hidden_states = hidden_states.flatten(2).transpose(1, 2) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, r_timestep, encoder_hidden_states, encoder_hidden_states_image, is_causal=is_causal, far_cfg=far_cfg + timestep, + r_timestep, + encoder_hidden_states, + encoder_hidden_states_image, + is_causal=is_causal, + far_cfg=far_cfg, ) timestep_proj = timestep_proj.unflatten(2, (6, -1)) @@ -1236,7 +1478,13 @@ def _forward_bidirection( hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) - output = self._unpack_latent_sequence(hidden_states, num_frames=far_cfg['total_frames'], height=height, width=width, patch_size=self.config.patch_size[1]) # noqa: E501 + output = self._unpack_latent_sequence( + hidden_states, + num_frames=far_cfg["total_frames"], + height=height, + width=width, + patch_size=self.config.patch_size[1], + ) # noqa: E501 if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py index ebddfc0b10b7..74cda6137ee8 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py @@ -48,7 +48,7 @@ def basic_clean(text): def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.strip() return text @@ -87,8 +87,8 @@ class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin): (the default training-time behavior). Disable to mirror raw Euler stepping. """ - model_cpu_offload_seq = 'text_encoder->transformer->vae' - _callback_tensor_inputs = ['latents', 'prompt_embeds', 'negative_prompt_embeds'] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -97,7 +97,7 @@ def __init__( transformer: AnyFlowTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMapEulerDiscreteScheduler, - use_mean_velocity: bool = True + use_mean_velocity: bool = True, ): super().__init__() @@ -109,8 +109,8 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, 'vae', None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, 'vae', None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity @@ -131,12 +131,12 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer( prompt, - padding='max_length', + padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, - return_tensors='pt', + return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() @@ -211,19 +211,19 @@ def encode_prompt( ) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or '' + negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( - f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=' - f' {type(prompt)}.' + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( - f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' - f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' - ' the batch size of `prompt`.' + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) negative_prompt_embeds = self._get_t5_prompt_embeds( @@ -247,35 +247,35 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 16 != 0 or width % 16 != 0: - raise ValueError(f'`height` and `width` have to be divisible by 16 but are {height} and {width}.') + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}' # noqa: E501 + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 ) if prompt is not None and prompt_embeds is not None: raise ValueError( - f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to' - ' only forward one of the two.' + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( - f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to' - ' only forward one of the two.' + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( - 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.' + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): - raise ValueError(f'`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}') + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, @@ -302,12 +302,12 @@ def prepare_latents( ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( - f'You have passed a list of generators of length {len(generator)}, but requested an effective batch' - f' size of {batch_size}. Make sure the batch size matches the length of the generators.' + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = rearrange(latents, 'b c t h w -> b t c h w') + latents = rearrange(latents, "b c t h w -> b t c h w") return latents @property @@ -338,8 +338,10 @@ def attention_kwargs(self): def vae_encode(self, context_sequence): # normalize: [0, 1] -> [-1, 1] context_sequence = context_sequence * 2 - 1 - context_sequence = self.encode_latents(context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False) - context_sequence = rearrange(context_sequence, 'b c t h w -> b t c h w') + context_sequence = self.encode_latents( + context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False + ) + context_sequence = rearrange(context_sequence, "b c t h w -> b t c h w") return context_sequence def _normalize_latents(self, latents, latents_mean, latents_std): @@ -350,7 +352,7 @@ def _normalize_latents(self, latents, latents_mean, latents_std): @torch.no_grad() def encode_latents(self, videos, sample=True): - videos = rearrange(videos, 'b t c h w -> b c t h w') + videos = rearrange(videos, "b t c h w -> b c t h w") moments = self.vae._encode(videos) latents_mean = torch.tensor(self.vae.config.latents_mean) @@ -418,7 +420,7 @@ def inference_range(latents, timesteps): r_timestep=r_timestep, encoder_hidden_states=prompt_embeds, return_dict=False, - is_causal=False + is_causal=False, )[0] if self.do_classifier_free_guidance: @@ -475,13 +477,13 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = 'np', + output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, - callback_on_step_end_tensor_inputs: List[str] = ['latents'], + callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): @@ -498,7 +500,7 @@ def __call__( if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( - f'`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number.' + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) @@ -562,13 +564,13 @@ def __call__( latents=init_latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale + guidance_scale=guidance_scale, ) if context_sequence is not None: latents[:, :context_length, ...] = context_sequence - latents = rearrange(latents, 'b f c h w -> b c f h w') + latents = rearrange(latents, "b f c h w -> b c f h w") - if not output_type == 'latent': + if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py index 04ea594e2127..a5de05cf4c84 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py @@ -49,7 +49,7 @@ def basic_clean(text): def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.strip() return text @@ -87,8 +87,8 @@ class AnyFlowCausalPipeline(DiffusionPipeline, WanLoraLoaderMixin): When ``True`` the model output is averaged across two anchor times to reduce discretization error. """ - model_cpu_offload_seq = 'text_encoder->transformer->vae' - _callback_tensor_inputs = ['latents', 'prompt_embeds', 'negative_prompt_embeds'] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -97,7 +97,7 @@ def __init__( transformer: AnyFlowTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMapEulerDiscreteScheduler, - use_mean_velocity: bool = True + use_mean_velocity: bool = True, ): super().__init__() @@ -109,8 +109,8 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, 'vae', None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, 'vae', None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity @@ -131,12 +131,12 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer( prompt, - padding='max_length', + padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, - return_tensors='pt', + return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() @@ -211,19 +211,19 @@ def encode_prompt( ) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or '' + negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( - f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=' - f' {type(prompt)}.' + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( - f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' - f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' - ' the batch size of `prompt`.' + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) negative_prompt_embeds = self._get_t5_prompt_embeds( @@ -247,35 +247,35 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 16 != 0 or width % 16 != 0: - raise ValueError(f'`height` and `width` have to be divisible by 16 but are {height} and {width}.') + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}' # noqa: E501 + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" # noqa: E501 ) if prompt is not None and prompt_embeds is not None: raise ValueError( - f'Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to' - ' only forward one of the two.' + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( - f'Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to' - ' only forward one of the two.' + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( - 'Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.' + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f'`prompt` has to be of type `str` or `list` but is {type(prompt)}') + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): - raise ValueError(f'`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}') + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, @@ -302,12 +302,12 @@ def prepare_latents( ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( - f'You have passed a list of generators of length {len(generator)}, but requested an effective batch' - f' size of {batch_size}. Make sure the batch size matches the length of the generators.' + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = rearrange(latents, 'b c t h w -> b t c h w') + latents = rearrange(latents, "b c t h w -> b t c h w") return latents @property @@ -338,8 +338,10 @@ def attention_kwargs(self): def vae_encode(self, context_sequence): # normalize: [0, 1] -> [-1, 1] context_sequence = context_sequence * 2 - 1 - context_sequence = self.encode_latents(context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False) - context_sequence = rearrange(context_sequence, 'b c t h w -> b t c h w') + context_sequence = self.encode_latents( + context_sequence.to(dtype=self.vae.dtype, device=self._execution_device), sample=False + ) + context_sequence = rearrange(context_sequence, "b c t h w -> b t c h w") return context_sequence def _normalize_latents(self, latents, latents_mean, latents_std): @@ -350,7 +352,7 @@ def _normalize_latents(self, latents, latents_mean, latents_std): @torch.no_grad() def encode_latents(self, videos, sample=True): - videos = rearrange(videos, 'b t c h w -> b c t h w') + videos = rearrange(videos, "b t c h w -> b c t h w") moments = self.vae._encode(videos) latents_mean = torch.tensor(self.vae.config.latents_mean) @@ -377,12 +379,12 @@ def inference( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = 'np', + output_type: Optional[str] = "np", return_dict: bool = True, kv_cache=None, kv_cache_flag=None, grad_timestep=None, - chunk_partition=None + chunk_partition=None, ): if negative_prompt_embeds is not None: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -415,7 +417,7 @@ def inference_range(latents, timesteps): chunk_partition=chunk_partition, # kv-cache related kv_cache=kv_cache, - kv_cache_flag=copy.deepcopy(kv_cache_flag) + kv_cache_flag=copy.deepcopy(kv_cache_flag), ) if self.do_classifier_free_guidance: noise_uncond, noise_pred = noise_pred.chunk(2) @@ -468,7 +470,7 @@ def training_rollout( ): self._guidance_scale = guidance_scale - latents = rearrange(latents, 'b c t h w -> b t c h w') + latents = rearrange(latents, "b c t h w -> b t c h w") batch_size, num_frame, _, height, width = latents.shape # 5. Prepare latent variables @@ -476,10 +478,16 @@ def training_rollout( chunk_partition = self.transformer.config.chunk_partition - assert init_latents.shape[1] == sum(chunk_partition), 'please check the chunk_partition equal to num_smaple_frames' + assert init_latents.shape[1] == sum(chunk_partition), ( + "please check the chunk_partition equal to num_smaple_frames" + ) - full_token_per_frame = (init_latents.shape[3] // self.transformer.config.patch_size[1]) * (init_latents.shape[4] // self.transformer.config.patch_size[2]) # noqa: E501 - compressed_token_per_frame = (init_latents.shape[3] // self.transformer.config.compressed_patch_size[1]) * (init_latents.shape[4] // self.transformer.config.compressed_patch_size[2]) # noqa: E501 + full_token_per_frame = (init_latents.shape[3] // self.transformer.config.patch_size[1]) * ( + init_latents.shape[4] // self.transformer.config.patch_size[2] + ) # noqa: E501 + compressed_token_per_frame = (init_latents.shape[3] // self.transformer.config.compressed_patch_size[1]) * ( + init_latents.shape[4] // self.transformer.config.compressed_patch_size[2] + ) # noqa: E501 # init kv cache if use_kv_cache: @@ -489,21 +497,35 @@ def training_rollout( for layer_idx in range(self.transformer.config.num_layers): kv_cache[layer_idx] = { - 'full_cache': torch.zeros(( - 2, batch_size, self.transformer.config.num_attention_heads, - self.transformer.config.full_chunk_limit * max(chunk_partition) * full_token_per_frame, - self.transformer.config.attention_head_dim - ), device=init_latents.device, dtype=init_latents.dtype), - 'compressed_cache': torch.zeros(( - 2, batch_size, self.transformer.config.num_attention_heads, - (len(chunk_partition) - self.transformer.config.full_chunk_limit + 1) * max(chunk_partition) * compressed_token_per_frame, - self.transformer.config.attention_head_dim - ), device=init_latents.device, dtype=init_latents.dtype) + "full_cache": torch.zeros( + ( + 2, + batch_size, + self.transformer.config.num_attention_heads, + self.transformer.config.full_chunk_limit * max(chunk_partition) * full_token_per_frame, + self.transformer.config.attention_head_dim, + ), + device=init_latents.device, + dtype=init_latents.dtype, + ), + "compressed_cache": torch.zeros( + ( + 2, + batch_size, + self.transformer.config.num_attention_heads, + (len(chunk_partition) - self.transformer.config.full_chunk_limit + 1) + * max(chunk_partition) + * compressed_token_per_frame, + self.transformer.config.attention_head_dim, + ), + device=init_latents.device, + dtype=init_latents.dtype, + ), } kv_cache_flag = { - 'num_cached_chunks': 0, - 'is_cache_step': False, + "num_cached_chunks": 0, + "is_cache_step": False, } else: kv_cache = None @@ -513,49 +535,64 @@ def training_rollout( # setup start sequence if context_sequence is not None: - if 'latent' in context_sequence: - latents = rearrange(context_sequence['latent'], 'b c t h w -> b t c h w') + if "latent" in context_sequence: + latents = rearrange(context_sequence["latent"], "b c t h w -> b t c h w") else: - assert (context_sequence['raw'].shape[1] - 1) % 4 == 0, 'require 4n+1 frames' - latents = self.vae_encode(context_sequence['raw']) + assert (context_sequence["raw"].shape[1] - 1) % 4 == 0, "require 4n+1 frames" + latents = self.vae_encode(context_sequence["raw"]) current_context_length = latents.shape[1] output[:, :current_context_length] = latents - num_context_chunks = next(i + 1 for i in range(len(chunk_partition)) if sum(chunk_partition[:i + 1]) >= current_context_length) + num_context_chunks = next( + i + 1 for i in range(len(chunk_partition)) if sum(chunk_partition[: i + 1]) >= current_context_length + ) else: num_context_chunks = 0 for chunk_idx in tqdm(range(len(chunk_partition))): - if chunk_idx >= num_context_chunks: pred_latents = self.inference( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, kv_cache=kv_cache, kv_cache_flag=kv_cache_flag, - latents=init_latents[:, sum(chunk_partition[:chunk_idx]): sum(chunk_partition[:chunk_idx + 1])], + latents=init_latents[:, sum(chunk_partition[:chunk_idx]) : sum(chunk_partition[: chunk_idx + 1])], num_inference_steps=num_inference_steps, grad_timestep=grad_timestep, guidance_scale=guidance_scale, - chunk_partition=chunk_partition[:chunk_idx + 1] + chunk_partition=chunk_partition[: chunk_idx + 1], ) - output[:, sum(chunk_partition[:chunk_idx]): sum(chunk_partition[:chunk_idx + 1])] = pred_latents + output[:, sum(chunk_partition[:chunk_idx]) : sum(chunk_partition[: chunk_idx + 1])] = pred_latents # step1: save to kv cache if chunk_idx < len(chunk_partition) - 1: - kv_cache = self.encode_kv_cache(kv_cache, kv_cache_flag, chunk_partition=chunk_partition[:chunk_idx + 1], chunk_idx=chunk_idx, output=output, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds) # noqa: E501 + kv_cache = self.encode_kv_cache( + kv_cache, + kv_cache_flag, + chunk_partition=chunk_partition[: chunk_idx + 1], + chunk_idx=chunk_idx, + output=output, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) # noqa: E501 - output = rearrange(output, 'b f c h w -> b c f h w') + output = rearrange(output, "b f c h w -> b c f h w") return output @torch.no_grad() - def encode_kv_cache(self, kv_cache, kv_cache_flag, chunk_partition, chunk_idx, output, prompt_embeds, negative_prompt_embeds): - kv_cache_flag['is_cache_step'] = True + def encode_kv_cache( + self, kv_cache, kv_cache_flag, chunk_partition, chunk_idx, output, prompt_embeds, negative_prompt_embeds + ): + kv_cache_flag["is_cache_step"] = True if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - latents = output[:, :sum(chunk_partition)] - latent_model_input = torch.cat([latents] * 2).to(torch.bfloat16) if self.do_classifier_free_guidance else latents.to(torch.bfloat16) + latents = output[:, : sum(chunk_partition)] + latent_model_input = ( + torch.cat([latents] * 2).to(torch.bfloat16) + if self.do_classifier_free_guidance + else latents.to(torch.bfloat16) + ) timestep = torch.tensor([0], device=latents.device).expand(latent_model_input.shape[0]).unsqueeze(-1) timestep = timestep.repeat((1, latent_model_input.shape[1])) @@ -572,11 +609,11 @@ def encode_kv_cache(self, kv_cache, kv_cache_flag, chunk_partition, chunk_idx, o return_dict=False, # kv-cache related kv_cache=kv_cache, - kv_cache_flag=copy.deepcopy(kv_cache_flag) + kv_cache_flag=copy.deepcopy(kv_cache_flag), ) - kv_cache_flag['num_cached_chunks'] += 1 - kv_cache_flag['is_cache_step'] = False + kv_cache_flag["num_cached_chunks"] += 1 + kv_cache_flag["is_cache_step"] = False return kv_cache @@ -596,16 +633,16 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = 'np', + output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, - callback_on_step_end_tensor_inputs: List[str] = ['latents'], + callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, show_progress=True, - use_kv_cache=True + use_kv_cache=True, ): # 1. Check inputs. Raise error if not correct @@ -621,7 +658,7 @@ def __call__( if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( - f'`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number.' + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) @@ -672,18 +709,19 @@ def __call__( latents, ) init_latents = init_latents.to(transformer_dtype) - init_latents = rearrange(init_latents, 'b f c h w -> b c f h w') + init_latents = rearrange(init_latents, "b f c h w -> b c f h w") latents = self.training_rollout( - context_sequence=context_sequence, num_inference_steps=num_inference_steps, + context_sequence=context_sequence, + num_inference_steps=num_inference_steps, grad_timestep=None, latents=init_latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale + guidance_scale=guidance_scale, ) - if not output_type == 'latent': + if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 60222c2b6fca..91019a88fe90 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -435,6 +435,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AnyFlowTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -2972,6 +2987,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FlowMapEulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6511345e9511..eb7438c8f497 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -887,6 +887,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnyFlowCausalPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AnyFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AudioLDM2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_anyflow.py b/tests/models/transformers/test_models_transformer_anyflow.py index 2d44a476c472..9df4de8e5767 100644 --- a/tests/models/transformers/test_models_transformer_anyflow.py +++ b/tests/models/transformers/test_models_transformer_anyflow.py @@ -43,33 +43,33 @@ class AnyFlowTransformer3DModelTest(unittest.TestCase): @staticmethod def _tiny_init_kwargs(**overrides): - kwargs = dict( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=4, - out_channels=4, - text_dim=16, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - ) + kwargs = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } kwargs.update(overrides) return kwargs @staticmethod def _tiny_bidi_inputs(batch_size=1, num_frames=2, height=16, width=16, text_seq_len=12, text_dim=16): - return dict( - hidden_states=torch.randn(batch_size, num_frames, 4, height, width, device="cpu"), - timestep=torch.full((batch_size, num_frames), 500.0, device="cpu"), - r_timestep=torch.full((batch_size, num_frames), 250.0, device="cpu"), - encoder_hidden_states=torch.randn(batch_size, text_seq_len, text_dim, device="cpu"), - is_causal=False, - return_dict=True, - ) + return { + "hidden_states": torch.randn(batch_size, num_frames, 4, height, width, device="cpu"), + "timestep": torch.full((batch_size, num_frames), 500.0, device="cpu"), + "r_timestep": torch.full((batch_size, num_frames), 250.0, device="cpu"), + "encoder_hidden_states": torch.randn(batch_size, text_seq_len, text_dim, device="cpu"), + "is_causal": False, + "return_dict": True, + } def test_construction_base_wan(self): m = AnyFlowTransformer3DModel(**self._tiny_init_kwargs()) @@ -99,9 +99,13 @@ def test_construction_far_plus_flowmap(self): def test_bidi_forward_shape_preserved(self): torch.manual_seed(0) - m = AnyFlowTransformer3DModel( - **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") - ).to("cpu").eval() + m = ( + AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ) + .to("cpu") + .eval() + ) inputs = self._tiny_bidi_inputs() with torch.no_grad(): @@ -111,9 +115,13 @@ def test_bidi_forward_shape_preserved(self): def test_bidi_forward_return_dict_false(self): torch.manual_seed(0) - m = AnyFlowTransformer3DModel( - **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") - ).to("cpu").eval() + m = ( + AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ) + .to("cpu") + .eval() + ) inputs = self._tiny_bidi_inputs() inputs["return_dict"] = False @@ -124,9 +132,13 @@ def test_bidi_forward_return_dict_false(self): def test_bidi_forward_determinism(self): torch.manual_seed(0) - m = AnyFlowTransformer3DModel( - **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") - ).to("cpu").eval() + m = ( + AnyFlowTransformer3DModel( + **self._tiny_init_kwargs(init_flowmap_model=True, gate_value=0.25, deltatime_type="r") + ) + .to("cpu") + .eval() + ) inputs_a = self._tiny_bidi_inputs() inputs_b = {k: v.clone() if torch.is_tensor(v) else v for k, v in inputs_a.items()} From 74a89aead6bbaf3da6d1ba97b449fcd79b4419ce Mon Sep 17 00:00:00 2001 From: Enderfga Date: Wed, 6 May 2026 17:06:16 +0800 Subject: [PATCH 08/16] [AnyFlow] address review feedback: bug fixes + DMD wording + EN/ZH tutorials MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical bug fixes (verified against precision-validation review): * pipeline_anyflow.py / pipeline_anyflow_causal.py: replace hardcoded transformer_dtype = torch.bfloat16 with self.transformer.dtype, so pipe.to("cpu") and PipelineTesterMixin save/load tests do not crash on a dtype mismatch in the patch_embedding conv3d. * transformer_anyflow.py: drop the duplicate `base = base = ...` assignment in _build_causal_mask (was a copy-paste typo carried over from FAR-Dev). * transformer_anyflow.py: drop unused `q_is_context` / `k_is_context` locals and the `# noqa: F841` markers that were silencing the dead-store warning. * transformer_anyflow.py: remove `CacheMixin` from the inheritance list — the pipeline manages KV cache directly, the mixin's interface is unused. * transformer_anyflow.py: guard the module-level `torch.compile(flex_attention)` with try/except so the file imports cleanly on CPU CI / no-Triton machines. * convert_anyflow_to_diffusers.py: replace ad-hoc print warnings with the stdlib logger (warning_once-style) and a module-level basicConfig. Documentation accuracy: * AnyFlowCausalPipeline class docstring + main pipeline doc + EN/ZH tutorial: drop the fictitious `task_type` / `image` / `video` arguments and document the real API: pass `context_sequence={"raw": tensor}` (or `{"latent": ...}`) to switch between T2V (None) / I2V (1-frame) / TV2V (4n+1-frame) modes. * Pipeline class docstrings + main doc: explicitly describe AnyFlow's two-stage LoRA distillation including DMD reverse-divergence supervision with Flow-Map backward simulation in stage 2 (was previously implicit). * training_rollout: add detailed docstring explaining its role as the 3-segment Flow-Map backward simulation entry point used during DMD training. * Long-form tutorial doc `using-diffusers/anyflow.md` (EN, 239 LOC) and Chinese mirror `docs/source/zh/using-diffusers/anyflow.md` (224 LOC) added and registered in both `_toctree.yml` files. Tests: * Skip `test_attention_slicing_forward_pass` in both pipeline test classes with a clear rationale (custom attention processor does not support slicing). * All 21 standalone tests still pass (12 scheduler + 9 transformer). Quality gates: * `ruff check` clean across all AnyFlow files. * `ruff format --check` reports 6 files already formatted. * `python utils/check_copies.py` reports no diff. Out of scope for this commit (deferred until reviewer feedback): * Splitting AnyFlowTransformer3DModel into bidi + causal subclasses * Unifying _forward_inference / _forward_cache return types * Migrating model tests from plain unittest to BaseModelTesterConfig + mixins * HF model card / config.json metadata updates on the nvidia/* repos (push to Hub manually before opening the PR) --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/anyflow.md | 29 ++- docs/source/en/using-diffusers/anyflow.md | 239 ++++++++++++++++++ docs/source/zh/_toctree.yml | 2 + docs/source/zh/using-diffusers/anyflow.md | 224 ++++++++++++++++ scripts/convert_anyflow_to_diffusers.py | 19 +- .../transformers/transformer_anyflow.py | 18 +- .../pipelines/anyflow/pipeline_anyflow.py | 70 ++++- .../anyflow/pipeline_anyflow_causal.py | 74 ++++-- tests/pipelines/anyflow/test_anyflow.py | 4 + .../pipelines/anyflow/test_anyflow_causal.py | 4 + 11 files changed, 640 insertions(+), 45 deletions(-) create mode 100644 docs/source/en/using-diffusers/anyflow.md create mode 100644 docs/source/zh/using-diffusers/anyflow.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 41be879d8173..a5dd4aa26117 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -198,6 +198,8 @@ title: Model accelerators and hardware - isExpanded: false sections: + - local: using-diffusers/anyflow + title: AnyFlow - local: using-diffusers/helios title: Helios - local: using-diffusers/consisid diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md index e62bb2cb2ebd..bed636b76da8 100644 --- a/docs/source/en/api/pipelines/anyflow.md +++ b/docs/source/en/api/pipelines/anyflow.md @@ -99,6 +99,11 @@ export_to_video(video, "out.mp4", fps=16) ### Generation with AnyFlow (FAR Causal) +The causal pipeline selects between T2V / I2V / TV2V via the ``context_sequence`` argument: pass ``None`` +for plain text-to-video, or a dict with a ``"raw"`` key holding a video tensor of shape +``(B, C, T, H, W)`` with ``T = 4n + 1`` to condition on existing frames. Use a single conditioning frame +for I2V and a longer clip for TV2V continuation. + @@ -113,7 +118,6 @@ pipe = AnyFlowCausalPipeline.from_pretrained( video = pipe( prompt="A cat surfing a wave, sunset", - task_type="t2v", num_inference_steps=4, num_frames=33, ).frames[0] @@ -127,16 +131,20 @@ export_to_video(video, "out.mp4", fps=16) import torch from diffusers import AnyFlowCausalPipeline from diffusers.utils import export_to_video, load_image +from torchvision import transforms pipe = AnyFlowCausalPipeline.from_pretrained( "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -img = load_image("path/to/first_frame.png") +# Wrap the conditioning image as a one-frame video tensor: (1, 3, 1, H, W) +first_frame = load_image("path/to/first_frame.png") +to_tensor = transforms.Compose([transforms.Resize((480, 832)), transforms.ToTensor()]) +first_frame = to_tensor(first_frame).unsqueeze(0).unsqueeze(2).to("cuda") # (1, 3, 1, 480, 832) + video = pipe( prompt="a cat walks across a sunlit lawn", - image=img, - task_type="i2v", + context_sequence={"raw": first_frame}, num_inference_steps=4, num_frames=33, ).frames[0] @@ -150,16 +158,21 @@ export_to_video(video, "out.mp4", fps=16) import torch from diffusers import AnyFlowCausalPipeline from diffusers.utils import export_to_video, load_video +from torchvision import transforms pipe = AnyFlowCausalPipeline.from_pretrained( "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -context = load_video("path/to/context.mp4") +# Provide a context clip whose frame count is 4n + 1 (e.g., 9, 13, 17). +context_frames = load_video("path/to/context.mp4") # list of PIL frames +to_tensor = transforms.Compose([transforms.Resize((480, 832)), transforms.ToTensor()]) +context_tensor = torch.stack([to_tensor(f) for f in context_frames[:9]], dim=1).unsqueeze(0).to("cuda") +# Shape: (1, 3, 9, 480, 832) + video = pipe( prompt="continue the story", - video=context, - task_type="tv2v", + context_sequence={"raw": context_tensor}, num_inference_steps=4, num_frames=33, ).frames[0] @@ -171,9 +184,11 @@ export_to_video(video, "out.mp4", fps=16) ## Notes +- The released NVIDIA checkpoints went through a two-stage LoRA distillation: forward Flow-Map training plus on-policy distillation that combines Flow-Map backward simulation with **DMD reverse-divergence supervision** over the student's own rollouts. CFG was fused into the model weights during stage 1 (`fuse_guidance_scale = 3.0`), so inference does not run a second classifier-free guidance pass — quality is recovered from the distilled weights themselves. - `FlowMapEulerDiscreteScheduler` is general-purpose. You can attach it to any flow-map-distilled checkpoint via `from_pretrained(..., scheduler=FlowMapEulerDiscreteScheduler.from_config(...))`. - The bidirectional pipeline accepts any `AnyFlowTransformer3DModel` configured with `init_flowmap_model=True`. The causal pipeline additionally requires `init_far_model=True`. - LoRA training is supported via `WanLoraLoaderMixin`, the same mixin used by the upstream Wan pipelines. +- For continued on-policy fine-tuning with DMD, both pipelines expose a `training_rollout` method that drives the three-segment Flow-Map backward simulation used in the original AnyFlow stage-2 trainer. ## AnyFlowPipeline diff --git a/docs/source/en/using-diffusers/anyflow.md b/docs/source/en/using-diffusers/anyflow.md new file mode 100644 index 000000000000..dc89f17a9a7a --- /dev/null +++ b/docs/source/en/using-diffusers/anyflow.md @@ -0,0 +1,239 @@ + + +# AnyFlow + +[AnyFlow](https://huggingface.co/papers/) is a video diffusion **distillation** framework that turns +a pretrained Wan2.1 teacher into an *any-step* student under standard Euler sampling. A single distilled +checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining and quality scales **monotonically** +with steps — unlike consistency models, which often degrade as NFE grows. + +The key idea is to learn the **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$ for arbitrary +$1 \ge t \ge r \ge 0$ instead of the fixed endpoint map $\mathbf{z}_t \to \mathbf{z}_0$ used by consistency +models. Composability of the flow map removes re-noising between sampling steps; on-policy distillation with +**DMD reverse-divergence supervision** plus **Flow-Map backward simulation** (3-segment shortcut) closes the +exposure-bias gap that consistency-based distillation leaves open. + +This guide walks through the practical decisions: which pipeline to pick, how to use any-step sampling, and +how to plug AnyFlow into typical T2V / I2V / TV2V workflows. + +## Bidirectional vs causal — pick a pipeline + +AnyFlow ships in two flavors that share the same scheduler and the same flow-map distillation but differ in +how they sample frames: + +- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) — **bidirectional** T2V. Denoises the entire + video tensor in one pass with global self-attention. Use this when the input is a single text prompt and you + do not need streaming output. +- [`AnyFlowCausalPipeline`](../api/pipelines/anyflow#anyflowcausalpipeline) — **causal (FAR)**. Denoises the + video chunk by chunk with block-sparse causal attention and reuses KV cache across chunks. Use this for + image-to-video (I2V), text+video-to-video (TV2V) continuation, or any setup that benefits from frame-level + autoregressive sampling. The same model handles all three task modes via a `task_type` argument. + +A quick selector: + +| Scenario | Pipeline | How to invoke | +|----------|----------|---------------| +| Pure text-to-video, max quality at fixed NFE | `AnyFlowPipeline` | `pipe(prompt, ...)` | +| Image-to-video (start from a still image) | `AnyFlowCausalPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` | +| Video continuation / TV2V | `AnyFlowCausalPipeline` | `pipe(prompt, context_sequence={"raw": }, ...)` | +| Streaming / progressive generation | `AnyFlowCausalPipeline` | — | + +The bidirectional variant is faster per token at high resolution; the causal variant trades that for the +ability to start sampling before all latent frames are allocated, useful for very long sequences. + +## Loading checkpoints + +NVIDIA released four AnyFlow checkpoints, one per pipeline + scale combination: + +```py +import torch +from diffusers import AnyFlowPipeline, AnyFlowCausalPipeline + +# Bidirectional, lightweight +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Bidirectional, full quality +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 1.3B +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 14B +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +All four use the same [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete) with +`shift=5.0` baked in. + +## Any-step sampling + +The defining feature of AnyFlow is that the same checkpoint produces increasing quality as you raise NFE, +with no schedule retuning. Sweep step counts on a fixed prompt to see how the model trades latency for +fidelity: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "A red panda eating bamboo in a forest, cinematic lighting" + +for nfe in [1, 2, 4, 8, 16, 32]: + generator = torch.Generator("cuda").manual_seed(0) + video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0] + export_to_video(video, f"out_nfe{nfe}.mp4", fps=16) +``` + +In our benchmarks (paper Tab 3 / Fig 1) every AnyFlow checkpoint improves monotonically from 4 → 32 NFE +on VBench Quality, while consistency-based baselines (rCM, Self-Forcing) degrade in the same regime. + +> [!TIP] +> Classifier-free guidance (CFG) was *fused* into the model weights during distillation +> (`fuse_guidance_scale = 3.0`). The pipeline does not run a second guided forward pass at inference time — +> guidance comes from the distilled weights themselves. Leave `guidance_scale=1.0` (the default) for the +> released checkpoints. + +## Image-to-video and text+video-to-video + +The causal pipeline supports three task modes from a single distilled model. The mode is selected +implicitly by the ``context_sequence`` argument (a dict with a ``"raw"`` video tensor or ``"latent"`` +pre-encoded latents). Frame counts in the context tensor must satisfy ``T = 4n + 1`` to align with the +VAE temporal stride. + +```py +import torch +from diffusers import AnyFlowCausalPipeline +from diffusers.utils import export_to_video, load_image, load_video +from torchvision import transforms + +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +to_tensor = transforms.Compose([transforms.Resize((480, 832)), transforms.ToTensor()]) + +# 1) Text-to-video (no context) +video = pipe(prompt="A cat surfing a wave at sunset", num_inference_steps=4, num_frames=33).frames[0] +export_to_video(video, "t2v.mp4", fps=16) + +# 2) Image-to-video — wrap the still as a one-frame video (1, 3, 1, H, W) +first_frame = load_image("path/to/first_frame.png") +first_frame = to_tensor(first_frame).unsqueeze(0).unsqueeze(2).to("cuda") +video = pipe( + prompt="a cat walks across a sunlit lawn", + context_sequence={"raw": first_frame}, + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "i2v.mp4", fps=16) + +# 3) Text + video → continuation. Context length must be 4n + 1 (e.g., 9 frames). +context_frames = load_video("path/to/context.mp4") +context_tensor = torch.stack([to_tensor(f) for f in context_frames[:9]], dim=1).unsqueeze(0).to("cuda") +video = pipe( + prompt="continue the story", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "tv2v.mp4", fps=16) +``` + +Internally, the patchification chunk schedule depends on whether (and how long) ``context_sequence`` is set: +without context the model uses kernel sizes 2 (full) and 4 (compressed); with a context clip the first chunk +uses kernel size 1 so the conditioning frames keep full resolution. + +If you already have VAE-encoded latents, pass them via ``context_sequence={"latent": ...}`` to skip the +``vae_encode`` step. + +## Memory and inference speed + +A 14B AnyFlow model fits on a single 40 GB device with group offloading + VAE slicing: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + +For latency, `torch.compile` works well on the transformer (the heaviest module by far): + +```py +pipe = pipe.to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + +Compile costs are amortized after a few steps; combined with low NFE (4–8 for AnyFlow) you typically see +2–3× speedup vs the eager 14B path. + +## LoRA fine-tuning + +Both pipelines reuse [`WanLoraLoaderMixin`](../api/loaders/lora), so any LoRA adapter trained for the +matching Wan2.1 backbone loads directly: + +```py +pipe.load_lora_weights("path/or/repo/with/wan_lora") +``` + +For continued **on-policy** fine-tuning with DMD-style reverse-divergence supervision (the same recipe used +to produce the released checkpoints), both pipelines expose a `training_rollout` method that drives the +3-segment Flow-Map backward simulation. End users training a new LoRA can call it under autograd to compose +their own DMD trainer; the original AnyFlow trainer that built the released checkpoints is in +`Enderfga/AnyFlow` (out of scope for diffusers). + +## Common gotchas + +- **Always-1.0 `guidance_scale`.** The distilled checkpoints already encode CFG. Setting `guidance_scale > 1` + will run a redundant unconditional pass, double the latency, and slightly hurt quality. +- **Bidirectional pipeline does not stream.** All `num_frames` worth of latents are denoised together. Use + the causal pipeline if you want to start playback before sampling completes. +- **Causal pipeline KV cache assumes the chunk schedule is consistent across calls.** Rebuilding the cache + mid-generation is not supported by the released model. +- **`num_frames` must satisfy the VAE temporal stride.** Use values of the form `(N - 1) % 4 == 0` (e.g., 9, + 17, 33, 81) for the released checkpoints. + +## Citation + +```bibtex +@article{gu2026anyflow, + title = {AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation}, + author = {Gu, Yuchao and others}, + journal = {arXiv preprint arXiv:}, + year = {2026} +} + +@article{gu2025long, + title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction}, + author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng}, + journal={arXiv preprint arXiv:2503.19325}, + year={2025} +} +``` diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index af51506746b2..b49820dd76e7 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -130,6 +130,8 @@ - title: Specific pipeline examples isExpanded: false sections: + - local: using-diffusers/anyflow + title: AnyFlow - local: using-diffusers/consisid title: ConsisID - local: using-diffusers/helios diff --git a/docs/source/zh/using-diffusers/anyflow.md b/docs/source/zh/using-diffusers/anyflow.md new file mode 100644 index 000000000000..d132d3075297 --- /dev/null +++ b/docs/source/zh/using-diffusers/anyflow.md @@ -0,0 +1,224 @@ + + +# AnyFlow + +[AnyFlow](https://huggingface.co/papers/) 是一个视频扩散**蒸馏**框架,把预训练的 Wan2.1 教师 +模型蒸馏成在标准 Euler 采样下支持*任意步数 (any-step)* 的学生模型。同一个蒸馏出来的 checkpoint 可以 +在 1、2、4、8、16... NFE 下推理,**质量随步数单调提升** —— 这一点和 consistency models 不同,后者 +NFE 增加反而经常掉点。 + +核心思路是学习 **flow map** $\Phi_{r\leftarrow t}: \mathbf{z}_t \to \mathbf{z}_r$(任意 $1 \ge t \ge r \ge 0$), +而不是 consistency models 学的固定端点映射 $\mathbf{z}_t \to \mathbf{z}_0$。Flow map 的可组合性消除了 +采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation** +(3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。 + +本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / TV2V 工作流。 + +## Bidirectional 还是 Causal —— 怎么选 pipeline + +AnyFlow 提供两个 pipeline 形态,scheduler 和蒸馏方法相同,区别在于**怎么对帧采样**: + +- [`AnyFlowPipeline`](../api/pipelines/anyflow#anyflowpipeline) —— **bidirectional** T2V。一次性对整个 + 视频张量去噪,全局自注意力。**纯 prompt 输入、不要流式输出**时选这个。 +- [`AnyFlowCausalPipeline`](../api/pipelines/anyflow#anyflowcausalpipeline) —— **causal (FAR)**。 + 按 chunk 分段去噪,块稀疏因果注意力 + 跨 chunk 复用 KV cache。**图生视频 (I2V)**、**视频续写 (TV2V)**、 + 或任何受益于逐帧自回归采样的场景选这个。同一个模型通过传入 `context_sequence` 来切换三种任务模式。 + +简化对照表: + +| 场景 | Pipeline | 调用方式 | +|------|----------|----------| +| 纯文生视频,固定 NFE 求最大质量 | `AnyFlowPipeline` | `pipe(prompt, ...)` | +| 图生视频(首帧给定) | `AnyFlowCausalPipeline` | `pipe(prompt, context_sequence={"raw": <单帧 tensor>}, ...)` | +| 视频续写 / TV2V | `AnyFlowCausalPipeline` | `pipe(prompt, context_sequence={"raw": <多帧 tensor>}, ...)` | +| 流式 / 渐进式生成 | `AnyFlowCausalPipeline` | — | + +高分辨率下 bidirectional 单 token 更快;causal 牺牲一点单步速度,换来在所有 latent 帧分配前就能开始 +采样的能力,对超长序列尤其有用。 + +## 加载 checkpoint + +NVIDIA 发布了 4 个 AnyFlow checkpoint,pipeline × 规模各一份: + +```py +import torch +from diffusers import AnyFlowPipeline, AnyFlowCausalPipeline + +# Bidirectional, 轻量 +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Bidirectional, 满血 +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 1.3B +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +# Causal (FAR), 14B +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +``` + +四个 checkpoint 共用同一份 [`FlowMapEulerDiscreteScheduler`](../api/schedulers/flow_map_euler_discrete), +默认 `shift=5.0`。 + +## Any-step 采样 + +AnyFlow 最关键的特性是同一个 checkpoint **不需重新调度**,NFE 越大质量越高。固定 prompt、扫一下步数 +就能看出模型怎么在延迟和保真度之间权衡: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") + +prompt = "森林里一只小熊猫在啃竹子,电影感光照" + +for nfe in [1, 2, 4, 8, 16, 32]: + generator = torch.Generator("cuda").manual_seed(0) + video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0] + export_to_video(video, f"out_nfe{nfe}.mp4", fps=16) +``` + +paper 的 Tab 3 / Fig 1 表明:每个 AnyFlow checkpoint 在 4 → 32 NFE 范围 VBench Quality 都单调上升,而 +consistency 类基线(rCM、Self-Forcing)在同区间反而掉点。 + +> [!TIP] +> Classifier-free guidance (CFG) 已经在蒸馏阶段融进权重 (`fuse_guidance_scale = 3.0`)。pipeline 推理 +> 时**不会**再跑一次 unconditional 前向 —— guidance 直接由蒸馏后的权重带出。release 出来的 checkpoint +> 都用默认的 `guidance_scale=1.0` 即可。 + +## 图生视频 与 视频续写 + +Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `context_sequence` 隐式选择**(dict,含 +`"raw"` 视频张量或 `"latent"` 已编码 latent)。Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE +时间步长对齐。 + +```py +import torch +from diffusers import AnyFlowCausalPipeline +from diffusers.utils import export_to_video, load_image, load_video +from torchvision import transforms + +pipe = AnyFlowCausalPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 +).to("cuda") +to_tensor = transforms.Compose([transforms.Resize((480, 832)), transforms.ToTensor()]) + +# 1) 文生视频(无 context) +video = pipe(prompt="一只猫在夕阳下冲浪", num_inference_steps=4, num_frames=33).frames[0] +export_to_video(video, "t2v.mp4", fps=16) + +# 2) 图生视频 —— 把首帧包成 (1, 3, 1, H, W) 单帧视频 +first_frame = load_image("path/to/first_frame.png") +first_frame = to_tensor(first_frame).unsqueeze(0).unsqueeze(2).to("cuda") +video = pipe( + prompt="一只猫走过阳光下的草坪", + context_sequence={"raw": first_frame}, + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "i2v.mp4", fps=16) + +# 3) 视频续写。Context 帧数必须是 4n + 1(比如 9 帧)。 +context_frames = load_video("path/to/context.mp4") +context_tensor = torch.stack([to_tensor(f) for f in context_frames[:9]], dim=1).unsqueeze(0).to("cuda") +video = pipe( + prompt="继续这个故事", + context_sequence={"raw": context_tensor}, + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "tv2v.mp4", fps=16) +``` + +底层 patchify chunk 调度根据 `context_sequence` 自动调整:纯文生用 kernel 2 (full) 和 4 (compressed); +有 context 时第一个 chunk 改成 kernel 1,让条件帧保留全分辨率。 + +如果你已经有 VAE 编码过的 latent,可以直接传 `context_sequence={"latent": ...}` 跳过 `vae_encode` 步骤。 + +## 显存与推理速度 + +14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑: + +```py +import torch +from diffusers import AnyFlowPipeline +from diffusers.hooks import apply_group_offloading + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 +) +apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") +pipe.vae.enable_slicing() +pipe.vae.enable_tiling() +``` + +延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好: + +```py +pipe = pipe.to("cuda") +pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") +``` + +编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),相比 eager 模式 14B 通常能拿 2-3× 加速。 + +## LoRA 微调 + +两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的 +LoRA adapter 直接加载即可: + +```py +pipe.load_lora_weights("path/or/repo/with/wan_lora") +``` + +如果想做**继续 on-policy 蒸馏微调**(用论文里相同的 DMD 反向散度监督配方训新 LoRA),两个 pipeline +都暴露了 `training_rollout` 方法,驱动 3 段 Flow-Map backward simulation。普通用户可以在 autograd +模式下调它,配合自己的 DMD trainer 用。produce 出 release checkpoint 的原始训练框架在 +`Enderfga/AnyFlow`(不在 diffusers 范围内)。 + +## 常见坑 + +- **永远 `guidance_scale=1.0`。** 蒸馏后的 checkpoint 已经把 CFG 融进权重。设 `> 1` 会多跑一遍 + unconditional 前向、延迟翻倍、质量微降。 +- **Bidirectional pipeline 不支持流式。** 所有 `num_frames` 一起去噪。需要边采边播请用 causal pipeline。 +- **Causal pipeline KV cache 假设 chunk 调度跨调用一致。** 中途重建 cache 不被 release 模型支持。 +- **`num_frames` 必须满足 VAE 时间步长。** release checkpoint 用 `(N - 1) % 4 == 0` 的值(如 9、17、33、81)。 + +## 引用 + +```bibtex +@article{gu2026anyflow, + title = {AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation}, + author = {Gu, Yuchao and others}, + journal = {arXiv preprint arXiv:}, + year = {2026} +} + +@article{gu2025long, + title={Long-Context Autoregressive Video Modeling with Next-Frame Prediction}, + author={Gu, Yuchao and Mao, Weijia and Shou, Mike Zheng}, + journal={arXiv preprint arXiv:2503.19325}, + year={2025} +} +``` diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py index c4193b424976..e10dbe8b30ee 100644 --- a/scripts/convert_anyflow_to_diffusers.py +++ b/scripts/convert_anyflow_to_diffusers.py @@ -34,6 +34,7 @@ """ import argparse +import logging import os import torch @@ -46,6 +47,10 @@ ) +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + # Per-variant configuration. ``base_model`` is fetched from the Hub to source the matching VAE / text encoder. VARIANTS = { "AnyFlow-FAR-Wan2.1-1.3B-Diffusers": { @@ -104,11 +109,17 @@ def build_pipeline(variant: str, ckpt_path: str): state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["ema"] missing, unexpected = transformer.load_state_dict(state_dict, strict=False) if unexpected: - print( - f"[warn] unexpected keys in state dict (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}" + logger.warning( + "Unexpected keys in state dict (ignored): %s%s", + unexpected[:5], + "..." if len(unexpected) > 5 else "", ) if missing: - print(f"[warn] missing keys not loaded from state dict: {missing[:5]}{'...' if len(missing) > 5 else ''}") + logger.warning( + "Missing keys not loaded from state dict: %s%s", + missing[:5], + "..." if len(missing) > 5 else "", + ) scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=1000, shift=5.0) @@ -145,7 +156,7 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) pipeline = build_pipeline(args.variant, args.ckpt) pipeline.save_pretrained(args.output_dir) - print(f"Saved {args.variant} pipeline to {args.output_dir}") + logger.info("Saved %s pipeline to %s", args.variant, args.output_dir) if __name__ == "__main__": diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py index 26c8c882112a..d29ee4ffadc9 100644 --- a/src/diffusers/models/transformers/transformer_anyflow.py +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -33,7 +33,6 @@ from ...utils import logging from ..attention import FeedForward from ..attention_processor import Attention -from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -42,7 +41,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -flex_attention = torch.compile(flex_attention, dynamic=True) + +# `flex_attention` is JIT-compiled lazily on first call so that importing this module does not require +# Triton or a CUDA-capable device (CPU CI / older PyTorch builds otherwise fail at import time). +try: + flex_attention = torch.compile(flex_attention, dynamic=True) +except Exception as e: # pragma: no cover - environment-dependent + logger.warning( + "Failed to torch.compile flex_attention; falling back to the eager kernel. Error: %s", + e, + ) def build_block_mask(mask_2d, device): @@ -683,7 +691,7 @@ def forward( return hidden_states -class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 3D Transformer for any-step video diffusion. The architecture extends the Wan2.1 3D DiT backbone with two optional modules controlled by config flags: @@ -965,11 +973,9 @@ def mask_mod(b, h, q_idx, kv_idx): base = frame_idx[q_idx] >= frame_idx[kv_idx] # 4) interval mask - q_is_context = q_idx < context_seq_len # noqa: F841 q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end) q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end) - k_is_context = kv_idx < context_seq_len # noqa: F841 k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end) k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end) @@ -1033,7 +1039,7 @@ def mask_mod(b, h, q_idx, kv_idx): def mask_mod(b, h, q_idx, kv_idx): is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) - base = base = frame_idx[q_idx] >= frame_idx[kv_idx] + base = frame_idx[q_idx] >= frame_idx[kv_idx] return base & ~is_padding return create_block_mask( diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py index 74cda6137ee8..854c60f41f47 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py @@ -41,18 +41,22 @@ import ftfy +# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean def basic_clean(text): - text = ftfy.fix_text(text) + if is_ftfy_available(): + text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() +# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text +# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text @@ -67,6 +71,14 @@ class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin): 1, 2, 4, 8, 16... NFE without retraining. This pipeline operates over the full video tensor in one bidirectional pass; for frame-level autoregressive (causal) generation use ``AnyFlowCausalPipeline``. + The released NVIDIA checkpoints loaded by this pipeline went through a two-stage LoRA distillation: + (1) forward Flow-Map training with the MeanFlow identity as a stop-grad regression target, and + (2) on-policy distillation that combines Flow-Map backward simulation with DMD reverse-divergence + supervision over the student's own rollouts. Sampling at inference is plain Euler in mean-velocity + form (``z_r = z_t - (t - r) * u``) with no re-noising and no CFG (guidance was fused into the model + weights during stage 1). See ``training_rollout`` for the rollout entry point reused during DMD + fine-tuning. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -114,13 +126,14 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, - prompt: Union[str, List[str]] = None, + prompt: str | list[str] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -155,25 +168,26 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `list[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): + negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). @@ -382,6 +396,36 @@ def training_rollout( negative_prompt_embeds: Optional[torch.Tensor] = None, guidance_scale: float = 1.0, ): + r""" + Three-segment Flow-Map backward simulation used as the on-policy rollout for stage-2 DMD + distillation. Not part of the standard inference path — end users should call ``__call__``. + + When ``grad_timestep`` is ``None`` the method reduces to a plain (no-grad) multi-step rollout. + When ``grad_timestep`` is set, the rollout is split into three segments (``z_T -> z_t``, the + gradient anchor ``z_t -> z_r`` step, then ``z_r -> z_0``); every segment contributes to the + autograd graph, matching Algorithm 2 of the AnyFlow paper. This is the entry point that the + on-policy trainer composes with a frozen ``real_score`` and a trainable discriminator to compute + the DMD KL-gradient surrogate. + + Args: + context_sequence (`torch.Tensor`, *optional*): + Clean prefix latents to keep fixed during the rollout (used by I2V / TV2V variants). + num_inference_steps (`int`, defaults to 50): + Number of inference steps used to discretize the rollout schedule. + grad_timestep (`int`, *optional*): + Index into the inference schedule that becomes the gradient anchor. ``None`` disables + gradients and turns the call into a plain rollout. + latents (`torch.Tensor`, *optional*): + Initial Gaussian latents at ``t = T``. + prompt_embeds, negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text encoder embeddings. + guidance_scale (`float`, defaults to 1.0): + CFG scale applied at runtime. Set to 1.0 (default) for distilled checkpoints since CFG + was fused into the weights during stage 1. + + Returns: + `torch.Tensor`: the rolled-out latents at the final timestep. + """ self._guidance_scale = guidance_scale if negative_prompt_embeds is not None: @@ -532,7 +576,7 @@ def __call__( device=device, ) - transformer_dtype = torch.bfloat16 + transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py index a5de05cf4c84..992c1ab1bc44 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_causal.py @@ -42,18 +42,22 @@ import ftfy +# Copied from diffusers.pipelines.wan.pipeline_wan.basic_clean def basic_clean(text): - text = ftfy.fix_text(text) + if is_ftfy_available(): + text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() +# Copied from diffusers.pipelines.wan.pipeline_wan.whitespace_clean def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text +# Copied from diffusers.pipelines.wan.pipeline_wan.prompt_clean def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text @@ -65,8 +69,21 @@ class AnyFlowCausalPipeline(DiffusionPipeline, WanLoraLoaderMixin): The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with flow-map steps while attending only to past chunks via block-sparse causal attention, and intermediate - KV cache is reused across chunks. Set ``task_type`` per call to switch between ``"t2v"``, ``"i2v"``, and - ``"tv2v"``. + KV cache is reused across chunks. + + The task mode (T2V / I2V / TV2V) is selected by the ``context_sequence`` argument passed to ``__call__``: + + - ``context_sequence=None`` — pure text-to-video. + - ``context_sequence={"raw":