diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index ef152b15..a9fcbdf0 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -357,16 +357,29 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size): ) -def get_dummy_wan_inputs(config, pipeline, batch_size): - latents = pipeline.prepare_latents( - batch_size, - vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, - vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_channels_latents=pipeline.transformer.config.in_channels, +def _get_wan_transformer_for_dummy_inputs(pipeline): + for transformer_attr in ("transformer", "low_noise_transformer", "high_noise_transformer"): + transformer = getattr(pipeline, transformer_attr, None) + if transformer is not None: + return transformer + raise ValueError("WAN dummy inputs require a transformer, low_noise_transformer, or high_noise_transformer.") + + +def _get_dummy_wan_latents(config, pipeline, batch_size): + transformer = _get_wan_transformer_for_dummy_inputs(pipeline) + num_channels_latents = transformer.config.in_channels + num_latent_frames = (int(config.num_frames) - 1) // pipeline.vae_scale_factor_temporal + 1 + latent_height = int(config.height) // pipeline.vae_scale_factor_spatial + latent_width = int(config.width) // pipeline.vae_scale_factor_spatial + return jax.random.normal( + jax.random.key(config.seed), + (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width), + dtype=jnp.float32, ) + + +def get_dummy_wan_inputs(config, pipeline, batch_size): + latents = _get_dummy_wan_latents(config, pipeline, batch_size) bsz = latents.shape[0] prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096)) timesteps = jnp.array([0] * bsz, dtype=jnp.int32) diff --git a/src/maxdiffusion/tests/maxdiffusion_utils_test.py b/src/maxdiffusion/tests/maxdiffusion_utils_test.py index 4b29d365..65708494 100644 --- a/src/maxdiffusion/tests/maxdiffusion_utils_test.py +++ b/src/maxdiffusion/tests/maxdiffusion_utils_test.py @@ -15,7 +15,9 @@ """ import os +from types import SimpleNamespace import unittest +from unittest.mock import Mock from jax.sharding import Mesh @@ -35,6 +37,39 @@ class MaxDiffusionUtilsTest(unittest.TestCase): def setUp(self): MaxDiffusionUtilsTest.dummy_data = {} + def test_get_dummy_wan_inputs_generates_latents_without_pipeline_prepare_latents(self): + config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0) + pipeline = SimpleNamespace( + transformer=SimpleNamespace(config=SimpleNamespace(in_channels=16)), + vae_scale_factor_temporal=4, + vae_scale_factor_spatial=8, + prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")), + ) + + latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2) + + pipeline.prepare_latents.assert_not_called() + self.assertEqual(latents.shape, (2, 16, 3, 8, 10)) + self.assertEqual(prompt_embeds.shape, (2, 512, 4096)) + self.assertEqual(timesteps.shape, (2,)) + + def test_get_dummy_wan_inputs_supports_two_expert_pipeline(self): + config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0) + pipeline = SimpleNamespace( + low_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)), + high_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)), + vae_scale_factor_temporal=4, + vae_scale_factor_spatial=8, + prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")), + ) + + latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2) + + pipeline.prepare_latents.assert_not_called() + self.assertEqual(latents.shape, (2, 48, 3, 8, 10)) + self.assertEqual(prompt_embeds.shape, (2, 512, 4096)) + self.assertEqual(timesteps.shape, (2,)) + def test_create_scheduler(self): """Test create scheduler with different schedulers""" pyconfig.initialize(