From 0aea69b4badfbca28cb3bfef36308e46aaf9223b Mon Sep 17 00:00:00 2001 From: James Huang Date: Thu, 12 Mar 2026 07:57:36 +0000 Subject: [PATCH] CFG Cache For Wan 2.2 Signed-off-by: James Huang --- .../pipelines/wan/wan_pipeline_2_2.py | 210 ++++++++++--- src/maxdiffusion/tests/wan_cfg_cache_test.py | 287 ++++++++++++++++++ 2 files changed, 459 insertions(+), 38 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 16b601ba..b8f818e3 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional from ...pyconfig import HyperParameters @@ -21,6 +21,7 @@ from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp +import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler @@ -32,7 +33,7 @@ def __init__( config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], - **kwargs + **kwargs, ): super().__init__(config=config, **kwargs) self.low_noise_transformer = low_noise_transformer @@ -109,7 +110,15 @@ def __call__( prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, vae_only: bool = False, + use_cfg_cache: bool = False, ): + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases." + ) + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, @@ -138,6 +147,8 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, + use_cfg_cache=use_cfg_cache, + height=height, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -172,51 +183,174 @@ def run_inference_2_2( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, + use_cfg_cache: bool = False, + height: int = 480, ): + """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache. + + Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True): + - High-noise phase (t >= boundary): always full CFG — short phase, critical + for establishing video structure. + - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N + steps, FFT frequency-domain compensation on cache steps (batch×1). + - Boundary transition: mandatory full CFG step to populate cache for the + low-noise transformer. + - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025). + """ do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - def low_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance, - guidance_scale_low, - ) + bsz = latents.shape[0] - def high_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance, - guidance_scale_high, + # ── CFG cache path ── + if use_cfg_cache and do_classifier_free_guidance: + # Get timesteps as numpy for Python-level scheduling decisions + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + + # Resolution-dependent CFG cache config — adapted for Wan 2.2. + if height >= 720: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + cfg_cache_alpha = 0.2 + else: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 1 + cfg_cache_alpha = 0.2 + + # Pre-split embeds once + prompt_cond_embeds = prompt_embeds + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + # Determine the first low-noise step (boundary transition). + # In Wan 2.2 the boundary IS the structural→detail transition, so + # all low-noise cache steps should emphasise high-frequency correction. + first_low_step = next( + (s for s in range(num_inference_steps) if not step_uses_high[s]), + num_inference_steps, ) + t0_step = first_low_step # all cache steps get high-freq boost + + # Pre-compute cache schedule and phase-dependent weights. + first_full_in_low_seen = False + step_is_cache = [] + step_w1w2 = [] + for s in range(num_inference_steps): + if step_uses_high[s]: + # Never cache high-noise transformer steps + step_is_cache.append(False) + else: + is_cache = ( + first_full_in_low_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_in_low_seen = True + + # Phase-dependent weights: w = 1 + α·I(condition) + if s < t0_step: + step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq + else: + step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq + + # Cache tensors (on-device JAX arrays, initialised to None). + cached_noise_cond = None + cached_noise_uncond = None + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + is_cache_step = step_is_cache[step] + + # Select transformer and guidance scale based on precomputed schedule + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + + if is_cache_step: + # ── Cache step: cond-only forward + FFT frequency compensation ── + w1, w2 = step_w1w2[step] + timestep = jnp.broadcast_to(t, bsz) + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( + graphdef, + state, + rest, + latents, + timestep, + prompt_cond_embeds, + cached_noise_cond, + cached_noise_uncond, + guidance_scale=guidance_scale, + w1=jnp.float32(w1), + w2=jnp.float32(w2), + ) + else: + # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents + + # ── Original non-cache path ── + # Uses same Python-level if/else transformer selection as the cache path + # so both paths compile to identical XLA graphs (critical for bfloat16 + # reproducibility in the PSNR comparison). + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - use_high_noise = jnp.greater_equal(t, boundary) + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low - # Selects the model based on the current timestep: - # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise). - # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise). - noise_pred, latents = jax.lax.cond( - use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds) - ) + if do_classifier_free_guidance: + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + else: + timestep = jnp.broadcast_to(t, bsz) + noise_pred, latents = transformer_forward_pass( + graphdef, + state, + rest, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale, + ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents diff --git a/src/maxdiffusion/tests/wan_cfg_cache_test.py b/src/maxdiffusion/tests/wan_cfg_cache_test.py index 3543cf69..a499bc54 100644 --- a/src/maxdiffusion/tests/wan_cfg_cache_test.py +++ b/src/maxdiffusion/tests/wan_cfg_cache_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -270,5 +271,291 @@ def test_cfg_cache_speedup_and_fidelity(self): self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") +class Wan22CfgCacheValidationTest(unittest.TestCase): + """Tests that use_cfg_cache=True with guidance_scale <= 1.0 raises ValueError for Wan 2.2.""" + + def _make_pipeline(self): + """Create a WanPipeline2_2 instance with mocked internals.""" + pipeline = WanPipeline2_2.__new__(WanPipeline2_2) + return pipeline + + def test_cfg_cache_with_both_scales_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=1.0, + guidance_scale_high=1.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_low_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=0.5, + guidance_scale_high=4.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_high_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=1.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_valid_scales_no_validation_error(self): + """Both guidance_scales > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_cfg_cache=True, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + def test_no_cfg_cache_with_low_scales_no_error(self): + """use_cfg_cache=False should never raise our ValueError.""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale_low=0.5, + guidance_scale_high=0.5, + use_cfg_cache=False, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class Wan22CfgCacheScheduleTest(unittest.TestCase): + """Tests the CFG cache schedule for Wan 2.2 dual-transformer architecture. + + Key difference from 2.1: high-noise steps are never cached, and the first + low-noise step always does full CFG to populate the cache. + """ + + def _get_cache_schedule_2_2(self, num_inference_steps, boundary_ratio=0.875, num_train_timesteps=1000, height=720): + """Extract the cache schedule from run_inference_2_2's logic. + + Returns (step_is_cache, step_uses_high) lists. + """ + boundary = boundary_ratio * num_train_timesteps + + # Simulate timesteps (linearly spaced, descending — simplified) + timesteps = np.linspace(num_train_timesteps - 1, 0, num_inference_steps, dtype=np.int32) + step_uses_high = [bool(timesteps[s] >= boundary) for s in range(num_inference_steps)] + + if height >= 720: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + else: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 1 + + first_full_in_low_seen = False + step_is_cache = [] + for s in range(num_inference_steps): + if step_uses_high[s]: + step_is_cache.append(False) + else: + is_cache = ( + first_full_in_low_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_in_low_seen = True + + return step_is_cache, step_uses_high + + def test_high_noise_steps_never_cached(self): + """High-noise phase steps (t >= boundary) must never be cache steps.""" + step_is_cache, step_uses_high = self._get_cache_schedule_2_2(50) + for s in range(50): + if step_uses_high[s]: + self.assertFalse(step_is_cache[s], f"Step {s} is high-noise but marked as cache") + + def test_first_low_noise_step_is_full_cfg(self): + """The first low-noise step must be full CFG to populate the cache.""" + step_is_cache, step_uses_high = self._get_cache_schedule_2_2(50) + first_low = next(s for s in range(50) if not step_uses_high[s]) + self.assertFalse(step_is_cache[first_low], f"First low-noise step {first_low} should be full CFG") + + def test_has_cache_steps_in_low_noise_phase(self): + """There should be cache steps in the low-noise phase.""" + step_is_cache, step_uses_high = self._get_cache_schedule_2_2(50) + low_noise_cache_count = sum(1 for s in range(50) if not step_uses_high[s] and step_is_cache[s]) + self.assertGreater(low_noise_cache_count, 0, "Should have cache steps in the low-noise phase") + + def test_boundary_ratio_affects_high_noise_count(self): + """Lower boundary_ratio means more high-noise steps (easier threshold to exceed).""" + _, high_09 = self._get_cache_schedule_2_2(50, boundary_ratio=0.9) + _, high_05 = self._get_cache_schedule_2_2(50, boundary_ratio=0.5) + self.assertGreater(sum(high_05), sum(high_09), "Lower boundary_ratio should have more high-noise steps") + + def test_720p_more_conservative_than_480p(self): + """720p should have fewer cache steps than 480p.""" + cache_720, _ = self._get_cache_schedule_2_2(50, height=720) + cache_480, _ = self._get_cache_schedule_2_2(50, height=480) + self.assertGreater(sum(cache_480), sum(cache_720), "720p should be more conservative than 480p") + + def test_cache_interval_in_low_noise_phase(self): + """Every cfg_cache_interval-th step after start should be full CFG.""" + step_is_cache, step_uses_high = self._get_cache_schedule_2_2(50, height=480) + start = int(50 / 3) + end = 49 + for s in range(start, end): + if not step_uses_high[s] and (s - start) % 5 == 0: + self.assertFalse(step_is_cache[s], f"Step {s} should be full CFG (interval=5)") + + def test_short_run_no_cache(self): + """Very few steps should have no cache steps.""" + step_is_cache, _ = self._get_cache_schedule_2_2(3) + self.assertEqual(sum(step_is_cache), 0, "3 steps is too short for cache") + + def test_all_high_noise_no_cache(self): + """If boundary_ratio=0, all steps are high noise, no caching.""" + step_is_cache, step_uses_high = self._get_cache_schedule_2_2(50, boundary_ratio=0.0) + self.assertTrue(all(step_uses_high), "All steps should be high-noise") + self.assertEqual(sum(step_is_cache), 0, "No cache steps when all high-noise") + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class Wan22CfgCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: CFG cache for Wan 2.2 dual-transformer. + + Runs on TPU v7-8 (8 chips, context_parallelism=8) with WAN 2.2 27B, 720p. + Skipped in CI (GitHub Actions) — run locally with: + python -m pytest src/maxdiffusion/tests/wan_cfg_cache_test.py::Wan22CfgCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_27b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "fps=24", + "guidance_scale_low=3.0", + "guidance_scale_high=4.0", + "boundary_ratio=0.875", + "flow_shift=3.0", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointer2_2(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + # Warmup both XLA code paths + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale_low=cls.config.guidance_scale_low, + guidance_scale_high=cls.config.guidance_scale_high, + use_cfg_cache=use_cache, + ) + + def _run_pipeline(self, use_cfg_cache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale_low=self.config.guidance_scale_low, + guidance_scale_high=self.config.guidance_scale_high, + use_cfg_cache=use_cfg_cache, + ) + return videos, time.perf_counter() - t0 + + def test_cfg_cache_speedup_and_fidelity(self): + """CFG cache must be faster than baseline with PSNR >= 30 dB and SSIM >= 0.95.""" + videos_baseline, t_baseline = self._run_pipeline(use_cfg_cache=False) + videos_cached, t_cached = self._run_pipeline(use_cfg_cache=True) + + # Speed check + speedup = t_baseline / t_cached + print(f"Baseline: {t_baseline:.2f}s, CFG cache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"CFG cache should be faster. Speedup={speedup:.3f}x") + + # Fidelity checks + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + # PSNR + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + print(f"PSNR: {psnr:.2f} dB") + self.assertGreaterEqual(psnr, 30.0, f"PSNR={psnr:.2f} dB < 30 dB") + + # SSIM (per-frame) + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + print(f"SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + if __name__ == "__main__": absltest.main()