Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,29 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size):
)


def get_dummy_wan_inputs(config, pipeline, batch_size):
latents = pipeline.prepare_latents(
batch_size,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_channels_latents=pipeline.transformer.config.in_channels,
def _get_wan_transformer_for_dummy_inputs(pipeline):
for transformer_attr in ("transformer", "low_noise_transformer", "high_noise_transformer"):
transformer = getattr(pipeline, transformer_attr, None)
if transformer is not None:
return transformer
raise ValueError("WAN dummy inputs require a transformer, low_noise_transformer, or high_noise_transformer.")


def _get_dummy_wan_latents(config, pipeline, batch_size):
transformer = _get_wan_transformer_for_dummy_inputs(pipeline)
num_channels_latents = transformer.config.in_channels
num_latent_frames = (int(config.num_frames) - 1) // pipeline.vae_scale_factor_temporal + 1
latent_height = int(config.height) // pipeline.vae_scale_factor_spatial
latent_width = int(config.width) // pipeline.vae_scale_factor_spatial
return jax.random.normal(
jax.random.key(config.seed),
(batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width),
dtype=jnp.float32,
)


def get_dummy_wan_inputs(config, pipeline, batch_size):
latents = _get_dummy_wan_latents(config, pipeline, batch_size)
bsz = latents.shape[0]
prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096))
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
Expand Down
35 changes: 35 additions & 0 deletions src/maxdiffusion/tests/maxdiffusion_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""

import os
from types import SimpleNamespace
import unittest
from unittest.mock import Mock

from jax.sharding import Mesh

Expand All @@ -35,6 +37,39 @@ class MaxDiffusionUtilsTest(unittest.TestCase):
def setUp(self):
MaxDiffusionUtilsTest.dummy_data = {}

def test_get_dummy_wan_inputs_generates_latents_without_pipeline_prepare_latents(self):
config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0)
pipeline = SimpleNamespace(
transformer=SimpleNamespace(config=SimpleNamespace(in_channels=16)),
vae_scale_factor_temporal=4,
vae_scale_factor_spatial=8,
prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")),
)

latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2)

pipeline.prepare_latents.assert_not_called()
self.assertEqual(latents.shape, (2, 16, 3, 8, 10))
self.assertEqual(prompt_embeds.shape, (2, 512, 4096))
self.assertEqual(timesteps.shape, (2,))

def test_get_dummy_wan_inputs_supports_two_expert_pipeline(self):
config = SimpleNamespace(height=64, width=80, num_frames=9, seed=0)
pipeline = SimpleNamespace(
low_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)),
high_noise_transformer=SimpleNamespace(config=SimpleNamespace(in_channels=48)),
vae_scale_factor_temporal=4,
vae_scale_factor_spatial=8,
prepare_latents=Mock(side_effect=AssertionError("prepare_latents should not be called")),
)

latents, prompt_embeds, timesteps = maxdiffusion_utils.get_dummy_wan_inputs(config, pipeline, batch_size=2)

pipeline.prepare_latents.assert_not_called()
self.assertEqual(latents.shape, (2, 48, 3, 8, 10))
self.assertEqual(prompt_embeds.shape, (2, 512, 4096))
self.assertEqual(timesteps.shape, (2,))

def test_create_scheduler(self):
"""Test create scheduler with different schedulers"""
pyconfig.initialize(
Expand Down