Skip to content
36 changes: 28 additions & 8 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1


class WanCheckpointer2_1(WanCheckpointer):
Expand All @@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

metadatas = self.checkpoint_manager.item_metadata(step)
transformer_metadata = metadatas.wan_state
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A safer way to do this to prevent unexpected crashes (for any elements not having shape/dtype attributes):

def add_sharding_to_struct(leaf_struct, sharding):
      struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
      if hasattr(struct, "shape") and hasattr(struct, "dtype"): 
        return jax.ShapeDtypeStruct(
            shape=struct.shape, dtype=struct.dtype, sharding=sharding
        )
      return struct 

return jax.ShapeDtypeStruct(
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
)

target_shardings = jax.tree_util.tree_map(
lambda x: replicated_sharding, state
)

with mesh:
abstract_train_state_with_sharding = jax.tree_util.tree_map(
add_sharding_to_struct, state, target_shardings
)

params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing restore_type = np.ndarray causes the JAX sharding applied above to be redundant. (JAX sharding cannot work on np.ndarrays). Suggest to make it jax.Array to ensure checkpoint is loaded on host in sharded manner if that's intended

abstract_tree_structure_params,
abstract_train_state_with_sharding,
)
)

Expand Down
36 changes: 28 additions & 8 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1


class WanCheckpointerI2V_2_1(WanCheckpointer):
Expand All @@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

metadatas = self.checkpoint_manager.item_metadata(step)
transformer_metadata = metadatas.wan_state
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
return jax.ShapeDtypeStruct(
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
)

target_shardings = jax.tree_util.tree_map(
lambda x: replicated_sharding, state
)

with mesh:
abstract_train_state_with_sharding = jax.tree_util.tree_map(
add_sharding_to_struct, state, target_shardings
)

params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_params,
abstract_train_state_with_sharding,
)
)

Expand Down
112 changes: 112 additions & 0 deletions src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
from typing import Optional, Tuple
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
import numpy as np
import orbax.checkpoint as ocp
from .. import max_logging
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1


class WanVaceCheckpointer2_1(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")

cpu_devices = np.array(jax.devices(backend="cpu"))
mesh = Mesh(cpu_devices, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, P())

metadatas = self.checkpoint_manager.item_metadata(step)
state = metadatas.wan_state

def add_sharding_to_struct(leaf_struct, sharding):
return jax.ShapeDtypeStruct(
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
)

target_shardings = jax.tree_util.tree_map(
lambda x: replicated_sharding, state
)

with mesh:
abstract_train_state_with_sharding = jax.tree_util.tree_map(
add_sharding_to_struct, state, target_shardings
)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
step=step,
args=ocp.args.Composite(
wan_config=ocp.args.JsonRestore(),
wan_state=ocp.args.StandardRestore(
abstract_train_state_with_sharding
),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(
self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict
):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")

# Save the checkpoint
self.checkpoint_manager.save(
train_step,
args=ocp.args.Composite(
wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)),
wan_state=ocp.args.StandardSave(train_states),
),
)

max_logging.log(f"Checkpoint for step {train_step} is saved.")
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mask_padding_tokens: True
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1
dropout: 0.0

flash_block_sizes: {
"block_q" : 512,
Expand Down Expand Up @@ -145,6 +145,7 @@ diffusion_scheduler_config: {
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorboard_dir is created automatically inside the pyconfig. Is there a reason it needs to be in the config?


# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
Expand Down Expand Up @@ -281,6 +282,7 @@ output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1.0
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0
disable_training_weights: False # if True, disables the use of mid-point loss weighting

# For creating tfrecords from dataset
tfrecords_dir: ''
Expand All @@ -300,7 +302,7 @@ save_optimizer: False
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0 # AdamW Weight decay
adam_weight_decay: 0.0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mask_padding_tokens: True
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1
dropout: 0.0

flash_block_sizes: {
"block_q" : 512,
Expand Down Expand Up @@ -122,6 +122,7 @@ diffusion_scheduler_config: {
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
Expand Down Expand Up @@ -237,6 +238,7 @@ output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1.0
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0
disable_training_weights: False # if True, disables the use of mid-point loss weighting

# For creating tfrecords from dataset
tfrecords_dir: ''
Expand All @@ -256,7 +258,7 @@ save_optimizer: False
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0 # AdamW Weight decay
adam_weight_decay: 0.0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand Down
5 changes: 3 additions & 2 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ mask_padding_tokens: True
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1
dropout: 0.0

flash_block_sizes: {
"block_q" : 512,
Expand Down Expand Up @@ -133,6 +133,7 @@ diffusion_scheduler_config: {
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
Expand Down Expand Up @@ -267,7 +268,7 @@ save_optimizer: False
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0 # AdamW Weight decay
adam_weight_decay: 0.0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand Down
5 changes: 3 additions & 2 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
dropout: 0.1
dropout: 0.0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
Expand Down Expand Up @@ -128,6 +128,7 @@ diffusion_scheduler_config: {
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
Expand Down Expand Up @@ -262,7 +263,7 @@ save_optimizer: False
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0 # AdamW Weight decay
adam_weight_decay: 0.0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand Down
5 changes: 3 additions & 2 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
dropout: 0.1
dropout: 0.0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
Expand Down Expand Up @@ -129,6 +129,7 @@ diffusion_scheduler_config: {
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
Expand Down Expand Up @@ -263,7 +264,7 @@ save_optimizer: False
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0 # AdamW Weight decay
adam_weight_decay: 0.0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand Down
12 changes: 10 additions & 2 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def _make_tfrecord_iterator(
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
used_feature_description = (
feature_description_fn
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
else feature_description
)

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, used_feature_description)
Expand Down Expand Up @@ -141,7 +145,11 @@ def prepare_sample(features):
ds = ds.concatenate(padding_ds)
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")

used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
used_prepare_sample = (
prepare_sample_fn
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
else prepare_sample
)
ds = (
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
Expand Down
Loading
Loading