From f6c36b9371132e63c0ecb5512a3cb3b764bb0e40 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Wed, 24 Jun 2026 08:59:48 +0800 Subject: [PATCH 1/2] Fix get_dummy_wan_inputs for two-expert Wan2.2 I2V pipeline get_dummy_wan_inputs (used by WanPipeline.quantize_transformer) assumed a single pipeline.transformer and the single-transformer prepare_latents() signature. WanPipelineI2V_2_2 (Wan2.2-I2V-A14B) is two-expert: it has low_noise_transformer / high_noise_transformer (no .transformer) and a different prepare_latents() signature, so qwix quantization crashed with AttributeError: 'WanPipelineI2V_2_2' object has no attribute 'transformer' and then TypeError: prepare_latents() got an unexpected keyword argument 'vae_scale_factor_temporal'. When pipeline.transformer is absent, build dummy latents directly in the (B, C, F, H, W) layout WanModel.__call__ expects, taking num_channels_latents from an existing expert. The single-transformer path is unchanged. Validated on Wan2.2-I2V-A14B (v6e): quantization now proceeds past these errors into qwix.quantize_model. --- src/maxdiffusion/maxdiffusion_utils.py | 32 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index ef152b15..e06c53ea 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -358,15 +358,29 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size): def get_dummy_wan_inputs(config, pipeline, batch_size): - latents = pipeline.prepare_latents( - batch_size, - vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, - vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_channels_latents=pipeline.transformer.config.in_channels, - ) + if getattr(pipeline, "transformer", None) is not None: + latents = pipeline.prepare_latents( + batch_size, + vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, + vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_channels_latents=pipeline.transformer.config.in_channels, + ) + else: + # The two-expert Wan2.2 I2V pipeline (WanPipelineI2V_2_2) has no `.transformer` + # and a different `prepare_latents()` signature, so build dummy latents directly + # in the (B, C, F, H, W) layout WanModel.__call__ expects. + transformer = getattr(pipeline, "low_noise_transformer", None) or pipeline.high_noise_transformer + num_channels_latents = transformer.config.in_channels + num_latent_frames = (config.num_frames - 1) // pipeline.vae_scale_factor_temporal + 1 + latent_height = config.height // pipeline.vae_scale_factor_spatial + latent_width = config.width // pipeline.vae_scale_factor_spatial + latents = jax.random.normal( + jax.random.key(config.seed), + (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width), + ) bsz = latents.shape[0] prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) timesteps = jnp.array([0] * bsz, dtype=jnp.int32) From 62122128c7e2564d902dec060cb2ae825239dd30 Mon Sep 17 00:00:00 2001 From: ThomasNing Date: Thu, 25 Jun 2026 08:49:00 +0800 Subject: [PATCH 2/2] Generate WAN dummy latents uniformly --- src/maxdiffusion/maxdiffusion_utils.py | 45 +++++++++---------- .../tests/maxdiffusion_utils_test.py | 35 +++++++++++++++ 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index e06c53ea..a9fcbdf0 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -357,30 +357,29 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size): ) +def _get_wan_transformer_for_dummy_inputs(pipeline): + for transformer_attr in ("transformer", "low_noise_transformer", "high_noise_transformer"): + transformer = getattr(pipeline, transformer_attr, None) + if transformer is not None: + return transformer + raise ValueError("WAN dummy inputs require a transformer, low_noise_transformer, or high_noise_transformer.") + + +def _get_dummy_wan_latents(config, pipeline, batch_size): + transformer = _get_wan_transformer_for_dummy_inputs(pipeline) + num_channels_latents = transformer.config.in_channels + num_latent_frames = (int(config.num_frames) - 1) // pipeline.vae_scale_factor_temporal + 1 + latent_height = int(config.height) // pipeline.vae_scale_factor_spatial + latent_width = int(config.width) // pipeline.vae_scale_factor_spatial + return jax.random.normal( + jax.random.key(config.seed), + (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width), + dtype=jnp.float32, + ) + + def get_dummy_wan_inputs(config, pipeline, batch_size): - if getattr(pipeline, "transformer", None) is not None: - latents = pipeline.prepare_latents( - batch_size, - vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, - vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_channels_latents=pipeline.transformer.config.in_channels, - ) - else: - # The two-expert Wan2.2 I2V pipeline (WanPipelineI2V_2_2) has no `.transformer` - # and a different `prepare_latents()` signature, so build dummy latents directly - # in the (B, C, F, H, W) layout WanModel.__call__ expects. - transformer = getattr(pipeline, "low_noise_transformer", None) or pipeline.high_noise_transformer - num_channels_latents = transformer.config.in_channels - num_latent_frames = (config.num_frames - 1) // pipeline.vae_scale_factor_temporal + 1 - latent_height = config.height // pipeline.vae_scale_factor_spatial - latent_width = config.width // pipeline.vae_scale_factor_spatial - latents = jax.random.normal( - jax.random.key(config.seed), - (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width), - ) + latents = _get_dummy_wan_latents(config, pipeline, batch_size) bsz = latents.shape[0] prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) timesteps = jnp.array([0] * bsz, dtype=jnp.int32) diff --git a/src/maxdiffusion/tests/maxdiffusion_utils_test.py b/src/maxdiffusion/tests/maxdiffusion_utils_test.py index 4b29d365..65708494 100644 --- a/src/maxdiffusion/tests/maxdiffusion_utils_test.py +++ b/src/maxdiffusion/tests/maxdiffusion_utils_test.py @@ -15,7 +15,9 @@ """ import os +from types import SimpleNamespace import unittest +from unittest.mock import Mock from jax.sharding import Mesh @@ -35,6 +37,39 @@ class MaxDiffusionUtilsTest(unittest.TestCase): def setUp(self): MaxDiffusionUtilsTest.dummy_data = {} + def test_get_dummy_wan_inputs_generates_latents_without_pipeline_prepare_latents(self): + config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0) + pipeline = SimpleNamespace( + transformer=SimpleNamespace(config=SimpleNamespace(in_channels=16)), + vae_scale_factor_temporal=4, + vae_scale_factor_spatial=8, + prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")), + ) + + latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2) + + pipeline.prepare_latents.assert_not_called() + self.assertEqual(latents.shape, (2, 16, 3, 8, 10)) + self.assertEqual(prompt_embeds.shape, (2, 512, 4096)) + self.assertEqual(timesteps.shape, (2,)) + + def test_get_dummy_wan_inputs_supports_two_expert_pipeline(self): + config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0) + pipeline = SimpleNamespace( + low_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)), + high_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)), + vae_scale_factor_temporal=4, + vae_scale_factor_spatial=8, + prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")), + ) + + latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2) + + pipeline.prepare_latents.assert_not_called() + self.assertEqual(latents.shape, (2, 48, 3, 8, 10)) + self.assertEqual(prompt_embeds.shape, (2, 512, 4096)) + self.assertEqual(timesteps.shape, (2,)) + def test_create_scheduler(self): """Test create scheduler with different schedulers""" pyconfig.initialize(