-
Notifications
You must be signed in to change notification settings - Fork 63
Refactor Wan Model Training & Add Wan-VACE Training Support #352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fb25b23
e205aa1
1fe4ce0
5c6f65f
6101386
efbc91d
f30daac
28fbabb
7d8fdfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,15 @@ | |
| """ | ||
|
|
||
| import json | ||
| import jax | ||
| import numpy as np | ||
| from typing import Optional, Tuple | ||
| from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 | ||
| from .. import max_logging | ||
| import orbax.checkpoint as ocp | ||
| from etils import epath | ||
| import jax | ||
| from jax.sharding import Mesh, NamedSharding, PartitionSpec as P | ||
| from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer | ||
| import numpy as np | ||
| import orbax.checkpoint as ocp | ||
| from .. import max_logging | ||
| from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 | ||
|
|
||
|
|
||
| class WanCheckpointer2_1(WanCheckpointer): | ||
|
|
@@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic | |
| max_logging.log("No WAN checkpoint found.") | ||
| return None, None | ||
| max_logging.log(f"Loading WAN checkpoint from step {step}") | ||
|
|
||
| cpu_devices = np.array(jax.devices(backend="cpu")) | ||
| mesh = Mesh(cpu_devices, axis_names=("data",)) | ||
| replicated_sharding = NamedSharding(mesh, P()) | ||
|
|
||
| metadatas = self.checkpoint_manager.item_metadata(step) | ||
| transformer_metadata = metadatas.wan_state | ||
| abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) | ||
| state = metadatas.wan_state | ||
|
|
||
| def add_sharding_to_struct(leaf_struct, sharding): | ||
| return jax.ShapeDtypeStruct( | ||
| shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding | ||
| ) | ||
|
|
||
| target_shardings = jax.tree_util.tree_map( | ||
| lambda x: replicated_sharding, state | ||
| ) | ||
|
|
||
| with mesh: | ||
| abstract_train_state_with_sharding = jax.tree_util.tree_map( | ||
| add_sharding_to_struct, state, target_shardings | ||
| ) | ||
|
|
||
| params_restore = ocp.args.PyTreeRestore( | ||
| restore_args=jax.tree.map( | ||
| lambda _: ocp.RestoreArgs(restore_type=np.ndarray), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing restore_type = np.ndarray causes the JAX sharding applied above to be redundant. (JAX sharding cannot work on np.ndarrays). Suggest to make it jax.Array to ensure checkpoint is loaded on host in sharded manner if that's intended |
||
| abstract_tree_structure_params, | ||
| abstract_train_state_with_sharding, | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| """Copyright 2025 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import json | ||
| from typing import Optional, Tuple | ||
| import jax | ||
| from jax.sharding import Mesh, NamedSharding, PartitionSpec as P | ||
| from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer | ||
| import numpy as np | ||
| import orbax.checkpoint as ocp | ||
| from .. import max_logging | ||
| from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 | ||
|
|
||
|
|
||
| class WanVaceCheckpointer2_1(WanCheckpointer): | ||
|
|
||
| def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: | ||
| if step is None: | ||
| step = self.checkpoint_manager.latest_step() | ||
| max_logging.log(f"Latest WAN checkpoint step: {step}") | ||
| if step is None: | ||
| max_logging.log("No WAN checkpoint found.") | ||
| return None, None | ||
| max_logging.log(f"Loading WAN checkpoint from step {step}") | ||
|
|
||
| cpu_devices = np.array(jax.devices(backend="cpu")) | ||
| mesh = Mesh(cpu_devices, axis_names=("data",)) | ||
| replicated_sharding = NamedSharding(mesh, P()) | ||
|
|
||
| metadatas = self.checkpoint_manager.item_metadata(step) | ||
| state = metadatas.wan_state | ||
|
|
||
| def add_sharding_to_struct(leaf_struct, sharding): | ||
| return jax.ShapeDtypeStruct( | ||
| shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding | ||
| ) | ||
|
|
||
| target_shardings = jax.tree_util.tree_map( | ||
| lambda x: replicated_sharding, state | ||
| ) | ||
|
|
||
| with mesh: | ||
| abstract_train_state_with_sharding = jax.tree_util.tree_map( | ||
| add_sharding_to_struct, state, target_shardings | ||
| ) | ||
|
|
||
| max_logging.log("Restoring WAN checkpoint") | ||
| restored_checkpoint = self.checkpoint_manager.restore( | ||
| step=step, | ||
| args=ocp.args.Composite( | ||
| wan_config=ocp.args.JsonRestore(), | ||
| wan_state=ocp.args.StandardRestore( | ||
| abstract_train_state_with_sharding | ||
| ), | ||
| ), | ||
| ) | ||
| max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") | ||
| max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") | ||
| max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") | ||
| max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") | ||
| return restored_checkpoint, step | ||
|
|
||
| def load_diffusers_checkpoint(self): | ||
| pipeline = VaceWanPipeline2_1.from_pretrained(self.config) | ||
| return pipeline | ||
|
|
||
| def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]: | ||
| restored_checkpoint, step = self.load_wan_configs_from_orbax(step) | ||
| opt_state = None | ||
| if restored_checkpoint: | ||
| max_logging.log("Loading WAN pipeline from checkpoint") | ||
| pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) | ||
| if "opt_state" in restored_checkpoint.wan_state.keys(): | ||
| opt_state = restored_checkpoint.wan_state["opt_state"] | ||
| else: | ||
| max_logging.log("No checkpoint found, loading default pipeline.") | ||
| pipeline = self.load_diffusers_checkpoint() | ||
|
|
||
| return pipeline, opt_state, step | ||
|
|
||
| def save_checkpoint( | ||
| self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict | ||
| ): | ||
| """Saves the training state and model configurations.""" | ||
|
|
||
| def config_to_json(model_or_config): | ||
| return json.loads(model_or_config.to_json_string()) | ||
|
|
||
| max_logging.log(f"Saving checkpoint for step {train_step}") | ||
|
|
||
| # Save the checkpoint | ||
| self.checkpoint_manager.save( | ||
| train_step, | ||
| args=ocp.args.Composite( | ||
| wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)), | ||
| wan_state=ocp.args.StandardSave(train_states), | ||
| ), | ||
| ) | ||
|
|
||
| max_logging.log(f"Checkpoint for step {train_step} is saved.") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,7 +72,7 @@ mask_padding_tokens: True | |
| # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded | ||
| # in cross attention q. | ||
| attention_sharding_uniform: True | ||
| dropout: 0.1 | ||
| dropout: 0.0 | ||
|
|
||
| flash_block_sizes: { | ||
| "block_q" : 512, | ||
|
|
@@ -145,6 +145,7 @@ diffusion_scheduler_config: { | |
| # Output directory | ||
| # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" | ||
| base_output_directory: "" | ||
| tensorboard_dir: "" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tensorboard_dir is created automatically inside the pyconfig. Is there a reason it needs to be in the config? |
||
|
|
||
| # Hardware | ||
| hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' | ||
|
|
@@ -281,6 +282,7 @@ output_dir: 'sdxl-model-finetuned' | |
| per_device_batch_size: 1.0 | ||
| # If global_batch_size % jax.device_count is not 0, use FSDP sharding. | ||
| global_batch_size: 0 | ||
| disable_training_weights: False # if True, disables the use of mid-point loss weighting | ||
|
|
||
| # For creating tfrecords from dataset | ||
| tfrecords_dir: '' | ||
|
|
@@ -300,7 +302,7 @@ save_optimizer: False | |
| adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. | ||
| adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. | ||
| adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. | ||
| adam_weight_decay: 0 # AdamW Weight decay | ||
| adam_weight_decay: 0.0 # AdamW Weight decay | ||
| max_grad_norm: 1.0 | ||
|
|
||
| enable_profiler: False | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A safer way to do this to prevent unexpected crashes (for any elements not having shape/dtype attributes):