From fb25b23ff242b5982c1de67cdc8f92d0bc6fe314 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 11:17:19 +0000 Subject: [PATCH 1/9] Update wan configs for training - Ensure `adam_weight_decay` is a float. - Add `tensorboard_dir` parameter for logging. Co-authored-by: martinarroyo --- src/maxdiffusion/configs/base_wan_14b.yml | 3 ++- src/maxdiffusion/configs/base_wan_1_3b.yml | 3 ++- src/maxdiffusion/configs/base_wan_27b.yml | 3 ++- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 3 ++- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 91a3e092a..d7d5a1cbd 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -145,6 +145,7 @@ diffusion_scheduler_config: { # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" +tensorboard_dir: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' @@ -300,7 +301,7 @@ save_optimizer: False adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 0 # AdamW Weight decay +adam_weight_decay: 0.0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index ffd2864a8..85201cda0 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -122,6 +122,7 @@ diffusion_scheduler_config: { # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" +tensorboard_dir: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' @@ -256,7 +257,7 @@ save_optimizer: False adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 0 # AdamW Weight decay +adam_weight_decay: 0.0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 022b18c91..09d9175d7 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -133,6 +133,7 @@ diffusion_scheduler_config: { # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" +tensorboard_dir: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' @@ -267,7 +268,7 @@ save_optimizer: False adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 0 # AdamW Weight decay +adam_weight_decay: 0.0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 2a5b0338c..f5e66c0c6 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -128,6 +128,7 @@ diffusion_scheduler_config: { # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" +tensorboard_dir: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' @@ -262,7 +263,7 @@ save_optimizer: False adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 0 # AdamW Weight decay +adam_weight_decay: 0.0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0bd6a27f2..a9c0fa9e1 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -129,6 +129,7 @@ diffusion_scheduler_config: { # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" +tensorboard_dir: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' @@ -263,7 +264,7 @@ save_optimizer: False adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 0 # AdamW Weight decay +adam_weight_decay: 0.0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False From e205aa1d306e81876ba9d3356558a3c207613f8f Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 12:55:17 +0000 Subject: [PATCH 2/9] Wan training: Resolve training mode bug with dropout and layer_forward - Conditionally apply dropout only when rate > 0. - Use standard list initialization. - Add rngs parameter to layer_forward (essential for gradient checkpointing with dropout > 0) Co-authored-by: martinarroyo --- src/maxdiffusion/models/attention_flax.py | 5 ++++- .../models/wan/transformers/transformer_wan.py | 5 ++++- .../models/wan/transformers/transformer_wan_vace.py | 12 +++++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b583171d7..b442a9399 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1239,7 +1239,10 @@ def __call__( with jax.named_scope("proj_attn"): hidden_states = self.proj_attn(attn_output) - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + if self.drop_out.rate > 0: + hidden_states = self.drop_out( + hidden_states, deterministic=deterministic, rngs=rngs + ) return hidden_states diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 887cb0d06..68f13706e 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -262,7 +262,10 @@ def conditional_named_scope(self, name: str): def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) hidden_states = checkpoint_name(hidden_states, "ffn_activation") - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + if self.drop_out.rate > 0: + hidden_states = self.drop_out( + hidden_states, deterministic=deterministic, rngs=rngs + ) with jax.named_scope("proj_out"): return self.proj_out(hidden_states) # output is (4, 75600, 5120) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index fc3e67e39..7be3697be 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -487,10 +487,10 @@ def __call__( raise NotImplementedError("scan_layers is not supported yet") else: # Prepare VACE hints - control_hidden_states_list = nnx.List([]) + control_hidden_states_list = [] for i, vace_block in enumerate(self.vace_blocks): - def layer_forward(hidden_states, control_hidden_states): + def layer_forward(hidden_states, control_hidden_states, rngs): return vace_block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -507,14 +507,16 @@ def layer_forward(hidden_states, control_hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - conditioning_states, control_hidden_states = rematted_layer_forward(hidden_states, control_hidden_states) + conditioning_states, control_hidden_states = rematted_layer_forward( + hidden_states, control_hidden_states, rngs + ) control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): - def layer_forward_vace(hidden_states): + def layer_forward_vace(hidden_states, rngs): return block( hidden_states, encoder_hidden_states, @@ -530,7 +532,7 @@ def layer_forward_vace(hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - hidden_states = rematted_layer_forward(hidden_states) + hidden_states = rematted_layer_forward(hidden_states, rngs) if i in self.config.vace_layers: control_hint, scale = control_hidden_states_list.pop() hidden_states = hidden_states + control_hint * scale From 1fe4ce0bb01aaeb126f27655c2b8a89cb19a9b76 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 12:58:17 +0000 Subject: [PATCH 3/9] Wan training: use learning rate from config Replaces the hardcoded learning rate in the optimizer creation with the value from `config.learning_rate`. Co-authored-by: martinarroyo --- src/maxdiffusion/trainers/wan_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 8d865e589..62a73467a 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -257,7 +257,9 @@ def start_training(self): scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(pipeline.transformer, self.config, 1e-5) + optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer( + pipeline.transformer, self.config, self.config.learning_rate + ) # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) From 5c6f65fabb099f64d8fd18be4fe193318006dd0d Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 12:59:32 +0000 Subject: [PATCH 4/9] Fix: Ensure prepare_sample_fn is used for 'tfrecord' dataset type Co-authored-by: martinarroyo --- .../input_pipeline/_tfds_data_processing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index dae9a3a1e..3c40ba240 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -113,7 +113,11 @@ def _make_tfrecord_iterator( "clip_embeddings": tf.io.FixedLenFeature([], tf.string), } - used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description + used_feature_description = ( + feature_description_fn + if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord") + else feature_description + ) def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, used_feature_description) @@ -141,7 +145,11 @@ def prepare_sample(features): ds = ds.concatenate(padding_ds) max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.") - used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample + used_prepare_sample = ( + prepare_sample_fn + if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord") + else prepare_sample + ) ds = ( ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) From 610138694fd9a8c01fe6525703783c64ce862c59 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:04:04 +0000 Subject: [PATCH 5/9] Wan training: Set default dropout to 0.0 in Wan configs Co-authored-by: martinarroyo --- src/maxdiffusion/configs/base_wan_14b.yml | 2 +- src/maxdiffusion/configs/base_wan_1_3b.yml | 2 +- src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index d7d5a1cbd..3856dd919 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -72,7 +72,7 @@ mask_padding_tokens: True # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded # in cross attention q. attention_sharding_uniform: True -dropout: 0.1 +dropout: 0.0 flash_block_sizes: { "block_q" : 512, diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 85201cda0..b43f256cd 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -72,7 +72,7 @@ mask_padding_tokens: True # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded # in cross attention q. attention_sharding_uniform: True -dropout: 0.1 +dropout: 0.0 flash_block_sizes: { "block_q" : 512, diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 09d9175d7..5739e6767 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -71,7 +71,7 @@ mask_padding_tokens: True # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded # in cross attention q. attention_sharding_uniform: True -dropout: 0.1 +dropout: 0.0 flash_block_sizes: { "block_q" : 512, diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index f5e66c0c6..a2e236487 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -62,7 +62,7 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 -dropout: 0.1 +dropout: 0.0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index a9c0fa9e1..080f90531 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -62,7 +62,7 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 -dropout: 0.1 +dropout: 0.0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. From efbc91d60d648f9d837713624fd23c3625a7b1e5 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:07:38 +0000 Subject: [PATCH 6/9] Wan 2.1 training: Resolve checkpoint loading issues with larger TPU slices and different topologies Co-authored-by: martinarroyo --- .../checkpointing/wan_checkpointer_2_1.py | 36 ++++++++++++++----- .../checkpointing/wan_checkpointer_i2v_2p1.py | 36 ++++++++++++++----- .../pipelines/wan/wan_pipeline.py | 23 ++++++++++-- .../pipelines/wan/wan_vace_pipeline_2_1.py | 21 +++++++++-- 4 files changed, 95 insertions(+), 21 deletions(-) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index da30567bc..91c73804e 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -15,14 +15,15 @@ """ import json -import jax -import numpy as np from typing import Optional, Tuple -from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 -from .. import max_logging -import orbax.checkpoint as ocp from etils import epath +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer +import numpy as np +import orbax.checkpoint as ocp +from .. import max_logging +from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 class WanCheckpointer2_1(WanCheckpointer): @@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log("No WAN checkpoint found.") return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") + + cpu_devices = np.array(jax.devices(backend="cpu")) + mesh = Mesh(cpu_devices, axis_names=("data",)) + replicated_sharding = NamedSharding(mesh, P()) + metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + state = metadatas.wan_state + + def add_sharding_to_struct(leaf_struct, sharding): + return jax.ShapeDtypeStruct( + shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding + ) + + target_shardings = jax.tree_util.tree_map( + lambda x: replicated_sharding, state + ) + + with mesh: + abstract_train_state_with_sharding = jax.tree_util.tree_map( + add_sharding_to_struct, state, target_shardings + ) + params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_params, + abstract_train_state_with_sharding, ) ) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py index 5850692f3..a6dacb0bf 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -15,14 +15,15 @@ """ import json -import jax -import numpy as np from typing import Optional, Tuple -from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 -from .. import max_logging -import orbax.checkpoint as ocp from etils import epath +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer +import numpy as np +import orbax.checkpoint as ocp +from .. import max_logging +from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 class WanCheckpointerI2V_2_1(WanCheckpointer): @@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic max_logging.log("No WAN checkpoint found.") return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") + + cpu_devices = np.array(jax.devices(backend="cpu")) + mesh = Mesh(cpu_devices, axis_names=("data",)) + replicated_sharding = NamedSharding(mesh, P()) + metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + state = metadatas.wan_state + + def add_sharding_to_struct(leaf_struct, sharding): + return jax.ShapeDtypeStruct( + shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding + ) + + target_shardings = jax.tree_util.tree_map( + lambda x: replicated_sharding, state + ) + + with mesh: + abstract_train_state_with_sharding = jax.tree_util.tree_map( + add_sharding_to_struct, state, target_shardings + ) + params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_params, + abstract_train_state_with_sharding, ) ) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7c0314b40..0bafe3791 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -168,9 +168,26 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): ) for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: - path = path[:-1] + if path[-1] == "value": + path = path[:-1] # remove 'value' + + try: + # Convert block indices to integers, as they might have been loaded as strings from the checkpoint. + path = path[:1] + (int(path[1]),) + path[2:] + except Exception: + pass + sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) + try: + state[path].value = device_put_replicated(val, sharding) + except Exception as e: + max_logging.log(f"Failed to device_put_replicated {path}: {e}") + max_logging.log(f"Trying to use process_allgather for {path}") + val_on_host = jax.experimental.multihost_utils.process_allgather( + val, tiled=True + ) + state[path].value = device_put_replicated(val_on_host, sharding) + del val_on_host state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) @@ -470,7 +487,6 @@ def encode_prompt( negative_prompt_embeds: jax.Array = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt, @@ -480,6 +496,7 @@ def encode_prompt( prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) if negative_prompt_embeds is None: + batch_size = len(prompt_embeds) negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_embeds = self._get_t5_prompt_embeds( diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index 487cc85e6..39b74bae4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -119,9 +119,26 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): ) for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: - path = path[:-1] + if path[-1] == "value": + path = path[:-1] # remove 'value' + + try: + # Convert block indices to integers, as they might have been loaded as strings from the checkpoint. + path = path[:1] + (int(path[1]),) + path[2:] + except Exception: + pass + sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) + try: + state[path].value = device_put_replicated(val, sharding) + except Exception as e: + max_logging.log(f"Failed to device_put_replicated {path}: {e}") + max_logging.log(f"Trying to use process_allgather for {path}") + val_on_host = jax.experimental.multihost_utils.process_allgather( + val, tiled=True + ) + state[path].value = device_put_replicated(val_on_host, sharding) + del val_on_host state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) From f30daacb4a54033192f4944dbb595b90d1649c80 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:36:44 +0000 Subject: [PATCH 7/9] Wan training: Fix WAN training timestep sampling with continuous sampling and introduce disable_training_weights, add max_grad_norm and max_abs_grad logging. - Switched timestamp sampling from discrete to continuous. - Add max_grad_norm and max_abs_grad calculation and logging. - Introduced `config.disable_training_weights` to optionally disable mid-point loss weighting. Co-authored-by: martinarroyo --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_1_3b.yml | 1 + .../schedulers/scheduling_flow_match_flax.py | 58 +++++++++++++++++-- src/maxdiffusion/trainers/wan_trainer.py | 55 ++++++++++++------ 4 files changed, 91 insertions(+), 24 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 3856dd919..85d690ff5 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -282,6 +282,7 @@ output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1.0 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 +disable_training_weights: False # if True, disables the use of mid-point loss weighting # For creating tfrecords from dataset tfrecords_dir: '' diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index b43f256cd..d2382bd33 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -238,6 +238,7 @@ output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1.0 # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size: 0 +disable_training_weights: False # if True, disables the use of mid-point loss weighting # For creating tfrecords from dataset tfrecords_dir: '' diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py index 1f9c3a78e..a1991310e 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -150,11 +150,9 @@ def set_timesteps( linear_timesteps_weights = None if training: - x = timesteps - y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) - y_shifted = y - jnp.min(y) - bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) - linear_timesteps_weights = bsmntw_weighing + linear_timesteps_weights = self._calculate_training_weights( + timesteps, num_inference_steps + ) return state.replace( sigmas=sigmas, @@ -164,6 +162,56 @@ def set_timesteps( num_inference_steps=num_inference_steps, ) + def _calculate_training_weights( + self, timesteps: jnp.ndarray, num_inference_steps: int + ) -> jnp.ndarray: + """Calculates the training weight for a given timestep.""" + x = timesteps + y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - jnp.min(y) + bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted)) + linear_timesteps_weights = bsmntw_weighing + return linear_timesteps_weights + + def sample_timesteps(self, timestep_rng, batch_size): + # 1. Sample continuous timesteps t in [0, 1] + t = jax.random.uniform(timestep_rng, (batch_size,)) + + # 2. Apply the "Shift" weighting (Time shifting) + t_shifted = (t * self.config.shift) / (1 + (self.config.shift - 1) * t) + + # 3. Scale t to [0, self.config.num_train_timesteps] + timesteps = t_shifted.squeeze() * self.config.num_train_timesteps + + return timesteps + + def apply_flow_match( + self, + noise: jnp.ndarray, + batch_images: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Apply flow match to the batch of images. + + Replaces: scheduler.add_noise + scheduler.training_target + + scheduler.training_weight + """ + + t = timesteps.astype(jnp.float32) / self.config.num_train_timesteps + broadcast_shape = (-1,) + (1,) * (batch_images.ndim - 1) + t = t.reshape(broadcast_shape) + + sigma = (1 - t) * self.config.sigma_min + t * self.config.sigma_max + + noisy_latents = (1 - sigma) * batch_images + sigma * noise + target = noise - batch_images + + training_weights = self._calculate_training_weights( + timesteps, self.config.num_train_timesteps + ) + + return noisy_latents, target, training_weights + def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray: """Finds the index of the closest timestep in the scheduler's `timesteps` array.""" timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 62a73467a..3b27a933e 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -24,6 +24,7 @@ import tensorflow as tf import jax.numpy as jnp import jax +import jaxopt from jax.sharding import PartitionSpec as P from flax import nnx from maxdiffusion.schedulers import FlaxFlowMatchScheduler @@ -453,38 +454,53 @@ def loss_fn(params): model = nnx.merge(state.graphdef, params, state.rest_of_state) latents = data["latents"].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + bsz = latents.shape[0] - timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, + timesteps = scheduler.sample_timesteps(timestep_rng, bsz) + noise = jax.random.normal( + key=new_rng, shape=latents.shape, dtype=latents.dtype + ) + noisy_latents, training_target, training_weight = ( + scheduler.apply_flow_match(noise, latents, timesteps) ) - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) - noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - with jax.named_scope("forward_pass"): model_pred = model( hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, deterministic=False, - rngs=nnx.Rngs(dropout_rng), + rngs=nnx.Rngs(dropout=dropout_rng), ) with jax.named_scope("loss"): - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 - loss = loss * training_weight + if not config.disable_training_weights: + training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4)) + loss = loss * training_weight loss = jnp.mean(loss) return loss grad_fn = nnx.value_and_grad(loss_fn) loss, grads = grad_fn(state.params) + max_grad_norm = jaxopt.tree_util.tree_l2_norm(grads) + + max_abs_grad = jax.tree_util.tree_reduce( + lambda max_val, arr: jnp.maximum(max_val, jnp.max(jnp.abs(arr))), + grads, + initializer=-1.0, + ) + + metrics = { + "scalar": { + "learning/loss": loss, + "learning/max_grad_norm": max_grad_norm, + "learning/max_abs_grad": max_abs_grad, + }, + "scalars": {}, + } + new_state = state.apply_gradients(grads=grads) - metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, scheduler_state, metrics, new_rng @@ -495,14 +511,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). - @jax.jit def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): # Reconstruct the model from its definition and parameters model = nnx.merge(state.graphdef, params, state.rest_of_state) noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype) - noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - + noisy_latents, training_target, training_weight = ( + scheduler.apply_flow_match(noise, latents, timesteps) + ) # Get the model's prediction model_pred = model( hidden_states=noisy_latents, @@ -512,10 +528,11 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): ) # Calculate the loss against the target - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) loss = (training_target - model_pred) ** 2 - loss = loss * training_weight + if not config.disable_training_weights: + training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4)) + loss = loss * training_weight + # Calculate the mean loss per sample across all non-batch dimensions. loss = loss.reshape(loss.shape[0], -1).mean(axis=1) From 28fbabb347de6f7b18d48ba48ad453149827a91b Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:45:18 +0000 Subject: [PATCH 8/9] Abstract common WAN training components into BaseWanTrainer The following key functionalities have been moved from WanTrainer to the new `BaseWanTrainer` ABC: - Initialization and config handling - Scheduler creation - TFLOPs calculation - Core training and evaluation loops (`start_training`, `training_loop`, `eval`) - Abstract methods for checkpointer, data loading, sharding, and step functions. Co-authored-by: martinarroyo --- src/maxdiffusion/trainers/base_wan_trainer.py | 374 ++++++++++++++++++ src/maxdiffusion/trainers/wan_trainer.py | 342 +--------------- 2 files changed, 391 insertions(+), 325 deletions(-) create mode 100644 src/maxdiffusion/trainers/base_wan_trainer.py diff --git a/src/maxdiffusion/trainers/base_wan_trainer.py b/src/maxdiffusion/trainers/base_wan_trainer.py new file mode 100644 index 000000000..a5974c191 --- /dev/null +++ b/src/maxdiffusion/trainers/base_wan_trainer.py @@ -0,0 +1,374 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import abc +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +import datetime +import os +import pprint +import threading +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from flax.training import train_state +import jax +from jax.experimental import multihost_utils +import jax.numpy as jnp +from maxdiffusion import max_logging, max_utils, train_utils +from maxdiffusion.generate_wan import inference_generate_video +from maxdiffusion.generate_wan import run as generate_wan +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.schedulers import FlaxFlowMatchScheduler +from maxdiffusion.train_utils import ( _metrics_queue,_tensorboard_writer_worker, load_next_batch) +from maxdiffusion.utils import load_video +from maxdiffusion.video_processor import VideoProcessor +import numpy as np +from skimage.metrics import structural_similarity as ssim + + +class TrainState(train_state.TrainState): + graphdef: nnx.GraphDef + rest_of_state: nnx.State + + +def _to_array(x): + if not isinstance(x, jax.Array): + x = jnp.asarray(x) + return x + + +def generate_sample(config, pipeline, filename_prefix): + """ + Generates a video to validate training did not corrupt the model + """ + if not hasattr(pipeline, "vae"): + wan_vae, vae_cache = WanPipeline.load_vae( + pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config + ) + pipeline.vae = wan_vae + pipeline.vae_cache = vae_cache + return generate_wan(config, pipeline, filename_prefix) + + +def print_ssim(pretrained_video_path, posttrained_video_path): + video_processor = VideoProcessor() + pretrained_video = load_video(pretrained_video_path[0]) + pretrained_video = video_processor.preprocess_video(pretrained_video) + pretrained_video = np.array(pretrained_video) + pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) + pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2) + + posttrained_video = load_video(posttrained_video_path[0]) + posttrained_video = video_processor.preprocess_video(posttrained_video) + posttrained_video = np.array(posttrained_video) + posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) + posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2) + + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) + + max_logging.log(f"SSIM score after training is {ssim_compare}") + + +class BaseWanTrainer(abc.ABC): + + def __init__(self, config): + if config.train_text_encoder: + raise ValueError("this script currently doesn't support training text_encoders") + self.config = config + self.checkpointer = self._get_checkpointer() + + @abc.abstractmethod + def _get_checkpointer(self): + """Returns the checkpointer for the trainer.""" + + def post_training_steps(self, pipeline, params, train_states, msg=""): + pass + + def create_scheduler(self): + """Creates and initializes the Flow Match scheduler for training.""" + noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32) + noise_scheduler_state = noise_scheduler.create_state() + noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) + return noise_scheduler, noise_scheduler_state + + @staticmethod + def calculate_tflops(pipeline): + maxdiffusion_config = pipeline.config + # Model configuration + height = pipeline.config.height + width = pipeline.config.width + num_frames = pipeline.config.num_frames + + # Transformer dimensions + transformer_config = pipeline.transformer.config + num_layers = transformer_config.num_layers + heads = pipeline.transformer.config.num_attention_heads + head_dim = pipeline.transformer.config.attention_head_dim + ffn_dim = transformer_config.ffn_dim + seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4) + text_encoder_dim = 512 + # Attention FLOPS + # Self + self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2) + self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim)) + # Cross + cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2) + cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2) + cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim)) + + # Output_projection from attention + attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2) + + total_attn_flops = ( + self_attn_qkv_proj_flops + + self_attn_qk_v_flops + + cross_attn_kv_proj_flops + + cross_attn_q_proj_flops + + cross_attention_qk_v_flops + + attn_output_proj_flops + ) + + # FFN + ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim) + + flops_per_block = total_attn_flops + ffn_flops + + total_transformer_flops = flops_per_block * num_layers + + tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12 + train_tflops = 3 * tflops + + max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}") + return train_tflops, total_attn_flops, seq_len + + @abc.abstractmethod + def get_data_shardings(self, mesh): + """Returns data shardings for training.""" + + @abc.abstractmethod + def get_eval_data_shardings(self, mesh): + """Returns data shardings for evaluation.""" + + @abc.abstractmethod + def load_dataset(self, mesh, pipeline=None, is_training=True): + """Loads the dataset.""" + + @abc.abstractmethod + def get_train_step(self, pipeline, mesh, state_shardings, data_shardings): + """Returns the training step function.""" + + @abc.abstractmethod + def get_eval_step(self, pipeline, mesh, state_shardings, eval_data_shardings): + """Returns the evaluation step function.""" + + def start_training(self): + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + pipeline, opt_state, step = self.checkpointer.load_checkpoint() + restore_args = {} + if opt_state and step: + restore_args = {"opt_state": opt_state, "step": step} + del opt_state + if self.config.enable_ssim: + # Generate a sample before training to compare against generated sample after training. + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") + + if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): + # save some memory. + del pipeline.vae + del pipeline.vae_cache + + mesh = pipeline.mesh + train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True) + + # Load FlowMatch scheduler + scheduler, scheduler_state = self.create_scheduler() + pipeline.scheduler = scheduler + pipeline.scheduler_state = scheduler_state + optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer( + pipeline.transformer, self.config, self.config.learning_rate + ) + # Returns pipeline with trained transformer state + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) + + if self.config.enable_ssim: + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + print_ssim(pretrained_video_path, posttrained_video_path) + + def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): + eval_data_iterator = self.load_dataset(mesh, is_training=False) + eval_rng = eval_rng_key + eval_losses_by_timestep = {} + # Loop indefinitely until the iterator is exhausted + while True: + try: + eval_start_time = datetime.datetime.now() + eval_batch = load_next_batch(eval_data_iterator, None, self.config) + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) + metrics["scalar"]["learning/eval_loss"].block_until_ready() + losses = metrics["scalar"]["learning/eval_loss"] + timesteps = eval_batch["timesteps"] + gathered_losses = multihost_utils.process_allgather(losses, tiled=True) + gathered_losses = jax.device_get(gathered_losses) + gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True) + gathered_timesteps = jax.device_get(gathered_timesteps) + if jax.process_index() == 0: + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): + timestep = int(t) + if timestep not in eval_losses_by_timestep: + eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep].append(l) + eval_end_time = datetime.datetime.now() + eval_duration = eval_end_time - eval_start_time + max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.") + except StopIteration: + # This block is executed when the iterator has no more data + break + # Check if any evaluation was actually performed + if eval_losses_by_timestep and jax.process_index() == 0: + mean_per_timestep = [] + if jax.process_index() == 0: + max_logging.log(f"Step {step}, calculating mean loss per timestep...") + for timestep, losses in sorted(eval_losses_by_timestep.items()): + losses = jnp.array(losses) + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] + mean_loss = jnp.mean(losses) + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + mean_per_timestep.append(mean_loss) + final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") + if writer: + writer.add_scalar("learning/eval_loss", final_eval_loss, step) + + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}): + mesh = pipeline.mesh + graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) + + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + state = TrainState.create( + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state + ) + if restore_args: + step = restore_args.get("step", 0) + max_logging.log(f"Restoring optimizer and resuming from step {step}") + state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0)) + del restore_args["opt_state"] + del optimizer + state = jax.tree.map(_to_array, state) + state_spec = nnx.get_partition_spec(state) + state = jax.lax.with_sharding_constraint(state, state_spec) + state_shardings = nnx.get_named_sharding(state, mesh) + if jax.process_index() == 0 and restore_args: + max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") + pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) + max_logging.log(pretty_string) + max_logging.log("------------------------------------------------") + if self.config.hardware != "gpu": + max_utils.delete_pytree(params) + data_shardings = self.get_data_shardings(mesh) + eval_data_shardings = self.get_eval_data_shardings(mesh) + + writer = max_utils.initialize_summary_writer(self.config) + writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) + writer_thread.start() + + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + p_train_step = self.get_train_step( + pipeline, mesh, state_shardings, data_shardings + ) + p_eval_step = self.get_eval_step( + pipeline, mesh, state_shardings, eval_data_shardings + ) + + rng = jax.random.key(self.config.seed) + rng, eval_rng_key = jax.random.split(rng) + start_step = 0 + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + if restore_args.get("step", 0): + max_logging.log(f"Resuming training from step {step}") + start_step = restore_args.get("step", 0) + per_device_tflops, _, _ = BaseWanTrainer.calculate_tflops(pipeline) + scheduler_state = pipeline.scheduler_state + example_batch = load_next_batch(train_data_iterator, None, self.config) + + with ThreadPoolExecutor(max_workers=1) as executor: + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + start_step_time = datetime.datetime.now() + + next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( + self.config.logical_axis_rules + ): + state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) + train_metric["scalar"]["learning/loss"].block_until_ready() + last_step_completion = datetime.datetime.now() + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + train_utils.record_scalar_metrics( + train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step) + ) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + + if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0: + if self.config.enable_generate_video_for_eval: + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) + inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") + # Re-create the iterator each time you start evaluation to reset it + # This assumes your data loading logic can be called to get a fresh iterator. + self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer) + + example_batch = next_batch_future.result() + if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + if self.config.save_optimizer: + self.checkpointer.save_checkpoint(step, pipeline, state) + else: + self.checkpointer.save_checkpoint(step, pipeline, state.params) + + _metrics_queue.put(None) + writer_thread.join() + if writer: + writer.flush() + if self.config.save_final_checkpoint: + max_logging.log(f"Saving final checkpoint for step {step}") + self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) + self.checkpointer.checkpoint_manager.wait_until_finished() + # load new state for trained transformer + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) + return pipeline diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3b27a933e..1bc8371cc 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -14,145 +14,26 @@ limitations under the License. """ -import os -import datetime import functools -import pprint -import numpy as np -import threading -from concurrent.futures import ThreadPoolExecutor -import tensorflow as tf + +from flax import nnx import jax.numpy as jnp import jax -import jaxopt from jax.sharding import PartitionSpec as P -from flax import nnx -from maxdiffusion.schedulers import FlaxFlowMatchScheduler -from flax.linen import partitioning as nn_partitioning -from maxdiffusion import max_utils, max_logging, train_utils +import jaxopt from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 -from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) -from maxdiffusion.generate_wan import run as generate_wan -from maxdiffusion.generate_wan import inference_generate_video -from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue) -from maxdiffusion.video_processor import VideoProcessor -from maxdiffusion.utils import load_video -from skimage.metrics import structural_similarity as ssim -from flax.training import train_state -from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline -from jax.experimental import multihost_utils - - -class TrainState(train_state.TrainState): - graphdef: nnx.GraphDef - rest_of_state: nnx.State - - -def _to_array(x): - if not isinstance(x, jax.Array): - x = jnp.asarray(x) - return x - - -def generate_sample(config, pipeline, filename_prefix): - """ - Generates a video to validate training did not corrupt the model - """ - if not hasattr(pipeline, "vae"): - wan_vae, vae_cache = WanPipeline.load_vae( - pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config - ) - pipeline.vae = wan_vae - pipeline.vae_cache = vae_cache - return generate_wan(config, pipeline, filename_prefix) - - -def print_ssim(pretrained_video_path, posttrained_video_path): - video_processor = VideoProcessor() - pretrained_video = load_video(pretrained_video_path[0]) - pretrained_video = video_processor.preprocess_video(pretrained_video) - pretrained_video = np.array(pretrained_video) - pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) - pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2) - - posttrained_video = load_video(posttrained_video_path[0]) - posttrained_video = video_processor.preprocess_video(posttrained_video) - posttrained_video = np.array(posttrained_video) - posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) - posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2) - - ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) - - max_logging.log(f"SSIM score after training is {ssim_compare}") - - -class WanTrainer: - - def __init__(self, config): - if config.train_text_encoder: - raise ValueError("this script currently doesn't support training text_encoders") - self.config = config - self.checkpointer = WanCheckpointer2_1(config=config) - - def post_training_steps(self, pipeline, params, train_states, msg=""): - pass - - def create_scheduler(self): - """Creates and initializes the Flow Match scheduler for training.""" - noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32) - noise_scheduler_state = noise_scheduler.create_state() - noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) - return noise_scheduler, noise_scheduler_state - - @staticmethod - def calculate_tflops(pipeline): - maxdiffusion_config = pipeline.config - # Model configuration - height = pipeline.config.height - width = pipeline.config.width - num_frames = pipeline.config.num_frames - - # Transformer dimensions - transformer_config = pipeline.transformer.config - num_layers = transformer_config.num_layers - heads = pipeline.transformer.config.num_attention_heads - head_dim = pipeline.transformer.config.attention_head_dim - ffn_dim = transformer_config.ffn_dim - seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4) - text_encoder_dim = 512 - # Attention FLOPS - # Self - self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2) - self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim)) - # Cross - cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2) - cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2) - cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim)) - - # Output_projection from attention - attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2) - - total_attn_flops = ( - self_attn_qkv_proj_flops - + self_attn_qk_v_flops - + cross_attn_kv_proj_flops - + cross_attn_q_proj_flops - + cross_attention_qk_v_flops - + attn_output_proj_flops - ) - - # FFN - ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim) - - flops_per_block = total_attn_flops + ffn_flops +from maxdiffusion.input_pipeline.input_pipeline_interface import make_data_iterator +from maxdiffusion.trainers.base_wan_trainer import ( + BaseWanTrainer, + _to_array, +) +import tensorflow as tf - total_transformer_flops = flops_per_block * num_layers - tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12 - train_tflops = 3 * tflops +class WanTrainer(BaseWanTrainer): - max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}") - return train_tflops, total_attn_flops, seq_len + def _get_checkpointer(self): + return WanCheckpointer2_1(config=self.config) def get_data_shardings(self, mesh): data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) @@ -235,210 +116,21 @@ def prepare_sample_eval(features): ) return data_iterator - def start_training(self): - with nn_partitioning.axis_rules(self.config.logical_axis_rules): - pipeline, opt_state, step = self.checkpointer.load_checkpoint() - restore_args = {} - if opt_state and step: - restore_args = {"opt_state": opt_state, "step": step} - del opt_state - if self.config.enable_ssim: - # Generate a sample before training to compare against generated sample after training. - pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") - - if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): - # save some memory. - del pipeline.vae - del pipeline.vae_cache - - mesh = pipeline.mesh - train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True) - - # Load FlowMatch scheduler - scheduler, scheduler_state = self.create_scheduler() - pipeline.scheduler = scheduler - pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer( - pipeline.transformer, self.config, self.config.learning_rate - ) - # Returns pipeline with trained transformer state - pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) - - if self.config.enable_ssim: - posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") - print_ssim(pretrained_video_path, posttrained_video_path) - - def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): - eval_data_iterator = self.load_dataset(mesh, is_training=False) - eval_rng = eval_rng_key - eval_losses_by_timestep = {} - # Loop indefinitely until the iterator is exhausted - while True: - try: - eval_start_time = datetime.datetime.now() - eval_batch = load_next_batch(eval_data_iterator, None, self.config) - with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) - metrics["scalar"]["learning/eval_loss"].block_until_ready() - losses = metrics["scalar"]["learning/eval_loss"] - timesteps = eval_batch["timesteps"] - gathered_losses = multihost_utils.process_allgather(losses, tiled=True) - gathered_losses = jax.device_get(gathered_losses) - gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True) - gathered_timesteps = jax.device_get(gathered_timesteps) - if jax.process_index() == 0: - for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): - timestep = int(t) - if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] - eval_losses_by_timestep[timestep].append(l) - eval_end_time = datetime.datetime.now() - eval_duration = eval_end_time - eval_start_time - max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.") - except StopIteration: - # This block is executed when the iterator has no more data - break - # Check if any evaluation was actually performed - if eval_losses_by_timestep and jax.process_index() == 0: - mean_per_timestep = [] - if jax.process_index() == 0: - max_logging.log(f"Step {step}, calculating mean loss per timestep...") - for timestep, losses in sorted(eval_losses_by_timestep.items()): - losses = jnp.array(losses) - losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] - mean_loss = jnp.mean(losses) - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") - mean_per_timestep.append(mean_loss) - final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) - max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") - if writer: - writer.add_scalar("learning/eval_loss", final_eval_loss, step) - - def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}): - mesh = pipeline.mesh - graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) - - with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - state = TrainState.create( - apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state - ) - if restore_args: - step = restore_args.get("step", 0) - max_logging.log(f"Restoring optimizer and resuming from step {step}") - state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0)) - del restore_args["opt_state"] - del optimizer - state = jax.tree.map(_to_array, state) - state_spec = nnx.get_partition_spec(state) - state = jax.lax.with_sharding_constraint(state, state_spec) - state_shardings = nnx.get_named_sharding(state, mesh) - if jax.process_index() == 0 and restore_args: - max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") - pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) - max_logging.log(pretty_string) - max_logging.log("------------------------------------------------") - if self.config.hardware != "gpu": - max_utils.delete_pytree(params) - data_shardings = self.get_data_shardings(mesh) - eval_data_shardings = self.get_eval_data_shardings(mesh) - - writer = max_utils.initialize_summary_writer(self.config) - writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) - writer_thread.start() - - num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) - max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) - max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) - max_utils.add_config_to_summary_writer(self.config, writer) - - if jax.process_index() == 0: - max_logging.log("***** Running training *****") - max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") - max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}") - max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") - - p_train_step = jax.jit( + def get_train_step(self, pipeline, mesh, state_shardings, data_shardings): + return jax.jit( functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), in_shardings=(state_shardings, data_shardings, None, None), out_shardings=(state_shardings, None, None, None), donate_argnums=(0,), ) - p_eval_step = jax.jit( + + def get_eval_step(self, pipeline, mesh, state_shardings, eval_data_shardings): + return jax.jit( functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config), in_shardings=(state_shardings, eval_data_shardings, None, None), out_shardings=(None, None), ) - rng = jax.random.key(self.config.seed) - rng, eval_rng_key = jax.random.split(rng) - start_step = 0 - last_step_completion = datetime.datetime.now() - local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None - running_gcs_metrics = [] if self.config.gcs_metrics else None - first_profiling_step = self.config.skip_first_n_steps_for_profiler - if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = np.clip( - first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 - ) - if restore_args.get("step", 0): - max_logging.log(f"Resuming training from step {step}") - start_step = restore_args.get("step", 0) - per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline) - scheduler_state = pipeline.scheduler_state - example_batch = load_next_batch(train_data_iterator, None, self.config) - - with ThreadPoolExecutor(max_workers=1) as executor: - for step in np.arange(start_step, self.config.max_train_steps): - if self.config.enable_profiler and step == first_profiling_step: - max_utils.activate_profiler(self.config) - start_step_time = datetime.datetime.now() - - next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): - state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) - train_metric["scalar"]["learning/loss"].block_until_ready() - last_step_completion = datetime.datetime.now() - - if self.config.enable_profiler and step == last_profiling_step: - max_utils.deactivate_profiler(self.config) - - train_utils.record_scalar_metrics( - train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step) - ) - if self.config.write_metrics: - train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) - - if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0: - if self.config.enable_generate_video_for_eval: - pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) - inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") - # Re-create the iterator each time you start evaluation to reset it - # This assumes your data loading logic can be called to get a fresh iterator. - self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer) - - example_batch = next_batch_future.result() - if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: - max_logging.log(f"Saving checkpoint for step {step}") - if self.config.save_optimizer: - self.checkpointer.save_checkpoint(step, pipeline, state) - else: - self.checkpointer.save_checkpoint(step, pipeline, state.params) - - _metrics_queue.put(None) - writer_thread.join() - if writer: - writer.flush() - if self.config.save_final_checkpoint: - max_logging.log(f"Saving final checkpoint for step {step}") - self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) - self.checkpointer.checkpoint_manager.wait_until_finished() - # load new state for trained transformer - pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) - return pipeline - def train_step(state, data, rng, scheduler_state, scheduler, config): return step_optimizer(state, data, rng, scheduler_state, scheduler, config) From 7d8fdfb319517688292050745877d2852f2a80a6 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Wed, 11 Mar 2026 13:55:33 +0000 Subject: [PATCH 9/9] Add WAN-VACE training functionality Introduces training support for WAN-VACE models. New files: - train_wan_vace.py: Main training script. - wan_vace_trainer.py: Trainer class for WAN-VACE. - wan_vace_checkpointing_2_1.py: Checkpointing logic for WAN-VACE. Co-authored-by: martinarroyo --- .../wan_vace_checkpointer_2_1.py | 112 +++++++ .../pipelines/wan/wan_vace_pipeline_2_1.py | 72 ++++- src/maxdiffusion/train_wan_vace.py | 46 +++ src/maxdiffusion/trainers/wan_vace_trainer.py | 301 ++++++++++++++++++ 4 files changed, 522 insertions(+), 9 deletions(-) create mode 100644 src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py create mode 100644 src/maxdiffusion/train_wan_vace.py create mode 100644 src/maxdiffusion/trainers/wan_vace_trainer.py diff --git a/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py new file mode 100644 index 000000000..120dd6603 --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py @@ -0,0 +1,112 @@ +"""Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from typing import Optional, Tuple +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer +import numpy as np +import orbax.checkpoint as ocp +from .. import max_logging +from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 + + +class WanVaceCheckpointer2_1(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + + cpu_devices = np.array(jax.devices(backend="cpu")) + mesh = Mesh(cpu_devices, axis_names=("data",)) + replicated_sharding = NamedSharding(mesh, P()) + + metadatas = self.checkpoint_manager.item_metadata(step) + state = metadatas.wan_state + + def add_sharding_to_struct(leaf_struct, sharding): + return jax.ShapeDtypeStruct( + shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding + ) + + target_shardings = jax.tree_util.tree_map( + lambda x: replicated_sharding, state + ) + + with mesh: + abstract_train_state_with_sharding = jax.tree_util.tree_map( + add_sharding_to_struct, state, target_shardings + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + step=step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonRestore(), + wan_state=ocp.args.StandardRestore( + abstract_train_state_with_sharding + ), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = VaceWanPipeline2_1.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint( + self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict + ): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + + # Save the checkpoint + self.checkpoint_manager.save( + train_step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)), + wan_state=ocp.args.StandardSave(train_states), + ), + ) + + max_logging.log(f"Checkpoint for step {train_step} is saved.") diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index 39b74bae4..4596ae6e3 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -338,7 +338,14 @@ def load_transformer( return wan_transformer @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + def _load_and_init( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + load_common_components=True, + ): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) @@ -348,20 +355,31 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform scheduler = None scheduler_state = None text_encoder = None + wan_vae = None + vae_cache = None + if not vae_only: if load_transformer: with mesh: transformer = cls.load_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer" + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", ) + if load_common_components: + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) + scheduler, scheduler_state = cls.load_scheduler(config=config) - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + if load_common_components: + with mesh: + wan_vae, vae_cache = cls.load_vae( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config + ) pipeline = cls( tokenizer=tokenizer, @@ -376,7 +394,43 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform config=config, ) - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh) + return pipeline + + @classmethod + def from_pretrained( + cls, + config: HyperParameters, + vae_only=False, + load_transformer=True, + load_common_components=True, + ): + pipeline = cls._load_and_init( + config, None, vae_only, load_transformer, load_common_components + ) + pipeline.transformer = cls.quantize_transformer( + config, pipeline.transformer, pipeline, pipeline.mesh + ) + return pipeline + + @classmethod + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + load_common_components=True, + ): + pipeline = cls._load_and_init( + config, + restored_checkpoint, + vae_only, + load_transformer, + load_common_components, + ) + pipeline.transformer = cls.quantize_transformer( + config, pipeline.transformer, pipeline, pipeline.mesh + ) return pipeline def check_inputs( diff --git a/src/maxdiffusion/train_wan_vace.py b/src/maxdiffusion/train_wan_vace.py new file mode 100644 index 000000000..c8d8acd78 --- /dev/null +++ b/src/maxdiffusion/train_wan_vace.py @@ -0,0 +1,46 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import max_logging, pyconfig +from maxdiffusion.train_utils import validate_train_config +import flax + + +def train(config): + from maxdiffusion.trainers.wan_vace_trainer import WanVaceTrainer + + trainer = WanVaceTrainer(config) + trainer.start_training() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv, validate_training=True) + config = pyconfig.config + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + try: + flax.config.update("flax_always_shard_variable", False) + except LookupError: + pass + train(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/trainers/wan_vace_trainer.py b/src/maxdiffusion/trainers/wan_vace_trainer.py new file mode 100644 index 000000000..bfdf2a809 --- /dev/null +++ b/src/maxdiffusion/trainers/wan_vace_trainer.py @@ -0,0 +1,301 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools + +from flax import nnx +import jax.numpy as jnp +import jax +from jax.sharding import PartitionSpec as P +import jaxopt +from maxdiffusion.checkpointing.wan_vace_checkpointer_2_1 import WanVaceCheckpointer2_1 +from maxdiffusion.input_pipeline.input_pipeline_interface import make_data_iterator +from maxdiffusion.trainers.base_wan_trainer import ( + BaseWanTrainer, + _to_array, +) +import tensorflow as tf + + +class WanVaceTrainer(BaseWanTrainer): + + def _get_checkpointer(self): + return WanVaceCheckpointer2_1(config=self.config) + + def post_training_steps(self, pipeline, params, train_states, msg=""): + pass + + def get_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding( + mesh, P(*self.config.data_sharding) + ) + data_sharding = { + "latents": data_sharding, + "encoder_hidden_states": data_sharding, + "conditioning_latents": data_sharding, + } + return data_sharding + + def get_eval_data_shardings(self, mesh): + data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding)) + data_sharding = { + "latents": data_sharding, + "encoder_hidden_states": data_sharding, + "timesteps": data_sharding, + "conditioning_latents": data_sharding, + } + return data_sharding + + def load_dataset(self, mesh, pipeline=None, is_training=True): + config = self.config + + # If using synthetic data + if config.dataset_type == "synthetic": + return make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + config.global_batch_size_to_load, + pipeline=pipeline, # Pass pipeline to extract dimensions + is_training=is_training, + ) + + config = self.config + if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: + raise ValueError( + "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True" + ) + feature_description = { + "latents": tf.io.FixedLenFeature([], tf.string), + "encoder_hidden_states": tf.io.FixedLenFeature([], tf.string), + "conditioning_latents": tf.io.FixedLenFeature([], tf.string), + } + + if not is_training: + feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64) + + def prepare_sample_train(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor( + features["encoder_hidden_states"], out_type=tf.float32 + ) + conditioning_latents = tf.io.parse_tensor( + features["conditioning_latents"], out_type=tf.float32 + ) + return { + "latents": latents, + "encoder_hidden_states": encoder_hidden_states, + "conditioning_latents": conditioning_latents, + } + + def prepare_sample_eval(features): + latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32) + encoder_hidden_states = tf.io.parse_tensor( + features["encoder_hidden_states"], out_type=tf.float32 + ) + conditioning_latents = tf.io.parse_tensor( + features["conditioning_latents"], out_type=tf.float32 + ) + timesteps = features["timesteps"] + return { + "latents": latents, + "encoder_hidden_states": encoder_hidden_states, + "conditioning_latents": conditioning_latents, + "timesteps": timesteps, + } + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + config.global_batch_size_to_load, + feature_description=feature_description, + prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval, + is_training=is_training, + ) + return data_iterator + + def get_train_step(self, pipeline, mesh, state_shardings, data_shardings): + return jax.jit( + functools.partial( + train_step, scheduler=pipeline.scheduler, config=self.config + ), + in_shardings=(state_shardings, data_shardings, None, None), + out_shardings=(state_shardings, None, None, None), + donate_argnums=(0,), + ) + + def get_eval_step(self, pipeline, mesh, state_shardings, eval_data_shardings): + return jax.jit( + functools.partial( + eval_step, scheduler=pipeline.scheduler, config=self.config + ), + in_shardings=(state_shardings, eval_data_shardings, None, None), + out_shardings=(None, None), + ) + + +def train_step(state, data, rng, scheduler_state, scheduler, config): + return step_optimizer(state, data, rng, scheduler_state, scheduler, config) + + +def step_optimizer(state, data, rng, scheduler_state, scheduler, config): + _, new_rng, timestep_rng, dropout_rng = jax.random.split(rng, num=4) + + for k, v in data.items(): + data[k] = v[: config.global_batch_size_to_train_on, :] + + def loss_fn(params): + model = nnx.merge(state.graphdef, params, state.rest_of_state) + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype( + config.weights_dtype + ) + control_hidden_states = data["conditioning_latents"].astype( + config.weights_dtype + ) + + bsz = latents.shape[0] + timesteps = scheduler.sample_timesteps(timestep_rng, bsz) + noise = jax.random.normal( + key=new_rng, shape=latents.shape, dtype=latents.dtype + ) + noisy_latents, training_target, training_weight = ( + scheduler.apply_flow_match(noise, latents, timesteps) + ) + with jax.named_scope("forward_pass"): + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + control_hidden_states=control_hidden_states, + deterministic=False, + rngs=nnx.Rngs(dropout=dropout_rng), + ) + + with jax.named_scope("loss"): + model_pred = model_pred.astype(jnp.float32) + training_target = training_target.astype(jnp.float32) + loss = (training_target - model_pred) ** 2 + if not config.disable_training_weights: + training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4)) + loss = loss * training_weight + loss = jnp.mean(loss) + + return loss + + grad_fn = nnx.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + max_grad_norm = jaxopt.tree_util.tree_l2_norm(grads) + + max_abs_grad = jax.tree_util.tree_reduce( + lambda max_val, arr: jnp.maximum(max_val, jnp.max(jnp.abs(arr))), + grads, + initializer=-1.0, + ) + + metrics = { + "scalar": { + "learning/loss": loss, + "learning/max_grad_norm": max_grad_norm, + "learning/max_abs_grad": max_abs_grad, + }, + "scalars": {}, + } + + new_state = state.apply_gradients(grads=grads) + return new_state, scheduler_state, metrics, new_rng + + +def eval_step(state, data, rng, scheduler_state, scheduler, config): + """ + Computes the evaluation loss for a single batch without updating model weights. + """ + + # The loss function logic is identical to training. We are evaluating the model's + # ability to perform its core training objective (e.g., denoising). + def loss_fn( + params, + latents, + encoder_hidden_states, + timesteps, + rng, + conditioning_latents, + ): + # Reconstruct the model from its definition and parameters + model = nnx.merge(state.graphdef, params, state.rest_of_state) + + noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype) + noisy_latents, training_target, training_weight = ( + scheduler.apply_flow_match(noise, latents, timesteps) + ) + # Get the model's prediction + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + control_hidden_states=conditioning_latents, + deterministic=True, + ) + + # Calculate the loss against the target + model_pred = model_pred.astype(jnp.float32) + training_target = training_target.astype(jnp.float32) + + loss = (training_target - model_pred) ** 2 + if not config.disable_training_weights: + training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4)) + loss = loss * training_weight + + # Calculate the mean loss per sample across all non-batch dimensions. + loss = loss.reshape(loss.shape[0], -1).mean(axis=1) + + return loss + + # --- Key Difference from train_step --- + # Directly compute the loss without calculating gradients. + # The model's state.params are used but not updated. + # TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop + bs = len(data["latents"]) + single_batch_size = config.global_batch_size_to_train_on + losses = jnp.zeros(bs) + for i in range(0, bs, single_batch_size): + start = i + end = min(i + single_batch_size, bs) + latents = data["latents"][start:end, :].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) + conditioning_latents = data["conditioning_latents"][start:end, :].astype( + config.weights_dtype + ) + timesteps = data["timesteps"][start:end].astype("int64") + _, new_rng = jax.random.split(rng, num=2) + loss = loss_fn( + state.params, + latents, + encoder_hidden_states, + timesteps, + new_rng, + conditioning_latents, + ) + losses = losses.at[start:end].set(loss) + + # Structure the metrics for logging and aggregation + metrics = {"scalar": {"learning/eval_loss": losses}} + + # Return the computed metrics and the new RNG key for the next eval step + return metrics, new_rng