diff --git a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py new file mode 100644 index 0000000000..d88d82079b --- /dev/null +++ b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py @@ -0,0 +1,283 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a HuggingFace LoRA adapter to MaxText LoRA adapter format. + +Key Parameters (to be set in the config file or as command-line overrides): + model_name: (Required) The name of the model (e.g., "llama3.1-8b"). + base_output_directory: (Required) The directory where the MaxText LoRA adapter + will be saved. Can be set in config file or as command-line override. + hf_lora_adapter_path: (Required) Path to the HF LoRA adapter directory or HuggingFace repo ID. + scan_layers: (bool) Whether the MaxText model uses scanned layers. + This must match the training configuration. + +Environment Variables: + HF_AUTH_TOKEN: (Optional) HuggingFace authentication token if needed for adapter. + +Example Usage: + To convert HF LoRA to MaxText adapter: + + python src/maxtext/ckpt_conversion/apply_lora.py \ + MaxText/configs/sft.yml model_name="llama3.1-8b" \ + hf_lora_adapter_path="username/lora-adapter-repo" \ + base_output_directory="/path/to/output/directory" \ + scan_layers=False +""" + +import argparse +import json +import os +import shutil +import sys +from typing import Sequence + +import jax +import jax.numpy as jnp +from etils import epath +from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files +from safetensors import safe_open +from transformers import AutoConfig + +from orbax import checkpoint as ocp +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING +from maxtext.checkpoint_conversion.utils.utils import HF_IDS +from maxtext.configs import pyconfig +from maxtext.utils import max_logging +from absl import logging + + +def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: + """Load HF LoRA adapter weights directly from safetensors files.""" + max_logging.log(f"Loading HF LoRA adapter from {adapter_path}") + + # Check adapter compatibility + adapter_config = None + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + config_file = adapter_dir / "adapter_config.json" + if config_file.exists(): + with open(config_file, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + else: + # HF Hub repo + try: + config_file = hf_hub_download(adapter_path, "adapter_config.json", token=os.environ.get("HF_AUTH_TOKEN")) + with open(config_file, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + except Exception as exc: # pylint: disable=broad-exception-caught + max_logging.log(f"Warning: Could not load adapter_config.json from HF Hub: {exc}") + + if adapter_config: + if adapter_config.get("base_model_name_or_path"): + max_logging.log(f"Adapter base model: {adapter_config['base_model_name_or_path']}") + # if base_model and base_model.replace("-Instruct", "") != hf_model_id.replace("-Instruct", ""): + # raise ValueError(f"Adapter base model '{base_model}' does not match expected model '{hf_model_id}'") + max_logging.log(f"Adapter compatible with model {hf_model_id}") + + # Handle both local paths and HF Hub paths + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + adapter_files = list(adapter_dir.glob("*.safetensors")) + if not adapter_files: + adapter_files = list(adapter_dir.glob("*.bin")) + if not adapter_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = adapter_files[0] + else: + # Assume it's a HF Hub repo ID + try: + files = list_repo_files(adapter_path, token=os.environ.get("HF_AUTH_TOKEN")) + safetensor_files = [f for f in files if f.endswith(".safetensors")] + if not safetensor_files: + bin_files = [f for f in files if f.endswith(".bin")] + if not bin_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = bin_files[0] + else: + adapter_file = safetensor_files[0] + + # Download the adapter file + adapter_file = hf_hub_download(adapter_path, adapter_file, token=os.environ.get("HF_AUTH_TOKEN")) + except Exception as e: + raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") from e + + # Load the adapter weights + if adapter_file.endswith(".safetensors"): + with safe_open(adapter_file, framework="numpy") as f: + lora_weights = {k: f.get_tensor(k) for k in f.keys()} + else: + # For .bin files, we'd need torch.load, but safetensors is preferred + raise ValueError(f"Unsupported adapter file format: {adapter_file}") + + max_logging.log(f"Loaded {len(lora_weights)} LoRA parameters from adapter") + return lora_weights + + +def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict) -> str: + """Convert HF LoRA key to MaxText parameter path using the mapping from to_maxtext.py.""" + # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight + + # 1. Clean up LoRA suffixes to get the base module path + # e.g. ...q_proj.lora_A.weight -> ...q_proj + hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "") + hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") + + # 2. Handle prefix. Expected target is usually "model.layers..." + # Input could be "base_model.model.model.layers..." or "base_model.model.layers..." + if hf_param_key.startswith("base_model.model."): + hf_param_key = hf_param_key[len("base_model.model.") :] + + # 3. Search for the corresponding MaxText key + for mt_key, hf_keys in param_mapping.items(): + if isinstance(hf_keys, list): + for hf_k in hf_keys: + # Match disregarding .weight suffix on the base model param + if hf_k.replace(".weight", "") == hf_param_key: + return mt_key + elif isinstance(hf_keys, str): + if hf_keys.replace(".weight", "") == hf_param_key: + return mt_key + + return None + + +def convert_lora_to_maxtext_adapter(config, lora_weights: dict, output_path: str, hf_model_id: str): + """Converts HF LoRA weights to MaxText adapter format without merging.""" + + hf_token = config.hf_access_token + + # Get the parameter mapping (MT -> HF) + model_key = config.model_name + if "-Instruct" in model_key: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_key = model_key.replace("-Instruct", "") + hf_config_obj = AutoConfig.from_pretrained(hf_model_id, token=hf_token) + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) + + # 2. Initialize an empty dictionary for the MaxText Adapter + mt_adapter_tree = {} + mapped_count = 0 + + # 3. Map HF LoRA weights to MaxText keys + for hf_key, weight in lora_weights.items(): + # Identify the MaxText path for this specific HF weight + mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf) + + if mt_key: + # Determine if this is the 'A' or 'B' matrix + suffix = "lora_A" if "lora_A" in hf_key else "lora_B" + + # Construct a nested dictionary path in mt_adapter_tree + # MaxText expects: { 'decoder': { 'layers': { '0': { 'query': { 'lora_A': ... } } } } } + parts = mt_key.split("/") + current = mt_adapter_tree + for part in parts: + if part not in current: + current[part] = {} + current = current[part] + + # Convert weight to JAX array and store + current[suffix] = jnp.array(weight) + mapped_count += 1 + else: + max_logging.log(f"Warning: Could not map HF LoRA key {hf_key} to MaxText key") + + max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") + + # 4. Save as a standalone adapter checkpoint + max_logging.log(f"Saving MaxText LoRA adapter to {output_path}") + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + ckptr.save(epath.Path(output_path), mt_adapter_tree) + + max_logging.log("LoRA adapter conversion completed successfully") + + +def main(args: Sequence[str]) -> None: + # Set logging to INFO level to see max_logging.log messages + logging.set_verbosity(logging.INFO) + + # Check if the user is using an Instruct version. If so, use the base model architecture + original_model_name = None + for i, arg in enumerate(args): + if arg.startswith("model_name="): + model_name_arg = args[i].split("=")[1] + # Remove quotes if present + model_name_arg = model_name_arg.strip("'").strip('"') + original_model_name = model_name_arg + + if "-Instruct" in model_name_arg: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_name_arg = model_name_arg.replace("-Instruct", "") + args[i] = f"model_name={model_name_arg}" + break + + # Initialize maxtext config + config = pyconfig.initialize(args) + + if not hasattr(config, "hf_lora_adapter_path") or not config.hf_lora_adapter_path: + raise ValueError("hf_lora_adapter_path must be specified") + + # Determine HF model ID and check if supported + hf_model_id = HF_IDS.get(config.model_name) + if hf_model_id is None: + raise ValueError(f"Model '{config.model_name}' is not supported. Use a supported model_name from HF_IDS.") + + if not hasattr(config, "base_output_directory") or not config.base_output_directory: + raise ValueError("base_output_directory must be specified (in config file or as command-line argument)") + + output_dir = config.base_output_directory + + # Use original model name for output path + model_name_for_path = original_model_name or config.model_name + adapter_name = os.path.basename(config.hf_lora_adapter_path) + full_output_path = os.path.join(output_dir, model_name_for_path, adapter_name) + + os.makedirs(os.path.dirname(full_output_path), exist_ok=True) + + if os.path.exists(full_output_path): + max_logging.log(f"Output directory {full_output_path} exists. Removing it to allow Orbax to save.") + shutil.rmtree(full_output_path) + + # Load LoRA adapter and check compatibility + lora_weights = load_hf_lora_adapter(config.hf_lora_adapter_path, hf_model_id) + + # Convert LoRA to MaxText adapter format and save + convert_lora_to_maxtext_adapter(config, lora_weights, full_output_path, hf_model_id) + + # Verify output was created #using epath for local file and gcs compatibility + outputpath = epath.Path(full_output_path) + if not outputpath.exists(): + raise RuntimeError(f"Failed to create output directory {full_output_path}") + + +if __name__ == "__main__": + # Argument parsing similar to to_maxtext.py + parser = argparse.ArgumentParser() + parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) + + # Parse local arguments + local_args, remaining_args = parser.parse_known_args() + + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set jax environment + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" + + main(model_args) diff --git a/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py new file mode 100644 index 0000000000..7e2d993ebb --- /dev/null +++ b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py @@ -0,0 +1,143 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a MaxText LoRA adapter (checkpoint) back to HuggingFace PEFT format. + +Key Parameters: + model_name: The name of the model (e.g., "llama3.1-8b"). + maxtext_ckpt_path: Path to the MaxText checkpoint directory (e.g., .../checkpoints/100/model_params). + hf_model_id: The base HuggingFace model ID for config mapping. + output_dir: The directory where the HuggingFace adapter will be saved. + lora_r: The rank of the LoRA adapter. + lora_alpha: The alpha parameter for LoRA. + +Example Usage: + python src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py \ + model_name="llama3.1-8b" \ + maxtext_ckpt_path="/path/to/maxtext_lora/ckpt" \ + hf_model_id="meta-llama/Llama-3.1-8B" \ + output_dir="/path/to/hf/adapter/output" +""" + +import os +import json +import numpy as np +import sys +from safetensors.numpy import save_file +from orbax import checkpoint as ocp +from etils import epath +from transformers import AutoConfig +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING + +def parse_args(args): + """Parses command line arguments in the format key=value.""" + parsed_args = {} + for arg in args: + if "=" in arg: + key, value = arg.split("=", 1) + parsed_args[key] = value + return parsed_args + +def convert(model_name, maxtext_ckpt_path, hf_model_id, output_dir, lora_r=16, lora_alpha=32): + print(f"[*] Starting conversion from {maxtext_ckpt_path}") + + # Initialize Orbax Checkpointer + mngr = ocp.PyTreeCheckpointer() + mt_params = mngr.restore(epath.Path(maxtext_ckpt_path)) + + # Load HF Config for mapping + hf_config = AutoConfig.from_pretrained(hf_model_id).to_dict() + + class MockConfig: + scan_layers = True + model_name = "llama3.1-8b" + + # Get the parameter mapping for the specific model + mapping = PARAM_MAPPING[model_name](hf_config, MockConfig(), scan_layers=True) + final_hf_weights = {} + + def process_data(current_dict, parent_path="decoder/layers"): + """Recursive function to traverse MaxText params and map to HF.""" + for module_name, content in current_dict.items(): + path = f"{parent_path}/{module_name}" + + # Identify LoRA layers + if isinstance(content, dict) and 'kernel_lora_a' in content: + lookup_key = "params-" + path.replace("/", "-") + "-kernel" + + if lookup_key in mapping: + # Get the JAX values (as numpy) + data_a = np.array(content['kernel_lora_a']['value']) + data_b = np.array(content['kernel_lora_b']['value']) + hf_paths = mapping[lookup_key] + + # MaxText stacks multiple heads/projections, iterate through them + for i in range(data_a.shape[1]): + name = hf_paths[i].replace(".weight", "") + # Apply Transpose (.T) to match PyTorch dimension logic + final_hf_weights[f"base_model.model.{name}.lora_A.weight"] = data_a[:, i, :].T + final_hf_weights[f"base_model.model.{name}.lora_B.weight"] = data_b[:, i, :].T + + print(f"[DEBUG] Processed: {path}") + + elif isinstance(content, dict): + process_data(content, path) + + # Start recursion + process_data(mt_params['decoder']['layers']) + + # Save Safetensors + os.makedirs(output_dir, exist_ok=True) + adapter_file = os.path.join(output_dir, "adapter_model.safetensors") + save_file(final_hf_weights, adapter_file) + + # Create PEFT adapter_config.json + config_json = { + "base_model_name_or_path": hf_model_id, + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": int(lora_r), + "lora_alpha": int(lora_alpha), + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0.0, + "bias": "none", + "inference_mode": True + } + + config_file = os.path.join(output_dir, "adapter_config.json") + with open(config_file, "w") as f: + json.dump(config_json, f, indent=4) + + print(f"\n[!] Conversion Complete!") + print(f" Saved weights to: {adapter_file}") + print(f" Saved config to: {config_file}") + +if __name__ == "__main__": + cli_args = parse_args(sys.argv[1:]) + + # Required parameters check + required = ["model_name", "maxtext_ckpt_path", "hf_model_id", "output_dir"] + if not all(k in cli_args for k in required): + print(__doc__) + sys.exit(1) + + convert( + model_name=cli_args["model_name"], + maxtext_ckpt_path=cli_args["maxtext_ckpt_path"], + hf_model_id=cli_args["hf_model_id"], + output_dir=cli_args["output_dir"], + lora_r=cli_args.get("lora_r", 16), + lora_alpha=cli_args.get("lora_alpha", 32) + ) \ No newline at end of file diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index d97b2f3256..b55a54e4e8 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -718,7 +718,7 @@ autoregressive_decode_assert: "" # For nsys profiler, pass the training command to nsys command # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} -profiler: "" # Supported profiler: '', xplane, nsys +profiler: "xplane" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. upload_all_profiler_results: False # Skip first n steps for profiling, to omit things like compilation and to give @@ -1074,7 +1074,8 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false +enable_nnx: True +pure_nnx_decoder: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 32c86ddb31..e4dd3b25de 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -21,6 +21,22 @@ sft_train_on_completion_only: True packing: True learning_rate: 2.e-5 +# -------------- LoRA / QLoRA -------------- +# Enable LoRA/QLoRA by setting enable_lora: True and configuring the fields below. +enable_lora: False +lora_rank: 0 +lora_alpha: 0.0 +lora_module_path: "" +# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size. +lora_weight_qtype: null +lora_tile_size: null +# Optional NNX LoRA restore checkpoint path (direct `model_params` directory). +lora_restore_path: "" + +# -------------- HF LoRA Adapter -------------- +# HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors +hf_lora_adapter_path: "" + # -------------- HF pipeline -------------- dataset_type: hf hf_path: 'HuggingFaceH4/ultrachat_200k' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index fd0dcc7292..fce9d69577 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -287,6 +287,13 @@ class Checkpointing(BaseModel): load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.") lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") + hf_lora_adapter_path: PathStr = Field( + "", + description=( + "HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local " + "path to directory containing adapter_model.safetensors." + ), + ) load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") @@ -777,6 +784,7 @@ class HardwareAndMesh(BaseModel): enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.") optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") + pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") class LayoutAndSharding(BaseModel): @@ -795,7 +803,10 @@ class LayoutAndSharding(BaseModel): description="Allowed percentage of non-sharded parameters.", ) shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.") - internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.") + internal_compile: bool = Field( + False, + description="Use internal_compile to bypass open-source topology mappings.", + ) internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.") @@ -1051,6 +1062,36 @@ class FineTuning(BaseModel): use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") +class LoRA(BaseModel): + """Configuration for LoRA / QLoRA adapters.""" + + enable_lora: bool = Field(False, description="If True, enables LoRA/QLoRA during fine-tuning.") + lora_rank: NonNegativeInt = Field(0, description="LoRA rank. Set >0 when LoRA is enabled.") + lora_alpha: NonNegativeFloat = Field(0.0, description="LoRA alpha scaling factor.") + lora_module_path: str = Field( + "", + description=( + "Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + ), + ) + lora_weight_qtype: str | None = Field( + None, + description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."), + ) + lora_tile_size: NonNegativeInt | None = Field( + None, + description="Optional tile size for QLoRA (e.g., 128 or 256).", + ) + lora_restore_path: PathStr = Field( + "", + description=( + "Optional Tunix NNX LoRA checkpoint path to restore adapter weights from." + " This may point to the checkpoint root, a numeric step directory," + " or a direct `model_params` path." + ), + ) + + class Distillation(BaseModel): """Configuration for Knowledge Distillation.""" @@ -1383,7 +1424,11 @@ class Profiling(BaseModel): xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.") xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.") xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.") - profile_power_events: bool = Field(False, description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.") + profile_power_events: bool = Field( + False, + description="Enable TPU-specific power/thermal profiling events." + " Defaults to False to avoid breaking GPU xplane tracing.", + ) class HloDump(BaseModel): @@ -1865,6 +1910,7 @@ class MaxTextConfig( AdamW, Muon, FineTuning, + LoRA, Distillation, # Reinforcement Learning RLHardware, diff --git a/src/maxtext/layers/multi_token_prediction.py b/src/maxtext/layers/multi_token_prediction.py index c9647b8368..d97a8e3592 100644 --- a/src/maxtext/layers/multi_token_prediction.py +++ b/src/maxtext/layers/multi_token_prediction.py @@ -108,12 +108,22 @@ def __init__( rngs=rngs, ) # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. - mtp_transformer_layer = transformer_layer_module( - config=cfg, - mesh=mesh, - model_mode=MODEL_MODE_TRAIN, - name=f"mtp_{k}_transformer_layer", - ) + if cfg.pure_nnx_decoder: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + rngs=rngs, + ) + else: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + ) + self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) # ToNNX requires explicit initialization with sample inputs for proper parameter setup. diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py new file mode 100644 index 0000000000..647030e0c2 --- /dev/null +++ b/src/maxtext/layers/nnx_decoders.py @@ -0,0 +1,967 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for decoder layers""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +import functools +from typing import Any +import warnings +import inspect + +import jax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx +from flax.nnx import wrappers as nnx_wrappers + +from maxtext.common.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.layers import linears +from maxtext.layers import mhc +from maxtext.layers import normalizations +from maxtext.layers import initializers +from maxtext.layers import quantizations +from maxtext.layers.attentions import Attention +from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.embeddings import Embed, attend_on_embedding, PositionalEmbedding +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.models import ( + deepseek, + deepseek_batchsplit, + gemma, + gemma2, + gemma3, + gpt3, + gpt_oss, + llama2, + llama4, + mistral, + mixtral, + qwen3, + simple_layer, + olmo3, +) +from maxtext.multimodal import utils as mm_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils import max_logging +from maxtext.utils import sharding +from maxtext.utils import maxtext_utils +from maxtext.inference import page_manager + +# ------------------------------------------------------------------------------ +# The network: Decoder Definitions +# ------------------------------------------------------------------------------ + + +class NNXDecoderLayer(nnx.Module): + """ + Transformer decoder layer converted to NNX. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + name: str = "decoder_layer", + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + + cfg = self.config + + self.pre_self_attention_norm = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + self.self_attention = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + use_mrope=cfg.use_mrope, + mrope_section=cfg.mrope_section, + model_mode=model_mode, + ) + + self.mlp = linears.MlpBlock( + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + model_mode=model_mode, + config=cfg, + quant=self.quant, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + debug_sharding=cfg.debug_sharding, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") + else: + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.pre_self_attention_norm(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + mlp_lnx = self.mlp(lnx, deterministic=deterministic) + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + + next_layer_addition = mlp_lnx + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + + if cfg.record_internal_nn_metrics: + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) + self.sow( + nnx.Intermediate, + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): + """Process deepstack visual embeddings by adding them to hidden states at visual token positions. + + Args: + hidden_states: [batch, seq_len, hidden_dim] decoder hidden states + bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions + visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer + + Returns: + Updated hidden_states with visual features added at visual positions + """ + # Expand mask to [batch, seq_len, 1] for broadcasting + mask_expanded = bidirectional_mask[:, :, jnp.newaxis] + # Use cumsum to map each True position in mask to its index in visual_embeds + visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed + + # Gather visual tokens: for each position, get the corresponding visual token + batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1] + visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden] + + # Only add where mask is True: hidden_states += visual_embeds * mask + hidden_states = hidden_states + visual_embeds_scattered * mask_expanded + return hidden_states + + +class NNXDecoder(nnx.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + decoder_block_classes = self.get_decoder_layers() + + self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + ) + + if config.trainable_position_size > 0: + self.position_embedder = Embed( + num_embeddings=config.trainable_position_size, + num_features=config.emb_dim, + dtype=config.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=config, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) + + if not config.logits_via_embedding: + self.logits_dense = linears.DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=config.vocab_size, + weight_dtype=config.weight_dtype, + dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=config.parameter_memory_host_offload, + rngs=rngs, + ) + + self.scanned_layers = None + self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK + self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 + + if self.config.scan_layers: + if self.is_deepseek: + assert len(decoder_block_classes) == 2 + dense_cls, moe_cls = decoder_block_classes + + num_dense = config.first_num_dense_layers + self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) + + num_moe = config.num_decoder_layers - config.first_num_dense_layers + + self.moe_stack = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) + elif self.is_gemma3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + layer_kwargs = {"num_of_layers": attention_pattern_length} + + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + + RemattedGemma3Block = gemma3.Gemma3ScannableBlock + + if scan_length > 0: + self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) + self.layers_remainder = RemattedGemma3Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) # pytype: disable=wrong-keyword-args + else: + layer_cls = decoder_block_classes[0] + num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, + } + + self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) + else: + self.layers = nnx.List([]) + + if self.is_deepseek: + dense_cls, moe_cls = decoder_block_classes + for i in range(config.first_num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layer", i) + for i in range(config.num_decoder_layers - config.first_num_dense_layers): + self._create_and_register_layer(moe_cls, rngs, "moe_layer", i) + else: + layer_cls = decoder_block_classes[0] + + for lyr in range(config.num_decoder_layers): + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + } + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + layer_kwargs = {"layer_idx": lyr} + elif config.decoder_block == DecoderBlockType.GPT_OSS: + layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.OLMO3: + layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + + self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + + def _create_and_register_layer(self, layer_cls, rngs, base_name, i, **layer_kwargs): + attr_name = f"{base_name}_{i}" + layer = self._create_single_layer(layer_cls, rngs, **layer_kwargs) + setattr(self, attr_name, layer) + self.layers.append(layer) + + def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): + """Helper to create a single layer (Linen or NNX).""" + if issubclass(decoder_layer_class, nnx.Module): + return decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs, **kwargs + ) + else: + layer_linen = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, **kwargs + ) + return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) + + def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): + """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + + def create_layer_fn(rng): + layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs + ) + + return layer + + # Workaround for Deepseek MTP test failure. + # TODO: Handle this properly. + try: + forked_rngs = rngs.fork(split=length) + + except: # pylint: disable=bare-except + pass + + out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) + layers_vmapped = nnx.vmap( + create_layer_fn, + in_axes=0, + out_axes=out_axes, + axis_name="layers", + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(forked_rngs) + + return layers_vmapped + + def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): + """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" + + graphdef, state = nnx.split(layer) + + def pure_layer_fn(state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out = merged_layer(y_in, **kwargs) + return out, nnx.state(merged_layer) + + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) + nnx.update(layer, new_state) + + return out + + def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): + """Runs the layer stack using nnx.scan.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + graphdef, params, state = nnx.split( + layers, nnx.Param, ... + ) # state: the mutable state we carry (KV cache, RNGs, etc.) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + # Move scan_axis to 0 so scan can iterate over it + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + layer_cls = layers.__class__ + sig = inspect.signature(layer_cls.__call__) + valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + layer_cls = layers.__class__ # Access the underlying class + sig = inspect.signature(layer_cls.__call__) + # Filter kwargs to only include keys that exist in the layer's signature + valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + def layer_fn(carry, scanned_vars): + # Unpack the sliced variables for THIS layer + current_params, current_state = scanned_vars + + # Merge using the SLICED state + layer = nnx.merge(graphdef, current_params, current_state) + + # Run the layer (Filter kwargs if using the solution from previous turn) + layer_out = layer(carry, *args, **valid_kwargs) + + new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out + + # Extract the updated state to return it + # _, new_current_state = nnx.split(layer, nnx.Param, ...) + new_current_state = nnx.state(layer) + return new_carry, new_current_state + + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + + if scan_axis != 0: + scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) + scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) + scanned_state = nnx.State.merge(scanned_params, scanned_other) + + return final_carry, nnx.merge(graphdef, scanned_state) + + def get_decoder_layers(self): + """Retrieves decoder layer classes based on config using a dictionary lookup.""" + cfg = self.config + + def get_scannable(normal_cls, scannable_cls): + return [scannable_cls] if cfg.scan_layers else [normal_cls] + + def get_deepseek(): + if cfg.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + + layer_map = { + DecoderBlockType.DEFAULT: [NNXDecoderLayer], + DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer], + DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer], + DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer], + DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], + DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], + DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], + DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], + DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], + DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], + DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], + DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], + DecoderBlockType.DEEPSEEK: get_deepseek(), + DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), + DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), + DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock), + DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock), + } + + if cfg.decoder_block not in layer_map: + raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") + + return layer_map[cfg.decoder_block] + + def minimal_policy(self, with_context=False, with_quantization=False): + """Helper for creating minimal checkpoint policies.""" + names = [ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ] + if with_context: + names.append("context") + if with_quantization: + names.append("quantization") + return jax.checkpoint_policies.save_only_these_names(*names) + + def get_remat_policy(self): + """Get remat policy for jax.checkpoint.""" + policy = None + cfg = self.config + if cfg.remat_policy != "none": + if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): + # save all + if cfg.remat_policy == "minimal_flash": + max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") + policy = self.minimal_policy(with_context=True) + elif cfg.remat_policy == "minimal": + # save all except context + policy = self.minimal_policy() + elif cfg.remat_policy == "minimal_with_quantization": + if cfg.scan_layers: + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) + policy = self.minimal_policy(with_context=False, with_quantization=True) + elif cfg.remat_policy == "minimal_with_context_and_quantization": + if cfg.scan_layers: + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) + policy = self.minimal_policy(with_context=True, with_quantization=True) + elif cfg.remat_policy == "save_dot_with_context_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "context", + "out_proj", + ) + elif cfg.remat_policy == "save_dot_except_mlpwi": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", + ) + elif cfg.remat_policy == "save_dot_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + ) + elif cfg.remat_policy == "save_qkv_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + ) + elif cfg.remat_policy == "qkv_proj_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "custom": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=cfg.tensors_on_device, + names_which_can_be_offloaded=cfg.tensors_to_offload, + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "save_out_proj": + policy = jax.checkpoint_policies.save_only_these_names("out_proj") + else: + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" + policy = None + return policy + + def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): + """get normalization layer (return type inherits from nn.Module)""" + if self.config.decoder_block in ( + DecoderBlockType.DEFAULT, + DecoderBlockType.LLAMA2, + DecoderBlockType.MISTRAL, + DecoderBlockType.MIXTRAL, + DecoderBlockType.DEEPSEEK, + DecoderBlockType.GEMMA, + DecoderBlockType.GEMMA2, + DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN3, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.GPT_OSS, + DecoderBlockType.SIMPLE, + DecoderBlockType.SIMPLE_MLP, + DecoderBlockType.LLAMA4, + DecoderBlockType.OLMO3, + ): + return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + elif self.config.decoder_block == DecoderBlockType.GPT3: + return functools.partial( + gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs + ) + elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: + return functools.partial( + normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode + ) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def _apply_embedding( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings=None, + bidirectional_mask=None, + image_masks=None, + audio_embeddings=None, + audio_masks=None, + ): + """Applies token and positional embeddings to the input tokens.""" + cfg = self.config + + y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + + if image_embeddings is not None and cfg.use_multimodal: + if cfg.model_name in [ + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama4-17b-16e", + "llama4-17b-128e", + "qwen3-omni-30b-a3b", + ]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=image_embeddings, + mask=bidirectional_mask, + token_masks=image_masks, + ) + # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed + else: + raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + + if audio_embeddings is not None and cfg.use_audio: + if cfg.model_name in ["qwen3-omni-30b-a3b"]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=audio_embeddings, + mask=audio_masks, + token_masks=None, + ) + else: + raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if cfg.use_untrainable_positional_embedding: + y += self.positional_embedding(y, decoder_positions) + + if cfg.trainable_position_size > 0 and self.position_embedder: + y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) + + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + else: + norm_out_sharding = None + + y = self.decoder_norm(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) # NNX call + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + ) + + if cfg.logits_via_embedding: + if isinstance(shared_embedding, nnx.Module): + embedding_table = shared_embedding.embedding.value + else: + embedding_table = shared_embedding.variables["params"]["embedding"] + if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): + embedding_table = embedding_table.unbox() + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = logits / cfg.final_logits_soft_cap + logits = jnp.tanh(logits) * cfg.final_logits_soft_cap + else: + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + + return logits + + def __call__( + self, + shared_embedding: Any, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + audio_embeddings: None | jnp.ndarray = None, + audio_masks: None | jnp.ndarray = None, + deepstack_visual_embeds: None | list[jnp.ndarray] = None, + ): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + + y = self._apply_embedding( + shared_embedding, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + audio_embeddings, + audio_masks, + ) + + mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate) + if cfg.mhc_expansion_rate > 1: + # (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim) + y = mhc_expand(y) + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + + layer_kwargs = {} + if cfg.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs["bidirectional_mask"] = bidirectional_mask + + if cfg.scan_layers: + if self.is_deepseek: + layer_kwargs = { + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + } + y, self.dense_layers = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, self.moe_stack = self._apply_layers_sequentially( + self.moe_stack, y, *layer_args, length=num_moe, **layer_kwargs + ) + elif self.is_gemma3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + y, self.layers = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=cfg.num_decoder_layers, **layer_kwargs + ) + else: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + # Hoisted function to preserve XLA cache ID + def pure_layer_fn(graphdef, state_in, y_in, kv_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) + return out_y, out_kv, nnx.state(merged_layer) + + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + + for lyr, layer in enumerate(self.layers): + graphdef, state = nnx.split(layer) + kv_cache = kv_caches[lyr] if kv_caches is not None else None + + y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) + nnx.update(layer, new_state) + + if kv_caches is not None and kv_cache is not None: + kv_caches[lyr] = kv_cache + + if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): + visual_embeds = deepstack_visual_embeds[lyr] + if bidirectional_mask is not None and visual_embeds is not None: + y = deepstack_process(y, bidirectional_mask, visual_embeds) + + assert isinstance(y, jax.Array) + + # After the final transformer layer, `y` holds the raw, un-normalized hidden state. + if cfg.mhc_expansion_rate > 1: + # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) + hidden_state = mhc_reduce(y) + else: + hidden_state = y + + # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. + if cfg.attention == "vllm_rpa": + logits = None + + # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # Instead, we keep track on the hidden states, which has smaller size compared to full logits + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + self.sow(nnx.Intermediate, "hidden_states", hidden_state) + + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches + + def _apply_gemma3_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ): + """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" + + cfg = self.config + + # Define the repeating pattern length and calculate how many full blocks to scan + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + layer_kwargs = {"bidirectional_mask": bidirectional_mask} + + # Apply the main scan over the full blocks + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + def pure_gemma_fn(graphdef, state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, _ = merged_layer(y_in, *layer_args, **layer_kwargs) + return out_y, nnx.state(merged_layer) + + checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) + + graphdef, state = nnx.split(self.layers_remainder) + y, new_state = checkpointed_gemma_fn(graphdef, state, y) + nnx.update(self.layers_remainder, new_state) + + return y + + +def decoder_as_linen( + config: Config, + mesh: Mesh, + rngs: nnx.Rngs, + model_mode: str, + quant: None | Quant = None, +): + """Creates a Decoder module.""" + module = nnx_wrappers.to_linen( + NNXDecoder, + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant, + name="decoder", + abstract_init=False, + metadata_fn=initializers.variable_to_logically_partitioned, + ) + return module diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index fe41af9b40..2c4a0da4d9 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -132,7 +132,9 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) else: - raise ValueError(f"Cannot infer collection name from value: {v}") + # Skip non-variable attributes (e.g., submodules or metadata) when + # converting to Linen-style variables. + continue linen_structured[(col_name, *kp)] = v variables = nnx.traversals.unflatten_mapping(linen_structured) return variables diff --git a/src/maxtext/models/gemma3.py b/src/maxtext/models/gemma3.py index 588ffa6db2..630497e224 100644 --- a/src/maxtext/models/gemma3.py +++ b/src/maxtext/models/gemma3.py @@ -91,7 +91,6 @@ def __init__( batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) - self.pre_self_attention_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, @@ -198,7 +197,6 @@ def __call__( inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = self.pre_self_attention_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index cfd837c6c5..f4e751554d 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -18,13 +18,16 @@ from typing import Any -from flax import linen as nn -from flax import nnx import jax import jax.numpy as jnp from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx + from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN from maxtext.inference import page_manager +from maxtext.layers.nnx_decoders import NNXDecoder, decoder_as_linen from maxtext.layers import initializers from maxtext.layers import nnx_wrappers from maxtext.layers.decoders import Decoder @@ -85,7 +88,13 @@ def setup(self): ) self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None - self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + if cfg.pure_nnx_decoder: + self.decoder = decoder_as_linen( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0) + ) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. @@ -328,9 +337,11 @@ def __init__( ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None - - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + if cfg.pure_nnx_decoder: + self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) @@ -356,12 +367,13 @@ def __init__( else: dummy_attention_metadata = None - self.decoder.lazy_init( - shared_embedding=self.token_embedder, - decoder_input_tokens=dummy_decoder_input_tokens, - decoder_positions=dummy_decoder_positions, - attention_metadata=dummy_attention_metadata, - ) + if not cfg.pure_nnx_decoder: + self.decoder.lazy_init( + shared_embedding=self.token_embedder, + decoder_input_tokens=dummy_decoder_input_tokens, + decoder_positions=dummy_decoder_positions, + attention_metadata=dummy_attention_metadata, + ) # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: @@ -483,26 +495,47 @@ def __call__( if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections: mutable_collections.append("intermediates") - logits, hidden_state, kv_caches = self.decoder( - shared_embedding=self.token_embedder, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=not enable_dropout, - model_mode=model_mode, - previous_chunk=previous_chunk, - slot=slot, - page_state=page_state, - bidirectional_mask=bidirectional_mask, - image_embeddings=image_embeddings, - image_masks=encoder_image_masks, - audio_embeddings=audio_embeddings, - audio_masks=audio_masks, - kv_caches=kv_caches, - attention_metadata=attention_metadata, - deepstack_visual_embeds=deepstack_visual_embeds, - mutable=mutable_collections, - ) + if self.config.pure_nnx_decoder: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + else: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + mutable=mutable_collections, + ) # Materialize hidden state when vocab tiling is enabled if self.config.num_vocab_tiling > 1: diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index f6af7e26ec..5fe8c76c8f 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -15,22 +15,22 @@ """ SFT training script that calls a trainer in Tunix to run SFT on a MaxText model using `HuggingFaceH4/ultrachat_200k` dataset. The configurations for the dataset -are defined inside `src/maxtext/configs/post_train/sft.yml`. +are defined inside `src/MaxText/configs/sft.yml`. Example command: Training & Evaluation: python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ - run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ - model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ - hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ + run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=${MODEL_NAME?} load_parameters_path=${CHECKPOINT_PATH?} \ + hf_access_token=${HF_ACCESS_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} \ per_device_batch_size=1 max_target_length=1024 \ eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 Training: python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ - run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ - model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ - hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ + run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=${MODEL_NAME?} load_parameters_path=${CHECKPOINT_PATH?} \ + hf_access_token=${HF_ACCESS_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} \ per_device_batch_size=1 max_target_length=1024 \ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ @@ -38,8 +38,12 @@ from typing import Sequence from absl import app +import math import os +import re import jax +import jax.numpy as jnp +from flax import nnx import optax import pathwaysutils @@ -47,8 +51,12 @@ from orbax import checkpoint as ocp +from tunix.sft import checkpoint_manager as tunix_checkpoint_manager from tunix.sft import metrics_logger, peft_trainer, profiler +from tunix.sft import utils as tunix_sft_utils +from tunix.rl import reshard +from maxtext.optimizers import optimizers from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( @@ -60,7 +68,6 @@ maybe_record_goodput, record_goodput, ) -from maxtext.optimizers import optimizers from maxtext.trainers.post_train.sft import hooks from maxtext.utils import max_utils from maxtext.utils import max_logging @@ -126,7 +133,15 @@ def use_maxtext_loss_function(trainer, mt_config): The trainer configured with the MaxText loss function. """ - def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targets_position, targets_segmentation): + def loss_func( + model, + inputs, + inputs_position, + inputs_segmentation, + targets, + targets_position, + targets_segmentation, + ): data = { "inputs": inputs, "inputs_position": inputs_position, @@ -141,12 +156,493 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ return trainer +def _validate_lora_config(mt_config): + """Validates required LoRA configuration fields.""" + if mt_config.lora_rank <= 0: + raise ValueError("enable_lora is True but lora_rank is not set to a positive value.") + if not mt_config.lora_module_path: + raise ValueError("enable_lora is True but lora_module_path is empty.") + + +def _build_lora_provider(mt_config, qwix): + """Builds a Qwix LoRA provider from MaxText LoRA settings.""" + lora_kwargs = { + "module_path": mt_config.lora_module_path, + "rank": mt_config.lora_rank, + "alpha": mt_config.lora_alpha, + } + if mt_config.lora_tile_size is not None: + lora_kwargs["tile_size"] = mt_config.lora_tile_size + if mt_config.lora_weight_qtype is not None: + lora_kwargs["weight_qtype"] = mt_config.lora_weight_qtype + max_logging.log( + f"QLoRA configured: module_path={mt_config.lora_module_path} " + f"rank={mt_config.lora_rank} alpha={mt_config.lora_alpha} " + f"weight_qtype={mt_config.lora_weight_qtype} " + f"tile_size={mt_config.lora_tile_size}" + ) + else: + max_logging.log( + f"LoRA configured: module_path={mt_config.lora_module_path} " + f"rank={mt_config.lora_rank} alpha={mt_config.lora_alpha} " + f"tile_size={mt_config.lora_tile_size}" + ) + return qwix.LoraProvider(**lora_kwargs) + + +def _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types): + """Patches Qwix LoRA dot_general to support selected 3D-kernel paths.""" + + def _dot_general_with_3d( + self, + lhs, + rhs, + dimension_numbers, + precision=None, + preferred_element_type=None, + out_sharding=None, + ): + res = qwix_ptq.PtqProvider.dot_general( + self, + lhs, + rhs, + dimension_numbers, + precision, + preferred_element_type, + out_sharding=out_sharding, + ) + + rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) + if not isinstance(rule, qwix_lora.LoraRule): + return res + + weight_name = qwix_flax_util.find_param(rhs, qwix_lora.ptq.WithAux) + if weight_name is None: + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) + + try: + current_module = qwix_flax_util.get_current_module() + lora_a = getattr(current_module, f"{weight_name}_lora_a", None) + lora_b = getattr(current_module, f"{weight_name}_lora_b", None) + + if lora_a is None or lora_b is None: + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) + + if isinstance(lora_a, nnx.Variable): + lora_a = lora_a[...] + if isinstance(lora_b, nnx.Variable): + lora_b = lora_b[...] + except Exception: + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) + + if rule.dropout > 0: + lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) + + contract_axes_lhs = tuple(dimension_numbers[0][0]) + contract_axes_rhs = tuple(dimension_numbers[0][1]) + + # If the default provider fails due to shape, we handle it universally here. + if len(rhs.shape) > 2: + k = 1 + for axis in contract_axes_rhs: + k *= rhs.shape[axis] + + out_dim = lora_b.size // rule.rank + + # Validate that LoRA shapes make mathematical sense + if lora_a.size == k * rule.rank and lora_b.size == rule.rank * out_dim: + # Reshape A to 2D + lora_a_flat = jnp.reshape(lora_a, (k, rule.rank)) + + # Reshape B to 2D + lora_b_flat = jnp.reshape(lora_b, (rule.rank, out_dim)) + + # Flatten LHS to abstract over multiple batch/sequence dimensions + lhs_perm = [i for i in range(lhs.ndim) if i not in contract_axes_lhs] + list(contract_axes_lhs) + lhs_trans = jnp.transpose(lhs, lhs_perm) + lhs_shape = lhs_trans.shape + lhs_flat = jnp.reshape(lhs_trans, (-1, k)) + + # Do the 2D LoRA math + delta_flat = lhs_flat @ lora_a_flat @ lora_b_flat + + # Unflatten the delta to match the original result shape + delta = jnp.reshape(delta_flat, res.shape) + + return res + delta * (rule.alpha / rule.rank) + + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) + + lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) + +def _prepare_dummy_inputs(mt_config, mesh): + """Builds dummy decoder inputs used to materialize LoRA parameters.""" + batch_size = getattr(mt_config, "per_device_batch_size", 1) + seq_len = getattr(mt_config, "max_target_length", 1) + if batch_size <= 0 or seq_len <= 0: + raise ValueError("per_device_batch_size and max_target_length must be positive when LoRA is enabled.") + + devices_data_fsdp = 1 + if mesh is not None: + devices_data_fsdp = mesh.shape.get("data", 1) * mesh.shape.get("fsdp", 1) + + dummy_bs = (max(batch_size, devices_data_fsdp) + devices_data_fsdp - 1) // devices_data_fsdp + dummy_bs *= devices_data_fsdp + + decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + decoder_positions = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32), (dummy_bs, seq_len)) + return decoder_input_tokens, decoder_positions + + +def _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types): + """Pre-creates LoRA parameter tensors for modules matching the target regex.""" + rules = [rule for rule in getattr(lora_provider, "_rules", []) if isinstance(rule, qwix_lora.LoraRule)] + if not rules: + max_logging.log("LoRA precreate: no LoRA rules found on provider, skipping.") + return + + # MaxText uses a single LoRA rule from the provided module_path regex. + rule = rules[0] + compiled_module_path = re.compile(mt_config.lora_module_path) + num_decoder_layers = getattr(mt_config, "num_decoder_layers", None) + if num_decoder_layers is None: + num_decoder_layers = getattr(mt_config, "base_num_decoder_layers", None) + param_scan_axis = int(getattr(mt_config, "param_scan_axis", 0)) + + def _with_layer_axis(base_shape_or_transpose, layer_value): + axis = max(0, min(param_scan_axis, len(base_shape_or_transpose))) + values = list(base_shape_or_transpose) + values.insert(axis, layer_value) + return tuple(values) + + def _extract_kernel_shape(kernel_value): + kernel_shape = getattr(kernel_value, "shape", None) + if kernel_shape is None and hasattr(kernel_value, "array"): + kernel_shape = getattr(kernel_value.array, "shape", None) + if kernel_shape is None and hasattr(kernel_value.array, "qvalue"): + kernel_shape = getattr(kernel_value.array.qvalue, "shape", None) + if kernel_shape is None: + return None + return tuple(int(dim) for dim in kernel_shape) + + matched_modules = 0 + precreated_modules = 0 + skipped_modules = [] + precreated_shapes = [] + + def _process_param(module, module_path, param_name, param_obj, in_features_shape, out_features_shape): + nonlocal precreated_modules + try: + kernel_value = qwix_flax_util.unbox(param_obj) + except Exception: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}.{param_name}: cannot unbox kernel") + return False + + kernel_shape = _extract_kernel_shape(kernel_value) + if kernel_shape is None or len(kernel_shape) < 2: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}.{param_name}: unsupported kernel shape {kernel_shape}") + return False + + expected_suffix = in_features_shape + out_features_shape + layer_axis = None + base_kernel_shape = None + + # 1. Determine if this parameter is scanned over layers + if isinstance(num_decoder_layers, int) and len(kernel_shape) >= len(expected_suffix) + 1: + # Prefer param_scan_axis if it matches the expected layer count + if kernel_shape[param_scan_axis] == num_decoder_layers: + candidate_base = tuple(dim for i, dim in enumerate(kernel_shape) if i != param_scan_axis) + if candidate_base[-len(expected_suffix):] == expected_suffix: + layer_axis = param_scan_axis + base_kernel_shape = candidate_base + + # If not found at param_scan_axis, search other axes (for edge cases where scan axis might differ) + if layer_axis is None: + for axis in range(len(kernel_shape)): + if kernel_shape[axis] == num_decoder_layers: + candidate_base = tuple(dim for i, dim in enumerate(kernel_shape) if i != axis) + if candidate_base[-len(expected_suffix):] == expected_suffix: + layer_axis = axis + base_kernel_shape = candidate_base + break + + # 2. Check if it's an unscanned parameter + if layer_axis is None and len(kernel_shape) >= len(expected_suffix): + if kernel_shape[-len(expected_suffix):] == expected_suffix: + base_kernel_shape = kernel_shape + + # 3. If neither matched, skip this parameter + if base_kernel_shape is None: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}.{param_name}: kernel shape {kernel_shape} does not match expected suffix {expected_suffix}") + return False + + prefix_shape = base_kernel_shape[:-len(expected_suffix)] if len(expected_suffix) > 0 else base_kernel_shape + + # 4. Compute axes mapped sequentially for the base (unscanned) shape + prefix_axes_base = tuple(range(len(prefix_shape))) + input_axes_base = tuple(range(len(prefix_shape), len(prefix_shape) + len(in_features_shape))) + output_axes_base = tuple(range(len(prefix_shape) + len(in_features_shape), len(base_kernel_shape))) + + # 5. Shift axes to account for the layer_axis insertion + if layer_axis is not None: + def shift_axes(axes): + return tuple(axis if axis < layer_axis else axis + 1 for axis in axes) + + a_shape = _with_layer_axis(prefix_shape + in_features_shape + (rule.rank,), num_decoder_layers) + b_shape = _with_layer_axis(prefix_shape + (rule.rank,) + out_features_shape, num_decoder_layers) + a_sharding_transpose = _with_layer_axis(shift_axes(prefix_axes_base + input_axes_base) + (None,), layer_axis) + b_sharding_transpose = _with_layer_axis(shift_axes(prefix_axes_base) + (None,) + shift_axes(output_axes_base), layer_axis) + else: + a_shape = prefix_shape + in_features_shape + (rule.rank,) + b_shape = prefix_shape + (rule.rank,) + out_features_shape + a_sharding_transpose = prefix_axes_base + input_axes_base + (None,) + b_sharding_transpose = prefix_axes_base + (None,) + output_axes_base + + def _init_for_module( + self, + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, + ): + qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access + name=param_name, + rule=rule, + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, + ) + + types.MethodType(_init_for_module, module)() + precreated_modules += 1 + if len(precreated_shapes) < 10: + precreated_shapes.append((f"{module_path}.{param_name}", a_shape, b_shape)) + return True + + + for path, module in nnx.iter_modules(lora_model): + module_path = "/".join(str(p) for p in path) + if not compiled_module_path.search(module_path): + continue + + matched_modules += 1 + + # DenseGeneral-style layers (Standard, Vision, Audio) + if hasattr(module, "in_features_shape") and hasattr(module, "out_features_shape"): + in_features_shape = tuple(int(dim) for dim in getattr(module, "in_features_shape", ())) + out_features_shape = tuple(int(dim) for dim in getattr(module, "out_features_shape", ())) + if hasattr(module, "kernel"): + _process_param(module, module_path, "kernel", module.kernel, in_features_shape, out_features_shape) + + # MoE-style layers (RoutedMoE, RoutedAndSharedMoE) + elif type(module).__name__ in ["RoutedMoE", "RoutedAndSharedMoE"]: + emb_dim = getattr(getattr(module, "config", None), "emb_dim", None) + if emb_dim is not None: + intermediate_dim = getattr(module, "intermediate_dim", getattr(getattr(module, "config", None), "moe_mlp_dim", None)) + if intermediate_dim is not None: + if hasattr(module, "wi_0"): + _process_param(module, module_path, "wi_0", module.wi_0, (emb_dim,), (intermediate_dim,)) + if hasattr(module, "wi_1"): + _process_param(module, module_path, "wi_1", module.wi_1, (emb_dim,), (intermediate_dim,)) + if hasattr(module, "wo"): + _process_param(module, module_path, "wo", module.wo, (intermediate_dim,), (emb_dim,)) + + max_logging.log( + f"LoRA precreate: matched_modules={matched_modules} " + f"precreated_modules={precreated_modules} " + f"skipped_sample={skipped_modules} " + f"shape_sample={precreated_shapes}" + ) + + +def _verify_lora_parameters(lora_model, mt_config): + """Validates that LoRA is active or that target modules were matched.""" + compiled_module_path = re.compile(mt_config.lora_module_path) + matched_module_paths = [] + sample_module_paths = [] + + for path, _ in nnx.iter_modules(lora_model): + module_path = "/".join(str(p) for p in path) + if len(sample_module_paths) < 50: + sample_module_paths.append(module_path) + if compiled_module_path.search(module_path): + matched_module_paths.append(module_path) + + is_lora_enabled = tunix_sft_utils.is_lora_enabled(lora_model) + if is_lora_enabled: + max_logging.log("LoRA verification: tunix_sft_utils.is_lora_enabled=True") + return + + if not matched_module_paths: + max_logging.log( + f"LoRA module_path='{mt_config.lora_module_path}' did not match any weights. " + f"Sample module paths: {sample_module_paths}" + ) + raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") + + max_logging.log( + f"LoRA verification: matched {len(matched_module_paths)} target modules but " + "LoRA params are not yet materialized; continuing with lazy LoRA initialization. " + f"Sample matches: {matched_module_paths[:10]}" + ) + + +def maybe_apply_lora(model, mesh, mt_config): + """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" + # Skip Qwix LoRA if MaxText LoRA adapters are loaded + if hasattr(mt_config, "lora_input_adapters_path") and mt_config.lora_input_adapters_path: + max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") + return model + + if not getattr(mt_config, "enable_lora", False): + return model + + import qwix # pylint: disable=import-outside-toplevel + import qwix._src.flax_util as qwix_flax_util # pylint: disable=import-outside-toplevel + import qwix._src.providers.lora as qwix_lora # pylint: disable=import-outside-toplevel + import qwix._src.providers.ptq as qwix_ptq # pylint: disable=import-outside-toplevel + import types # pylint: disable=import-outside-toplevel + + _validate_lora_config(mt_config) + lora_provider = _build_lora_provider(mt_config, qwix) + + _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types) + + decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh) + lora_model = qwix.apply_lora_to_model( + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + skip_nnx_init=True, + ) + + # Materialize LoRA parameters. Qwix 0.1.5+ unsets RNGs after apply_lora_to_model, + lora_model.set_attributes(qwix_rngs=nnx.Rngs(10003)) + _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types) + lora_model.set_attributes(qwix_rngs=None) + + if mesh is not None: + lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) + + _verify_lora_parameters(lora_model, mt_config) + return lora_model + + +def _resolve_lora_restore_checkpoint(lora_restore_path): + """Normalizes lora_restore_path into Tunix checkpoint manager inputs.""" + normalized_path = os.path.normpath(lora_restore_path) + basename = os.path.basename(normalized_path) + + if basename == "model_params": + step_dir = os.path.dirname(normalized_path) + root_directory = os.path.dirname(step_dir) + step_name = os.path.basename(step_dir) + try: + return root_directory, int(step_name) + except ValueError as exc: + raise ValueError( + "lora_restore_path ending in 'model_params' must live under a numeric step directory." + ) from exc + + if basename.isdigit(): + return os.path.dirname(normalized_path), int(basename) + + return normalized_path, None + + +def maybe_restore_lora_from_path(model, mt_config, mesh=None): + """Optionally restores LoRA params from a dedicated adapter checkpoint path. + + If `lora_restore_path` is set and LoRA params have not yet been materialized on + the model, this function attempts to apply LoRA first (when enabled) before + restoring adapter weights. + + Returns: + A tuple of `(model, resume_step)` where `resume_step` is the step returned + by Tunix checkpoint restore. + """ + lora_restore_path = getattr(mt_config, "lora_restore_path", "") + if not lora_restore_path: + return model, None + + if not tunix_sft_utils.is_lora_enabled(model): + if getattr(mt_config, "enable_lora", False): + max_logging.log("lora_restore_path is set but model has no LoRA params yet; " "applying LoRA before restore.") + model = maybe_apply_lora(model, mesh, mt_config) + + if not tunix_sft_utils.is_lora_enabled(model): + raise ValueError( + "lora_restore_path is set but LoRA is not enabled on the model. " + "Set enable_lora=True and verify lora_module_path matches model modules." + ) + + if not os.path.exists(lora_restore_path): + raise ValueError(f"lora_restore_path does not exist: {lora_restore_path}") + + restore_root_directory, restore_step = _resolve_lora_restore_checkpoint(lora_restore_path) + max_logging.log( + f"Restoring LoRA params from checkpoint root '{restore_root_directory}' " + f"at step {restore_step if restore_step is not None else 'latest'}." + ) + + checkpoint_manager = tunix_checkpoint_manager.CheckpointManager( + root_directory=restore_root_directory, + ) + try: + restored_step, _ = checkpoint_manager.maybe_restore( + model, + step=restore_step, + restore_only_lora_params=True, + ) + finally: + checkpoint_manager.close() + + if restore_step is not None and restored_step != restore_step: + raise ValueError( + f"Expected LoRA restore from step {restore_step}, got step {restored_step}." + ) + + if restored_step == 0: + raise ValueError(f"No LoRA checkpoint found for lora_restore_path: {lora_restore_path}") + + max_logging.log("LoRA restore complete.") + return model, restored_step + + +def _maybe_resume_trainer_from_step(trainer, resume_step, tunix_config, source): + """Applies a recovered step to a freshly initialized trainer if needed.""" + if not resume_step or getattr(trainer, "_train_steps", 0) != 0: + return trainer + + grad_accum = getattr(tunix_config, "gradient_accumulation_steps", None) or 1 + trainer._train_steps = resume_step + trainer._iter_steps = resume_step * grad_accum + if hasattr(trainer, "_prof") and trainer._prof: + trainer._prof.initial_step = trainer._iter_steps + max_logging.log(f"Resuming trainer manually from step {resume_step} based on {source}.") + return trainer + def setup_trainer_state(mt_config, goodput_recorder=None): """Set up prerequisites for training loop.""" tunix_config = get_tunix_config(mt_config) with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + model = maybe_apply_lora(model, mesh, mt_config) + model, lora_resume_step = maybe_restore_lora_from_path(model, mt_config, mesh) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -162,6 +658,13 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = _maybe_resume_trainer_from_step( + trainer, + lora_resume_step, + tunix_config, + source="lora_restore_path", + ) + trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) @@ -172,7 +675,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + trainer.train( + trainer.data_hooks.train_data_iterator, + trainer.data_hooks.eval_data_iterator, + ) return trainer diff --git a/tests/checkpoint_compare.py b/tests/checkpoint_compare.py new file mode 100644 index 0000000000..112f524df0 --- /dev/null +++ b/tests/checkpoint_compare.py @@ -0,0 +1,179 @@ +"""Script for comparing parameters between two checkpoints.""" + +import jax +import jax.numpy as jnp +import orbax.checkpoint as ocp +from typing import Any, Dict, Sequence +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +from absl import app +from absl import flags + + +_LINEN_CKPT_PATH = flags.DEFINE_string( + "linen_ckpt_path", None, "Path to the Linen model checkpoint items directory.", required=True +) +_NNX_CKPT_PATH = flags.DEFINE_string( + "nnx_ckpt_path", None, "Path to the NNX model checkpoint items directory.", required=True +) + + +def load_checkpoint_params(path: str) -> Dict[str, Any]: + """Loads parameters from an Orbax checkpoint path.""" + print(f"Loading checkpoint from: {path}") + checkpointer = ocp.PyTreeCheckpointer() + restored_state = checkpointer.restore(path) + if restored_state is None: + raise ValueError(f"Failed to restore checkpoint from {path}") + if isinstance(restored_state, dict) and "params" in restored_state: + return restored_state["params"] + return restored_state + + +def transform_nnx_params(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Applies specific transformations with verbose logging matching original format.""" + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + print(f"TRANSPOSING: {key_str} with shape {leaf.shape}") + axes = (1, 0) + tuple(range(2, leaf.ndim)) + return jnp.transpose(leaf, axes=axes) + else: + if "token_embedder" in key_str: + print(f"SKIPPING Transpose: {key_str} because it is token_embedder") + else: + shape = getattr(leaf, "shape", "N/A") + print(f"SKIPPING Transpose: {key_str} with shape {shape} (ndim < 2)") + return leaf + + print("Applying transformations to NNX params...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]): + """Helper only used if structures differ.""" + flat_with_path, _ = tree_flatten_with_path(tree) + return {keystr(p): (getattr(l, "shape", "N/A"), str(getattr(l, "dtype", type(l).__name__))) for p, l in flat_with_path} + + +def print_structure_diff(params1, params2): + """Prints missing/added keys if structures differ.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + for k in sorted(keys2 - keys1): + print(f" + Added in NNX: {k}") + for k in sorted(keys1 - keys2): + print(f" - Missing in NNX: {k}") + + +def compare_params(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool: + """ + Compares two parameter trees (e.g., JAX/Flax PyTrees) for structural and numerical equality. + + This function performs a deep comparison of two PyTrees. It first + validates that both trees share the exact same structure. If successful, it iterates + through every leaf node to verify: + 1. Shapes match. + 2. Data types (dtypes) match. + 3. Numerical values are close (within `jnp.allclose` tolerances). + + Args: + params1: The first parameter dictionary or PyTree (e.g., a Linen model). + params2: The second parameter dictionary or PyTree (e.g., an NNX model). + + Returns: + bool: True if structure, shapes, types, and values all match; False otherwise. + """ + + if tree_structure(params1) != tree_structure(params2): + print("[] Tree structures differ.") + print_structure_diff(params1, params2) + return False + + print("[] Tree structures are the same.") + + all_match = True + + def _compare_leaf(path, x, y): + nonlocal all_match + key_str = keystr(path) + + try: + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + + if shape1 != shape2: + print(f"[{key_str}] SHAPE MISMATCH: {shape1} vs {shape2}") + all_match = False + return + + dtype1 = getattr(x, "dtype", type(x)) + dtype2 = getattr(y, "dtype", type(y)) + + if dtype1 != dtype2: + print(f"[{key_str}] DTYPE MISMATCH: {dtype1} vs {dtype2}") + all_match = False + return + + diff = x - y + abs_diff = jnp.abs(diff) + mean_diff_scalar = jnp.mean(abs_diff) + max_diff_scalar = jnp.max(abs_diff) + is_close_scalar = jnp.allclose(x, y) + + mean_diff = float(mean_diff_scalar) + max_diff = float(max_diff_scalar) + is_close = bool(is_close_scalar) + + print( + f"[{key_str}] " + f"Shape(Linen/NNX): {shape1} / {shape2} — " + f"Mean abs diff: {mean_diff:.2e}, " + f"Max abs diff: {max_diff:.2e}, " + f"AllClose: {is_close}" + ) + + if not is_close: + all_match = False + + except Exception as e: # pylint: disable=broad-exception-caught + print(f"[{key_str}] Error during comparison: {e}") + all_match = False + + tree_map_with_path(_compare_leaf, params1, params2) + + return all_match + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + linen_ckpt_path = _LINEN_CKPT_PATH.value + nnx_ckpt_path = _NNX_CKPT_PATH.value + + print(f"Linen Checkpoint Path: {linen_ckpt_path}") + print(f"NNX Checkpoint Path: {nnx_ckpt_path}") + + print("Loading Linen params...") + linen_params = load_checkpoint_params(linen_ckpt_path) + print("Loading NNX params...") + nnx_params = load_checkpoint_params(nnx_ckpt_path) + + if linen_params is not None and nnx_params is not None: + nnx_params_transformed = transform_nnx_params(nnx_params) + + print("\nComparing Linen params with Transformed NNX params...") + if compare_params(linen_params, nnx_params_transformed): + print("\nCheckpoints are considered the same (within np.allclose tolerance) after transformation!") + else: + print("\nCheckpoints DIFFER after transformation.") + else: + print("Failed to load params from one or both checkpoints.") + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index 5f5542ec31..71e6e07f71 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -55,14 +55,23 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - # Instantiate the Layer - self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer( - config=self.cfg, - mesh=self.mesh, - layer_number=TEST_LAYER_NUM, - transformer_layer_module=DecoderLayer, - rngs=self.rngs, - ) + if self.cfg.pure_nnx_decoder: + # Instantiate the Layer + self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer( + config=self.cfg, + mesh=self.mesh, + layer_number=TEST_LAYER_NUM, + transformer_layer_module=DecoderLayer, + rngs=self.rngs, + ) + else: + # Instantiate the Layer + self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayerLinen( + config=self.cfg, + mesh=self.mesh, + layer_number=TEST_LAYER_NUM, + transformer_layer_module=DecoderLayer, + ) # Dimensions directly from the config object self.batch_size = int(self.cfg.per_device_batch_size)