From 46f6eeed979cda3d5c5173f78c177f0690b4da6e Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Tue, 3 Feb 2026 02:41:48 +0000 Subject: [PATCH 01/12] feat: add LoRA/QLoRA configuration and application --- src/maxtext/configs/post_train/sft.yml | 10 ++ src/maxtext/configs/types.py | 26 ++++++ src/maxtext/layers/nnx_wrappers.py | 4 +- .../trainers/post_train/sft/train_sft.py | 93 +++++++++++++++++++ 4 files changed, 132 insertions(+), 1 deletion(-) diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 32c86ddb31..6d040c8535 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -21,6 +21,16 @@ sft_train_on_completion_only: True packing: True learning_rate: 2.e-5 +# -------------- LoRA / QLoRA -------------- +# Enable LoRA/QLoRA by setting enable_lora: True and configuring the fields below. +enable_lora: False +lora_rank: 0 +lora_alpha: 0.0 +lora_module_path: "" +# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size. +lora_weight_qtype: null +lora_tile_size: null + # -------------- HF pipeline -------------- dataset_type: hf hf_path: 'HuggingFaceH4/ultrachat_200k' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index fd0dcc7292..7369782416 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1051,6 +1051,31 @@ class FineTuning(BaseModel): use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") +class LoRA(BaseModel): + """Configuration for LoRA / QLoRA adapters.""" + + enable_lora: bool = Field(False, description="If True, enables LoRA/QLoRA during fine-tuning.") + lora_rank: NonNegativeInt = Field(0, description="LoRA rank. Set >0 when LoRA is enabled.") + lora_alpha: NonNegativeFloat = Field(0.0, description="LoRA alpha scaling factor.") + lora_module_path: str = Field( + "", + description=( + "Regex identifying target modules for LoRA, e.g." + " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + ), + ) + lora_weight_qtype: str | None = Field( + None, + description=( + "Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied." + ), + ) + lora_tile_size: NonNegativeInt | None = Field( + None, + description="Optional tile size for QLoRA (e.g., 128 or 256).", + ) + + class Distillation(BaseModel): """Configuration for Knowledge Distillation.""" @@ -1865,6 +1890,7 @@ class MaxTextConfig( AdamW, Muon, FineTuning, + LoRA, Distillation, # Reinforcement Learning RLHardware, diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index fe41af9b40..2c4a0da4d9 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -132,7 +132,9 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) else: - raise ValueError(f"Cannot infer collection name from value: {v}") + # Skip non-variable attributes (e.g., submodules or metadata) when + # converting to Linen-style variables. + continue linen_structured[(col_name, *kp)] = v variables = nnx.traversals.unflatten_mapping(linen_structured) return variables diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index f6af7e26ec..43efd1ff1d 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -40,6 +40,8 @@ from absl import app import os import jax +import jax.numpy as jnp +from flax import nnx import optax import pathwaysutils @@ -48,6 +50,7 @@ from orbax import checkpoint as ocp from tunix.sft import metrics_logger, peft_trainer, profiler +from tunix.rl import reshard from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn @@ -141,12 +144,102 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ return trainer +def maybe_apply_lora(model, mesh, mt_config): + """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" + if not getattr(mt_config, "enable_lora", False): + return model + + import qwix + + if mt_config.lora_rank <= 0: + raise ValueError("enable_lora is True but lora_rank is not set to a positive value.") + if not mt_config.lora_module_path: + raise ValueError("enable_lora is True but lora_module_path is empty.") + + lora_kwargs = { + "module_path": mt_config.lora_module_path, + "rank": mt_config.lora_rank, + "alpha": mt_config.lora_alpha, + } + if mt_config.lora_tile_size is not None: + lora_kwargs["tile_size"] = mt_config.lora_tile_size + if mt_config.lora_weight_qtype is not None: + lora_kwargs["weight_qtype"] = mt_config.lora_weight_qtype + max_logging.log("QLoRA is enabled with weight_qtype=%s", mt_config.lora_weight_qtype) + else: + max_logging.log("LoRA is enabled.") + + lora_provider = qwix.LoraProvider(**lora_kwargs) + + batch_size = getattr(mt_config, "per_device_batch_size", 1) + seq_len = getattr(mt_config, "max_target_length", 1) + if batch_size <= 0 or seq_len <= 0: + raise ValueError( + "per_device_batch_size and max_target_length must be positive when LoRA is enabled." + ) + + devices_data_fsdp = 1 + if mesh is not None: + devices_data_fsdp = mesh.shape.get("data", 1) * mesh.shape.get("fsdp", 1) + + dummy_bs = (max(batch_size, devices_data_fsdp) + devices_data_fsdp - 1) // devices_data_fsdp + dummy_bs *= devices_data_fsdp + + decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + decoder_positions = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32), (dummy_bs, seq_len)) + + lora_model = qwix.apply_lora_to_model( + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + ) + + if mesh is not None: + lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) + + def _key_to_str(key): + if isinstance(key, str): + return key + if hasattr(key, "name"): + return str(key.name) + return str(key) + + lora_state = nnx.state(lora_model, nnx.LoRAParam) + lora_leaves = jax.tree_util.tree_leaves(lora_state) + lora_count = len(lora_leaves) + + if lora_count == 0: + full_state = nnx.state(lora_model) + paths_and_leaves, _ = jax.tree_util.tree_flatten_with_path(full_state) + for path, _ in paths_and_leaves: + path_str = "/".join(_key_to_str(k) for k in path).lower() + if "lora" in path_str or "adapter" in path_str: + lora_count += 1 + + if lora_count == 0: + module_paths = [] + for path, _ in lora_model.iter_modules(): + module_paths.append("/".join(str(p) for p in path)) + if len(module_paths) >= 50: + break + max_logging.log( + f"LoRA module_path='{mt_config.lora_module_path}' did not match any weights. " + f"Sample module paths: {module_paths}" + ) + raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") + max_logging.log("LoRA verification: found %d LoRA parameter entries.", lora_count) + + return lora_model + + def setup_trainer_state(mt_config, goodput_recorder=None): """Set up prerequisites for training loop.""" tunix_config = get_tunix_config(mt_config) with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + model = maybe_apply_lora(model, mesh, mt_config) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) From 4a3396ae8ccda1d301142576eebf85a9f93096d6 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Thu, 5 Feb 2026 09:44:00 +0000 Subject: [PATCH 02/12] linen decoder support --- .../trainers/post_train/sft/train_sft.py | 363 +++++++++++++++--- 1 file changed, 314 insertions(+), 49 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 43efd1ff1d..064c21d51e 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -15,11 +15,11 @@ """ SFT training script that calls a trainer in Tunix to run SFT on a MaxText model using `HuggingFaceH4/ultrachat_200k` dataset. The configurations for the dataset -are defined inside `src/maxtext/configs/post_train/sft.yml`. +are defined inside `src/MaxText/configs/sft.yml`. Example command: Training & Evaluation: - python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ @@ -27,7 +27,7 @@ eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 Training: - python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ @@ -50,10 +50,12 @@ from orbax import checkpoint as ocp from tunix.sft import metrics_logger, peft_trainer, profiler +from tunix.sft import utils as tunix_sft_utils from tunix.rl import reshard -from maxtext.configs import pyconfig -from maxtext.trainers.pre_train.train import loss_fn +from MaxText import optimizers +from MaxText import pyconfig +from MaxText.train import loss_fn from maxtext.common.goodput import ( GoodputEvent, RECORD_JOB_END_TIME, @@ -63,7 +65,6 @@ maybe_record_goodput, record_goodput, ) -from maxtext.optimizers import optimizers from maxtext.trainers.post_train.sft import hooks from maxtext.utils import max_utils from maxtext.utils import max_logging @@ -144,18 +145,14 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ return trainer -def maybe_apply_lora(model, mesh, mt_config): - """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" - if not getattr(mt_config, "enable_lora", False): - return model - - import qwix - +def _validate_lora_config(mt_config): if mt_config.lora_rank <= 0: raise ValueError("enable_lora is True but lora_rank is not set to a positive value.") if not mt_config.lora_module_path: raise ValueError("enable_lora is True but lora_module_path is empty.") + +def _build_lora_provider(mt_config, qwix): lora_kwargs = { "module_path": mt_config.lora_module_path, "rank": mt_config.lora_rank, @@ -165,12 +162,200 @@ def maybe_apply_lora(model, mesh, mt_config): lora_kwargs["tile_size"] = mt_config.lora_tile_size if mt_config.lora_weight_qtype is not None: lora_kwargs["weight_qtype"] = mt_config.lora_weight_qtype - max_logging.log("QLoRA is enabled with weight_qtype=%s", mt_config.lora_weight_qtype) + max_logging.log( + "QLoRA configured: module_path=%s rank=%s alpha=%s weight_qtype=%s tile_size=%s" + % ( + mt_config.lora_module_path, + mt_config.lora_rank, + mt_config.lora_alpha, + mt_config.lora_weight_qtype, + mt_config.lora_tile_size, + ) + ) else: - max_logging.log("LoRA is enabled.") + max_logging.log( + "LoRA configured: module_path=%s rank=%s alpha=%s tile_size=%s" + % ( + mt_config.lora_module_path, + mt_config.lora_rank, + mt_config.lora_alpha, + mt_config.lora_tile_size, + ) + ) + return qwix.LoraProvider(**lora_kwargs) + + +def _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types): + def _dot_general_with_3d( + self, + lhs, + rhs, + dimension_numbers, + precision=None, + preferred_element_type=None, + out_sharding=None, + ): + res = qwix_ptq.PtqProvider.dot_general( + self, + lhs, + rhs, + dimension_numbers, + precision, + preferred_element_type, + out_sharding=out_sharding, + ) + + rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) + if not isinstance(rule, qwix_lora.LoraRule): + return res + + weight_name = qwix_flax_util.find_param(rhs, qwix_lora.ptq.WithAux) + if weight_name is None: + return res + + if ( + len(rhs.shape) == 3 + and tuple(dimension_numbers[0][1]) == (0,) + and not dimension_numbers[1][1] + ): + lora_params = qwix_lora._get_or_create_lora_params( + name=weight_name, + rule=rule, + a_shape=(rhs.shape[0], rule.rank), + b_shape=(rule.rank, rhs.shape[1] * rhs.shape[2]), + a_sharding_transpose=(0, None), + b_sharding_transpose=(None, 1), + ) + lora_a, lora_b = lora_params[:2] + if rule.dropout > 0: + lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) + lora_b = jnp.reshape(lora_b, (rule.rank, rhs.shape[1], rhs.shape[2])) + delta = jnp.einsum("...k,kr->...r", lhs, lora_a) + delta = jnp.einsum("...r,rnm->...nm", delta, lora_b) + return res + delta * (rule.alpha / rule.rank) + + if ( + len(rhs.shape) == 3 + and tuple(dimension_numbers[0][1]) == (0, 1) + and not dimension_numbers[1][1] + ): + k = rhs.shape[0] * rhs.shape[1] + lora_params = qwix_lora._get_or_create_lora_params( + name=weight_name, + rule=rule, + a_shape=(k, rule.rank), + b_shape=(rule.rank, rhs.shape[2]), + a_sharding_transpose=(0, None), + b_sharding_transpose=(None, 1), + ) + lora_a, lora_b = lora_params[:2] + if rule.dropout > 0: + lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) + contract_axes = tuple(dimension_numbers[0][0]) + lhs_perm = [i for i in range(lhs.ndim) if i not in contract_axes] + list(contract_axes) + lhs_trans = jnp.transpose(lhs, lhs_perm) + lhs_shape = lhs_trans.shape + lhs_flat = jnp.reshape(lhs_trans, lhs_shape[:-len(contract_axes)] + (k,)) + delta = jnp.einsum("...k,kr->...r", lhs_flat, lora_a) + delta = jnp.einsum("...r,rm->...m", delta, lora_b) + return res + delta * (rule.alpha / rule.rank) + + return qwix_lora.LoraProvider.dot_general( + self, + lhs, + rhs, + dimension_numbers, + precision, + preferred_element_type, + out_sharding=out_sharding, + ) + + lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) + + +def _patch_qwix_find_param(qwix_flax_util): + if getattr(qwix_flax_util, "_maxtext_find_param_patched", False): + return + + original_find_param = qwix_flax_util.find_param + + def _safe_find_param(x, ptq_array_type=None): + module = qwix_flax_util.get_current_module() + candidates = {} + + # 1) Pure NNX: scan attributes for nnx.Params / ptq arrays. + if isinstance(module, nnx.Module): + array_types = (nnx.Param,) if ptq_array_type is None else (nnx.Param, ptq_array_type) + for name, node in module.__dict__.items(): + if isinstance(node, array_types): + value = getattr(node, "value", None) + if value is None: + try: + value = qwix_flax_util.unbox(node) + except Exception: + continue + candidates[name] = value + + + + else: + return original_find_param(x, ptq_array_type) + + candidates_by_id = {id(c): n for n, c in candidates.items()} + + if id(x) in candidates_by_id: + return candidates_by_id[id(x)] + + if isinstance(x, jax.core.Tracer) and hasattr(x, "parent"): + while True: + if id(x) in candidates_by_id: + return candidates_by_id[id(x)] + if x.parent and len(x.parent.in_tracers) == 1: + x = x.parent.in_tracers[0] + elif id(const := x.get_const()) in candidates_by_id: + return candidates_by_id[id(const)] + else: + return None + + if not hasattr(x, "shape"): + return None + candidates = {n: c for n, c in candidates.items() if getattr(c, "shape", None) == x.shape} + if len(candidates) > 2: + raise ValueError(f"Multiple candidate params found: {candidates.keys()}") + if len(candidates) == 1: + return list(candidates.keys())[0] + + return None + + qwix_flax_util.find_param = _safe_find_param + qwix_flax_util._maxtext_find_param_patched = True + + +def _patch_with_sharding_constraint(): + if getattr(jax.lax, "_maxtext_with_sharding_constraint_patched", False): + return + + jax.lax._original_with_sharding_constraint = jax.lax.with_sharding_constraint - lora_provider = qwix.LoraProvider(**lora_kwargs) + def _safe_with_sharding_constraint(x, sharding, *args, **kwargs): + def _safe_leaf_fn(x_leaf, s_leaf): + try: + spec = getattr(s_leaf, "spec", s_leaf) + if hasattr(spec, "__len__"): + ndim = getattr(x_leaf, "ndim", None) + if ndim is not None and len(spec) > ndim: + return x_leaf + except Exception: + pass + return jax.lax._original_with_sharding_constraint(x_leaf, s_leaf, *args, **kwargs) + return jax.tree_util.tree_map(_safe_leaf_fn, x, sharding) + + jax.lax.with_sharding_constraint = _safe_with_sharding_constraint + jax.lax._maxtext_with_sharding_constraint_patched = True + + +def _prepare_dummy_inputs(mt_config, mesh): batch_size = getattr(mt_config, "per_device_batch_size", 1) seq_len = getattr(mt_config, "max_target_length", 1) if batch_size <= 0 or seq_len <= 0: @@ -187,37 +372,76 @@ def maybe_apply_lora(model, mesh, mt_config): decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) decoder_positions = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32), (dummy_bs, seq_len)) - - lora_model = qwix.apply_lora_to_model( - model, - lora_provider, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - ) - - if mesh is not None: - lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) - - def _key_to_str(key): - if isinstance(key, str): - return key - if hasattr(key, "name"): - return str(key.name) - return str(key) - - lora_state = nnx.state(lora_model, nnx.LoRAParam) - lora_leaves = jax.tree_util.tree_leaves(lora_state) - lora_count = len(lora_leaves) - - if lora_count == 0: - full_state = nnx.state(lora_model) - paths_and_leaves, _ = jax.tree_util.tree_flatten_with_path(full_state) - for path, _ in paths_and_leaves: - path_str = "/".join(_key_to_str(k) for k in path).lower() - if "lora" in path_str or "adapter" in path_str: - lora_count += 1 - - if lora_count == 0: + return decoder_input_tokens, decoder_positions + +def _precreate_lora_params(target_model, lora_provider, qwix_flax_util, qwix_lora, math, re, types): + rules = getattr(lora_provider, "_rules", []) + if not rules: + return + + for path, module in target_model.iter_modules(): + module_path = "/".join(map(str, path)) + + for rule in rules: + if rule.op_names and "dot_general" not in rule.op_names: + continue + + kernel_tensor = None + in_rank = 0 + out_rank = 0 + found_param_name = "kernel" + + # Case A: Pure NNX (Standard attributes) + if hasattr(module, "kernel") and hasattr(module, "in_features_shape"): + try: + kernel_tensor = qwix_flax_util.unbox(module.kernel) + in_rank = len(module.in_features_shape) + out_rank = len(module.out_features_shape) + except Exception: + kernel_tensor = None + + + + if kernel_tensor is None: + continue + + # Closure to define LoRA A and B shapes and sharding + def _init_for_module(self, k_tensor=kernel_tensor, i_r=in_rank, o_r=out_rank, p_name=found_param_name): + kernel_shape = getattr(k_tensor, "shape", ()) + extra_rank = max(0, len(kernel_shape) - (i_r + o_r)) + + prefix_shape = kernel_shape[:extra_rank] + in_shape = kernel_shape[extra_rank : extra_rank + i_r] + out_shape = kernel_shape[extra_rank + i_r :] + + in_size = int(math.prod(in_shape)) + out_size = int(math.prod(out_shape)) + + if in_size <= 0 or out_size <= 0: + return + + a_shape = prefix_shape + (in_size, rule.rank) + b_shape = prefix_shape + (rule.rank, out_size) + + prefix_axes = tuple(range(extra_rank)) + a_sharding_transpose = prefix_axes + (None,) + b_sharding_transpose = prefix_axes + (None,) + + qwix_lora._get_or_create_lora_params( + name=p_name, + rule=rule, + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, + ) + + types.MethodType(_init_for_module, module)() + + +def _verify_lora_parameters(lora_model, mt_config): + is_lora_enabled = tunix_sft_utils.is_lora_enabled(lora_model) + if not is_lora_enabled: module_paths = [] for path, _ in lora_model.iter_modules(): module_paths.append("/".join(str(p) for p in path)) @@ -228,8 +452,49 @@ def _key_to_str(key): f"Sample module paths: {module_paths}" ) raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") - max_logging.log("LoRA verification: found %d LoRA parameter entries.", lora_count) + max_logging.log("LoRA verification: tunix_sft_utils.is_lora_enabled=True") + + +def maybe_apply_lora(model, mesh, mt_config): + """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" + # Skip Qwix LoRA if MaxText LoRA adapters are loaded + if hasattr(mt_config, 'lora_input_adapters_path') and mt_config.lora_input_adapters_path: + max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") + return model + + if not getattr(mt_config, "enable_lora", False): + return model + + import qwix + import math + import re + import qwix._src.flax_util as qwix_flax_util + import qwix._src.providers.lora as qwix_lora + import qwix._src.providers.ptq as qwix_ptq + import types + + _validate_lora_config(mt_config) + lora_provider = _build_lora_provider(mt_config, qwix) + + _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types) + _patch_qwix_find_param(qwix_flax_util) + _patch_with_sharding_constraint() + + decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh) + lora_model = qwix.apply_lora_to_model( + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + skip_nnx_init=True, + ) + + _precreate_lora_params(lora_model, lora_provider, qwix_flax_util, qwix_lora, math, re, types) + + if mesh is not None: + lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) + _verify_lora_parameters(lora_model, mt_config) return lora_model @@ -306,4 +571,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + app.run(main) \ No newline at end of file From 32b1646fb5f4a555cf3f1784ee72a9f77170df62 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Fri, 13 Feb 2026 04:27:18 +0000 Subject: [PATCH 03/12] feat: add HF LoRA adapter support and skip Qwix LoRA --- .../checkpoint_conversion/lora_to_maxtext.py | 311 ++++++++++++++++++ src/maxtext/configs/post_train/sft.yml | 4 + src/maxtext/configs/types.py | 1 + 3 files changed, 316 insertions(+) create mode 100644 src/maxtext/checkpoint_conversion/lora_to_maxtext.py diff --git a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py new file mode 100644 index 0000000000..54eee25dd0 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py @@ -0,0 +1,311 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a HuggingFace LoRA adapter to MaxText LoRA adapter format. + +Key Parameters (to be set in the config file or as command-line overrides): + model_name: (Required) The name of the model (e.g., "llama3.1-8b"). + base_output_directory: (Required) The directory where the MaxText LoRA adapter + will be saved. Can be set in config file or as command-line override. + hf_lora_adapter_path: (Required) Path to the HF LoRA adapter directory or HuggingFace repo ID. + scan_layers: (bool) Whether the MaxText model uses scanned layers. + This must match the training configuration. + +Environment Variables: + HF_AUTH_TOKEN: (Optional) HuggingFace authentication token if needed for adapter. + +Example Usage: + To convert HF LoRA to MaxText adapter: + + python src/MaxText/utils/ckpt_conversion/apply_lora.py \ + MaxText/configs/sft.yml model_name="llama3.1-8b" \ + hf_lora_adapter_path="username/lora-adapter-repo" \ + base_output_directory="/path/to/output/directory" \ + scan_layers=False +""" + +import argparse +import os +import sys +import json +from typing import Sequence +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from safetensors import safe_open +from huggingface_hub import hf_hub_download +from transformers import AutoConfig +from etils import epath +from flax import nnx + +from orbax import checkpoint as ocp +from MaxText import pyconfig +from MaxText.common_types import MODEL_MODE_TRAIN +from MaxText.layers import models, quantizations +from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING +from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS +from maxtext.utils import max_logging +from maxtext.utils import maxtext_utils +from maxtext.utils import max_utils +from absl import logging + + +def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: + """Load HF LoRA adapter weights directly from safetensors files.""" + max_logging.log(f"Loading HF LoRA adapter from {adapter_path}") + + # Check adapter compatibility + adapter_config = None + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + config_file = adapter_dir / "adapter_config.json" + if config_file.exists(): + with open(config_file, 'r') as f: + adapter_config = json.load(f) + else: + # HF Hub repo + try: + config_file = hf_hub_download( + adapter_path, + "adapter_config.json", + token=os.environ.get("HF_AUTH_TOKEN") + ) + with open(config_file, 'r') as f: + adapter_config = json.load(f) + except Exception: + max_logging.log("Warning: Could not load adapter_config.json from HF Hub") + + if adapter_config: + base_model = adapter_config.get("base_model_name_or_path") + # if base_model and base_model.replace("-Instruct", "") != hf_model_id.replace("-Instruct", ""): + # raise ValueError(f"Adapter base model '{base_model}' does not match expected model '{hf_model_id}'") + max_logging.log(f"Adapter compatible with model {hf_model_id}") + + # Handle both local paths and HF Hub paths + if os.path.isdir(adapter_path): + # Local directory + adapter_dir = epath.Path(adapter_path) + adapter_files = list(adapter_dir.glob("*.safetensors")) + if not adapter_files: + adapter_files = list(adapter_dir.glob("*.bin")) + if not adapter_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = adapter_files[0] + else: + # Assume it's a HF Hub repo ID + try: + # Try to download the adapter config to get the file list + from huggingface_hub import list_repo_files + files = list_repo_files(adapter_path, token=os.environ.get("HF_AUTH_TOKEN")) + safetensor_files = [f for f in files if f.endswith('.safetensors')] + if not safetensor_files: + bin_files = [f for f in files if f.endswith('.bin')] + if not bin_files: + raise ValueError(f"No LoRA adapter files found in {adapter_path}") + adapter_file = bin_files[0] + else: + adapter_file = safetensor_files[0] + + # Download the adapter file + adapter_file = hf_hub_download( + adapter_path, + adapter_file, + token=os.environ.get("HF_AUTH_TOKEN") + ) + except Exception as e: + raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") + + # Load the adapter weights + if adapter_file.endswith('.safetensors'): + with safe_open(adapter_file, framework="numpy") as f: + lora_weights = {k: f.get_tensor(k) for k in f.keys()} + else: + # For .bin files, we'd need torch.load, but safetensors is preferred + raise ValueError(f"Unsupported adapter file format: {adapter_file}") + + max_logging.log(f"Loaded {len(lora_weights)} LoRA parameters from adapter") + return lora_weights + + +def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict, config) -> str: + """Convert HF LoRA key to MaxText parameter path using the mapping from to_maxtext.py.""" + # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight + + # 1. Clean up LoRA suffixes to get the base module path + # e.g. ...q_proj.lora_A.weight -> ...q_proj + hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "") + hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") + + # 2. Handle prefix. Expected target is usually "model.layers..." + # Input could be "base_model.model.model.layers..." or "base_model.model.layers..." + if hf_param_key.startswith("base_model.model."): + hf_param_key = hf_param_key[len("base_model.model."):] + + # 3. Search for the corresponding MaxText key + for mt_key, hf_keys in param_mapping.items(): + if isinstance(hf_keys, list): + for hf_k in hf_keys: + # Match disregarding .weight suffix on the base model param + if hf_k.replace(".weight", "") == hf_param_key: + return mt_key + elif isinstance(hf_keys, str): + if hf_keys.replace(".weight", "") == hf_param_key: + return mt_key + + return None + + +def convert_lora_to_maxtext_adapter(config, lora_weights: dict, output_path: str, hf_model_id: str): + """Converts HF LoRA weights to MaxText adapter format without merging.""" + + # 1. Setup Mesh and Model Structure (Abstractly) + devices_array = maxtext_utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, axis_names=config.mesh_axes) + quant = quantizations.configure_quantization(config) + + # Initialize rngs for model creation + rngs = nnx.Rngs(params=jax.random.PRNGKey(0), dropout=jax.random.PRNGKey(1)) + + # Use the model definition to understand the target parameter paths + model = models.Transformer(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN, rngs=rngs) + + hf_token = config.hf_access_token + + # Get the parameter mapping (MT -> HF) + model_key = config.model_name + if "-Instruct" in model_key: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_key = model_key.replace("-Instruct", "") + hf_config_obj = AutoConfig.from_pretrained(hf_model_id, token=hf_token) + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) + + # 2. Initialize an empty dictionary for the MaxText Adapter + mt_adapter_tree = {} + mapped_count = 0 + + # 3. Map HF LoRA weights to MaxText keys + for hf_key, weight in lora_weights.items(): + # Identify the MaxText path for this specific HF weight + mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf, config) + + if mt_key: + # Determine if this is the 'A' or 'B' matrix + suffix = "lora_A" if "lora_A" in hf_key else "lora_B" + + # Construct a nested dictionary path in mt_adapter_tree + # MaxText expects: { 'decoder': { 'layers': { '0': { 'query': { 'lora_A': ... } } } } } + parts = mt_key.split("/") + current = mt_adapter_tree + for part in parts: + if part not in current: + current[part] = {} + current = current[part] + + # Convert weight to JAX array and store + current[suffix] = jnp.array(weight) + mapped_count += 1 + else: + max_logging.log(f"Warning: Could not map HF LoRA key {hf_key} to MaxText key") + + max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") + + # 4. Save as a standalone adapter checkpoint + max_logging.log(f"Saving MaxText LoRA adapter to {output_path}") + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + ckptr.save(epath.Path(output_path), mt_adapter_tree) + + max_logging.log("LoRA adapter conversion completed successfully") + + +def main(args: Sequence[str]) -> None: + # Set logging to INFO level to see max_logging.log messages + logging.set_verbosity(logging.INFO) + + # Check if the user is using an Instruct version. If so, use the base model architecture + original_model_name = None + for i, arg in enumerate(args): + if arg.startswith("model_name="): + model_name_arg = args[i].split("=")[1] + # Remove quotes if present + model_name_arg = model_name_arg.strip("'").strip('"') + original_model_name = model_name_arg + + if "-Instruct" in model_name_arg: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_name_arg = model_name_arg.replace("-Instruct", "") + args[i] = f"model_name={model_name_arg}" + break + + # Initialize maxtext config + config = pyconfig.initialize(args) + + if not hasattr(config, 'hf_lora_adapter_path') or not config.hf_lora_adapter_path: + raise ValueError("hf_lora_adapter_path must be specified") + + # Determine HF model ID and check if supported + hf_model_id = HF_IDS.get(config.model_name) + if hf_model_id is None: + raise ValueError(f"Model '{config.model_name}' is not supported. Use a supported model_name from HF_IDS.") + + if not hasattr(config, 'base_output_directory') or not config.base_output_directory: + raise ValueError("base_output_directory must be specified (in config file or as command-line argument)") + + output_dir = config.base_output_directory + + # Use original model name for output path + model_name_for_path = original_model_name or config.model_name + adapter_name = os.path.basename(config.hf_lora_adapter_path) + full_output_path = os.path.join(output_dir, model_name_for_path, adapter_name) + + os.makedirs(os.path.dirname(full_output_path), exist_ok=True) + + if os.path.exists(full_output_path): + import shutil + max_logging.log(f"Output directory {full_output_path} exists. Removing it to allow Orbax to save.") + shutil.rmtree(full_output_path) + + # Load LoRA adapter and check compatibility + lora_weights = load_hf_lora_adapter(config.hf_lora_adapter_path, hf_model_id) + + # Convert LoRA to MaxText adapter format and save + convert_lora_to_maxtext_adapter(config, lora_weights, full_output_path, hf_model_id) + + # Verify output was created + if not os.path.exists(full_output_path): + raise RuntimeError(f"Failed to create output directory {full_output_path}") + + +if __name__ == "__main__": + # Argument parsing similar to to_maxtext.py + parser = argparse.ArgumentParser() + parser.add_argument( + "--simulated_cpu_devices_count", type=int, required=False, default=16 + ) + + # Parse local arguments + local_args, remaining_args = parser.parse_known_args() + + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set jax environment + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" + + main(model_args) \ No newline at end of file diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 6d040c8535..e71447e838 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -31,6 +31,10 @@ lora_module_path: "" lora_weight_qtype: null lora_tile_size: null +# -------------- HF LoRA Adapter -------------- +# HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors +hf_lora_adapter_path: "" + # -------------- HF pipeline -------------- dataset_type: hf hf_path: 'HuggingFaceH4/ultrachat_200k' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 7369782416..140e0243a1 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -287,6 +287,7 @@ class Checkpointing(BaseModel): load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.") lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") + hf_lora_adapter_path: PathStr = Field("", description="HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors.") load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") From 9a89716f05ba5cc39edad94e961e4ef7a6629dbc Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Tue, 3 Feb 2026 11:45:32 +0000 Subject: [PATCH 04/12] Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX --- src/maxtext/configs/base.yml | 5 +- src/maxtext/configs/types.py | 1 + src/maxtext/layers/multi_token_prediction.py | 22 +- src/maxtext/layers/nnx_decoders.py | 967 +++++++++++++++++++ src/maxtext/models/gemma3.py | 2 - src/maxtext/models/models.py | 97 +- tests/checkpoint_compare.py | 179 ++++ tests/unit/multi_token_prediction_test.py | 25 +- 8 files changed, 1248 insertions(+), 50 deletions(-) create mode 100644 src/maxtext/layers/nnx_decoders.py create mode 100644 tests/checkpoint_compare.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index d97b2f3256..b55a54e4e8 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -718,7 +718,7 @@ autoregressive_decode_assert: "" # For nsys profiler, pass the training command to nsys command # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} -profiler: "" # Supported profiler: '', xplane, nsys +profiler: "xplane" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. upload_all_profiler_results: False # Skip first n steps for profiling, to omit things like compilation and to give @@ -1074,7 +1074,8 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false +enable_nnx: True +pure_nnx_decoder: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 140e0243a1..c6b4be3ce6 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -778,6 +778,7 @@ class HardwareAndMesh(BaseModel): enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.") optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") + pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") class LayoutAndSharding(BaseModel): diff --git a/src/maxtext/layers/multi_token_prediction.py b/src/maxtext/layers/multi_token_prediction.py index c9647b8368..d97a8e3592 100644 --- a/src/maxtext/layers/multi_token_prediction.py +++ b/src/maxtext/layers/multi_token_prediction.py @@ -108,12 +108,22 @@ def __init__( rngs=rngs, ) # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. - mtp_transformer_layer = transformer_layer_module( - config=cfg, - mesh=mesh, - model_mode=MODEL_MODE_TRAIN, - name=f"mtp_{k}_transformer_layer", - ) + if cfg.pure_nnx_decoder: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + rngs=rngs, + ) + else: + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", + ) + self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) # ToNNX requires explicit initialization with sample inputs for proper parameter setup. diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py new file mode 100644 index 0000000000..647030e0c2 --- /dev/null +++ b/src/maxtext/layers/nnx_decoders.py @@ -0,0 +1,967 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for decoder layers""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +import functools +from typing import Any +import warnings +import inspect + +import jax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx +from flax.nnx import wrappers as nnx_wrappers + +from maxtext.common.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.layers import linears +from maxtext.layers import mhc +from maxtext.layers import normalizations +from maxtext.layers import initializers +from maxtext.layers import quantizations +from maxtext.layers.attentions import Attention +from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.embeddings import Embed, attend_on_embedding, PositionalEmbedding +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.models import ( + deepseek, + deepseek_batchsplit, + gemma, + gemma2, + gemma3, + gpt3, + gpt_oss, + llama2, + llama4, + mistral, + mixtral, + qwen3, + simple_layer, + olmo3, +) +from maxtext.multimodal import utils as mm_utils +from maxtext.utils.sharding import create_sharding +from maxtext.utils import max_logging +from maxtext.utils import sharding +from maxtext.utils import maxtext_utils +from maxtext.inference import page_manager + +# ------------------------------------------------------------------------------ +# The network: Decoder Definitions +# ------------------------------------------------------------------------------ + + +class NNXDecoderLayer(nnx.Module): + """ + Transformer decoder layer converted to NNX. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + name: str = "decoder_layer", + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + + cfg = self.config + + self.pre_self_attention_norm = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + self.self_attention = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + use_mrope=cfg.use_mrope, + mrope_section=cfg.mrope_section, + model_mode=model_mode, + ) + + self.mlp = linears.MlpBlock( + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + model_mode=model_mode, + config=cfg, + quant=self.quant, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + debug_sharding=cfg.debug_sharding, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") + else: + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.pre_self_attention_norm(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + mlp_lnx = self.mlp(lnx, deterministic=deterministic) + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + + next_layer_addition = mlp_lnx + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + + if cfg.record_internal_nn_metrics: + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) + self.sow( + nnx.Intermediate, + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): + """Process deepstack visual embeddings by adding them to hidden states at visual token positions. + + Args: + hidden_states: [batch, seq_len, hidden_dim] decoder hidden states + bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions + visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer + + Returns: + Updated hidden_states with visual features added at visual positions + """ + # Expand mask to [batch, seq_len, 1] for broadcasting + mask_expanded = bidirectional_mask[:, :, jnp.newaxis] + # Use cumsum to map each True position in mask to its index in visual_embeds + visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed + + # Gather visual tokens: for each position, get the corresponding visual token + batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1] + visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden] + + # Only add where mask is True: hidden_states += visual_embeds * mask + hidden_states = hidden_states + visual_embeds_scattered * mask_expanded + return hidden_states + + +class NNXDecoder(nnx.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + decoder_block_classes = self.get_decoder_layers() + + self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + ) + + if config.trainable_position_size > 0: + self.position_embedder = Embed( + num_embeddings=config.trainable_position_size, + num_features=config.emb_dim, + dtype=config.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=config, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) + + if not config.logits_via_embedding: + self.logits_dense = linears.DenseGeneral( + in_features_shape=config.emb_dim, + out_features_shape=config.vocab_size, + weight_dtype=config.weight_dtype, + dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=config.parameter_memory_host_offload, + rngs=rngs, + ) + + self.scanned_layers = None + self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK + self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 + + if self.config.scan_layers: + if self.is_deepseek: + assert len(decoder_block_classes) == 2 + dense_cls, moe_cls = decoder_block_classes + + num_dense = config.first_num_dense_layers + self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) + + num_moe = config.num_decoder_layers - config.first_num_dense_layers + + self.moe_stack = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) + elif self.is_gemma3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + layer_kwargs = {"num_of_layers": attention_pattern_length} + + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + + RemattedGemma3Block = gemma3.Gemma3ScannableBlock + + if scan_length > 0: + self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) + self.layers_remainder = RemattedGemma3Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) # pytype: disable=wrong-keyword-args + else: + layer_cls = decoder_block_classes[0] + num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, + } + + self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) + else: + self.layers = nnx.List([]) + + if self.is_deepseek: + dense_cls, moe_cls = decoder_block_classes + for i in range(config.first_num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layer", i) + for i in range(config.num_decoder_layers - config.first_num_dense_layers): + self._create_and_register_layer(moe_cls, rngs, "moe_layer", i) + else: + layer_cls = decoder_block_classes[0] + + for lyr in range(config.num_decoder_layers): + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + } + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + layer_kwargs = {"layer_idx": lyr} + elif config.decoder_block == DecoderBlockType.GPT_OSS: + layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.OLMO3: + layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + + self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + + def _create_and_register_layer(self, layer_cls, rngs, base_name, i, **layer_kwargs): + attr_name = f"{base_name}_{i}" + layer = self._create_single_layer(layer_cls, rngs, **layer_kwargs) + setattr(self, attr_name, layer) + self.layers.append(layer) + + def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): + """Helper to create a single layer (Linen or NNX).""" + if issubclass(decoder_layer_class, nnx.Module): + return decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs, **kwargs + ) + else: + layer_linen = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, **kwargs + ) + return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) + + def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): + """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + + def create_layer_fn(rng): + layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs + ) + + return layer + + # Workaround for Deepseek MTP test failure. + # TODO: Handle this properly. + try: + forked_rngs = rngs.fork(split=length) + + except: # pylint: disable=bare-except + pass + + out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) + layers_vmapped = nnx.vmap( + create_layer_fn, + in_axes=0, + out_axes=out_axes, + axis_name="layers", + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(forked_rngs) + + return layers_vmapped + + def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): + """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" + + graphdef, state = nnx.split(layer) + + def pure_layer_fn(state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out = merged_layer(y_in, **kwargs) + return out, nnx.state(merged_layer) + + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) + nnx.update(layer, new_state) + + return out + + def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): + """Runs the layer stack using nnx.scan.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + graphdef, params, state = nnx.split( + layers, nnx.Param, ... + ) # state: the mutable state we carry (KV cache, RNGs, etc.) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + # Move scan_axis to 0 so scan can iterate over it + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + layer_cls = layers.__class__ + sig = inspect.signature(layer_cls.__call__) + valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + layer_cls = layers.__class__ # Access the underlying class + sig = inspect.signature(layer_cls.__call__) + # Filter kwargs to only include keys that exist in the layer's signature + valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + def layer_fn(carry, scanned_vars): + # Unpack the sliced variables for THIS layer + current_params, current_state = scanned_vars + + # Merge using the SLICED state + layer = nnx.merge(graphdef, current_params, current_state) + + # Run the layer (Filter kwargs if using the solution from previous turn) + layer_out = layer(carry, *args, **valid_kwargs) + + new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out + + # Extract the updated state to return it + # _, new_current_state = nnx.split(layer, nnx.Param, ...) + new_current_state = nnx.state(layer) + return new_carry, new_current_state + + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + + if scan_axis != 0: + scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) + scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) + scanned_state = nnx.State.merge(scanned_params, scanned_other) + + return final_carry, nnx.merge(graphdef, scanned_state) + + def get_decoder_layers(self): + """Retrieves decoder layer classes based on config using a dictionary lookup.""" + cfg = self.config + + def get_scannable(normal_cls, scannable_cls): + return [scannable_cls] if cfg.scan_layers else [normal_cls] + + def get_deepseek(): + if cfg.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + + layer_map = { + DecoderBlockType.DEFAULT: [NNXDecoderLayer], + DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer], + DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer], + DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer], + DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], + DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], + DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], + DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], + DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], + DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], + DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], + DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], + DecoderBlockType.DEEPSEEK: get_deepseek(), + DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), + DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), + DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock), + DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock), + } + + if cfg.decoder_block not in layer_map: + raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") + + return layer_map[cfg.decoder_block] + + def minimal_policy(self, with_context=False, with_quantization=False): + """Helper for creating minimal checkpoint policies.""" + names = [ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ] + if with_context: + names.append("context") + if with_quantization: + names.append("quantization") + return jax.checkpoint_policies.save_only_these_names(*names) + + def get_remat_policy(self): + """Get remat policy for jax.checkpoint.""" + policy = None + cfg = self.config + if cfg.remat_policy != "none": + if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): + # save all + if cfg.remat_policy == "minimal_flash": + max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") + policy = self.minimal_policy(with_context=True) + elif cfg.remat_policy == "minimal": + # save all except context + policy = self.minimal_policy() + elif cfg.remat_policy == "minimal_with_quantization": + if cfg.scan_layers: + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) + policy = self.minimal_policy(with_context=False, with_quantization=True) + elif cfg.remat_policy == "minimal_with_context_and_quantization": + if cfg.scan_layers: + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) + policy = self.minimal_policy(with_context=True, with_quantization=True) + elif cfg.remat_policy == "save_dot_with_context_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "context", + "out_proj", + ) + elif cfg.remat_policy == "save_dot_except_mlpwi": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", + ) + elif cfg.remat_policy == "save_dot_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + ) + elif cfg.remat_policy == "save_qkv_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + ) + elif cfg.remat_policy == "qkv_proj_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "custom": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=cfg.tensors_on_device, + names_which_can_be_offloaded=cfg.tensors_to_offload, + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "save_out_proj": + policy = jax.checkpoint_policies.save_only_these_names("out_proj") + else: + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" + policy = None + return policy + + def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): + """get normalization layer (return type inherits from nn.Module)""" + if self.config.decoder_block in ( + DecoderBlockType.DEFAULT, + DecoderBlockType.LLAMA2, + DecoderBlockType.MISTRAL, + DecoderBlockType.MIXTRAL, + DecoderBlockType.DEEPSEEK, + DecoderBlockType.GEMMA, + DecoderBlockType.GEMMA2, + DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN3, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.GPT_OSS, + DecoderBlockType.SIMPLE, + DecoderBlockType.SIMPLE_MLP, + DecoderBlockType.LLAMA4, + DecoderBlockType.OLMO3, + ): + return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + elif self.config.decoder_block == DecoderBlockType.GPT3: + return functools.partial( + gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs + ) + elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: + return functools.partial( + normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode + ) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def _apply_embedding( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings=None, + bidirectional_mask=None, + image_masks=None, + audio_embeddings=None, + audio_masks=None, + ): + """Applies token and positional embeddings to the input tokens.""" + cfg = self.config + + y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + + if image_embeddings is not None and cfg.use_multimodal: + if cfg.model_name in [ + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama4-17b-16e", + "llama4-17b-128e", + "qwen3-omni-30b-a3b", + ]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=image_embeddings, + mask=bidirectional_mask, + token_masks=image_masks, + ) + # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed + else: + raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + + if audio_embeddings is not None and cfg.use_audio: + if cfg.model_name in ["qwen3-omni-30b-a3b"]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=audio_embeddings, + mask=audio_masks, + token_masks=None, + ) + else: + raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if cfg.use_untrainable_positional_embedding: + y += self.positional_embedding(y, decoder_positions) + + if cfg.trainable_position_size > 0 and self.position_embedder: + y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) + + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + else: + norm_out_sharding = None + + y = self.decoder_norm(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) # NNX call + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + ) + + if cfg.logits_via_embedding: + if isinstance(shared_embedding, nnx.Module): + embedding_table = shared_embedding.embedding.value + else: + embedding_table = shared_embedding.variables["params"]["embedding"] + if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): + embedding_table = embedding_table.unbox() + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = logits / cfg.final_logits_soft_cap + logits = jnp.tanh(logits) * cfg.final_logits_soft_cap + else: + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + + return logits + + def __call__( + self, + shared_embedding: Any, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + audio_embeddings: None | jnp.ndarray = None, + audio_masks: None | jnp.ndarray = None, + deepstack_visual_embeds: None | list[jnp.ndarray] = None, + ): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + + y = self._apply_embedding( + shared_embedding, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + audio_embeddings, + audio_masks, + ) + + mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate) + if cfg.mhc_expansion_rate > 1: + # (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim) + y = mhc_expand(y) + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + + layer_kwargs = {} + if cfg.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs["bidirectional_mask"] = bidirectional_mask + + if cfg.scan_layers: + if self.is_deepseek: + layer_kwargs = { + "previous_chunk": previous_chunk, + "page_state": page_state, + "slot": slot, + } + y, self.dense_layers = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, self.moe_stack = self._apply_layers_sequentially( + self.moe_stack, y, *layer_args, length=num_moe, **layer_kwargs + ) + elif self.is_gemma3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + y, self.layers = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=cfg.num_decoder_layers, **layer_kwargs + ) + else: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + # Hoisted function to preserve XLA cache ID + def pure_layer_fn(graphdef, state_in, y_in, kv_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) + return out_y, out_kv, nnx.state(merged_layer) + + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + + for lyr, layer in enumerate(self.layers): + graphdef, state = nnx.split(layer) + kv_cache = kv_caches[lyr] if kv_caches is not None else None + + y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) + nnx.update(layer, new_state) + + if kv_caches is not None and kv_cache is not None: + kv_caches[lyr] = kv_cache + + if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): + visual_embeds = deepstack_visual_embeds[lyr] + if bidirectional_mask is not None and visual_embeds is not None: + y = deepstack_process(y, bidirectional_mask, visual_embeds) + + assert isinstance(y, jax.Array) + + # After the final transformer layer, `y` holds the raw, un-normalized hidden state. + if cfg.mhc_expansion_rate > 1: + # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) + hidden_state = mhc_reduce(y) + else: + hidden_state = y + + # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. + if cfg.attention == "vllm_rpa": + logits = None + + # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # Instead, we keep track on the hidden states, which has smaller size compared to full logits + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + self.sow(nnx.Intermediate, "hidden_states", hidden_state) + + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches + + def _apply_gemma3_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ): + """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" + + cfg = self.config + + # Define the repeating pattern length and calculate how many full blocks to scan + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + layer_kwargs = {"bidirectional_mask": bidirectional_mask} + + # Apply the main scan over the full blocks + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + def pure_gemma_fn(graphdef, state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, _ = merged_layer(y_in, *layer_args, **layer_kwargs) + return out_y, nnx.state(merged_layer) + + checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) + + graphdef, state = nnx.split(self.layers_remainder) + y, new_state = checkpointed_gemma_fn(graphdef, state, y) + nnx.update(self.layers_remainder, new_state) + + return y + + +def decoder_as_linen( + config: Config, + mesh: Mesh, + rngs: nnx.Rngs, + model_mode: str, + quant: None | Quant = None, +): + """Creates a Decoder module.""" + module = nnx_wrappers.to_linen( + NNXDecoder, + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant, + name="decoder", + abstract_init=False, + metadata_fn=initializers.variable_to_logically_partitioned, + ) + return module diff --git a/src/maxtext/models/gemma3.py b/src/maxtext/models/gemma3.py index 588ffa6db2..630497e224 100644 --- a/src/maxtext/models/gemma3.py +++ b/src/maxtext/models/gemma3.py @@ -91,7 +91,6 @@ def __init__( batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) - self.pre_self_attention_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, @@ -198,7 +197,6 @@ def __call__( inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = self.pre_self_attention_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index cfd837c6c5..f4e751554d 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -18,13 +18,16 @@ from typing import Any -from flax import linen as nn -from flax import nnx import jax import jax.numpy as jnp from jax.sharding import Mesh + +from flax import linen as nn +from flax import nnx + from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN from maxtext.inference import page_manager +from maxtext.layers.nnx_decoders import NNXDecoder, decoder_as_linen from maxtext.layers import initializers from maxtext.layers import nnx_wrappers from maxtext.layers.decoders import Decoder @@ -85,7 +88,13 @@ def setup(self): ) self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None - self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + if cfg.pure_nnx_decoder: + self.decoder = decoder_as_linen( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0) + ) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. @@ -328,9 +337,11 @@ def __init__( ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None - - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + if cfg.pure_nnx_decoder: + self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) + else: + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) @@ -356,12 +367,13 @@ def __init__( else: dummy_attention_metadata = None - self.decoder.lazy_init( - shared_embedding=self.token_embedder, - decoder_input_tokens=dummy_decoder_input_tokens, - decoder_positions=dummy_decoder_positions, - attention_metadata=dummy_attention_metadata, - ) + if not cfg.pure_nnx_decoder: + self.decoder.lazy_init( + shared_embedding=self.token_embedder, + decoder_input_tokens=dummy_decoder_input_tokens, + decoder_positions=dummy_decoder_positions, + attention_metadata=dummy_attention_metadata, + ) # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: @@ -483,26 +495,47 @@ def __call__( if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections: mutable_collections.append("intermediates") - logits, hidden_state, kv_caches = self.decoder( - shared_embedding=self.token_embedder, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=not enable_dropout, - model_mode=model_mode, - previous_chunk=previous_chunk, - slot=slot, - page_state=page_state, - bidirectional_mask=bidirectional_mask, - image_embeddings=image_embeddings, - image_masks=encoder_image_masks, - audio_embeddings=audio_embeddings, - audio_masks=audio_masks, - kv_caches=kv_caches, - attention_metadata=attention_metadata, - deepstack_visual_embeds=deepstack_visual_embeds, - mutable=mutable_collections, - ) + if self.config.pure_nnx_decoder: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + else: + logits, hidden_state, kv_caches = self.decoder( + shared_embedding=self.token_embedder, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + slot=slot, + page_state=page_state, + bidirectional_mask=bidirectional_mask, + image_embeddings=image_embeddings, + image_masks=encoder_image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + kv_caches=kv_caches, + attention_metadata=attention_metadata, + deepstack_visual_embeds=deepstack_visual_embeds, + mutable=mutable_collections, + ) # Materialize hidden state when vocab tiling is enabled if self.config.num_vocab_tiling > 1: diff --git a/tests/checkpoint_compare.py b/tests/checkpoint_compare.py new file mode 100644 index 0000000000..112f524df0 --- /dev/null +++ b/tests/checkpoint_compare.py @@ -0,0 +1,179 @@ +"""Script for comparing parameters between two checkpoints.""" + +import jax +import jax.numpy as jnp +import orbax.checkpoint as ocp +from typing import Any, Dict, Sequence +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +from absl import app +from absl import flags + + +_LINEN_CKPT_PATH = flags.DEFINE_string( + "linen_ckpt_path", None, "Path to the Linen model checkpoint items directory.", required=True +) +_NNX_CKPT_PATH = flags.DEFINE_string( + "nnx_ckpt_path", None, "Path to the NNX model checkpoint items directory.", required=True +) + + +def load_checkpoint_params(path: str) -> Dict[str, Any]: + """Loads parameters from an Orbax checkpoint path.""" + print(f"Loading checkpoint from: {path}") + checkpointer = ocp.PyTreeCheckpointer() + restored_state = checkpointer.restore(path) + if restored_state is None: + raise ValueError(f"Failed to restore checkpoint from {path}") + if isinstance(restored_state, dict) and "params" in restored_state: + return restored_state["params"] + return restored_state + + +def transform_nnx_params(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Applies specific transformations with verbose logging matching original format.""" + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + print(f"TRANSPOSING: {key_str} with shape {leaf.shape}") + axes = (1, 0) + tuple(range(2, leaf.ndim)) + return jnp.transpose(leaf, axes=axes) + else: + if "token_embedder" in key_str: + print(f"SKIPPING Transpose: {key_str} because it is token_embedder") + else: + shape = getattr(leaf, "shape", "N/A") + print(f"SKIPPING Transpose: {key_str} with shape {shape} (ndim < 2)") + return leaf + + print("Applying transformations to NNX params...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]): + """Helper only used if structures differ.""" + flat_with_path, _ = tree_flatten_with_path(tree) + return {keystr(p): (getattr(l, "shape", "N/A"), str(getattr(l, "dtype", type(l).__name__))) for p, l in flat_with_path} + + +def print_structure_diff(params1, params2): + """Prints missing/added keys if structures differ.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + for k in sorted(keys2 - keys1): + print(f" + Added in NNX: {k}") + for k in sorted(keys1 - keys2): + print(f" - Missing in NNX: {k}") + + +def compare_params(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool: + """ + Compares two parameter trees (e.g., JAX/Flax PyTrees) for structural and numerical equality. + + This function performs a deep comparison of two PyTrees. It first + validates that both trees share the exact same structure. If successful, it iterates + through every leaf node to verify: + 1. Shapes match. + 2. Data types (dtypes) match. + 3. Numerical values are close (within `jnp.allclose` tolerances). + + Args: + params1: The first parameter dictionary or PyTree (e.g., a Linen model). + params2: The second parameter dictionary or PyTree (e.g., an NNX model). + + Returns: + bool: True if structure, shapes, types, and values all match; False otherwise. + """ + + if tree_structure(params1) != tree_structure(params2): + print("[] Tree structures differ.") + print_structure_diff(params1, params2) + return False + + print("[] Tree structures are the same.") + + all_match = True + + def _compare_leaf(path, x, y): + nonlocal all_match + key_str = keystr(path) + + try: + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + + if shape1 != shape2: + print(f"[{key_str}] SHAPE MISMATCH: {shape1} vs {shape2}") + all_match = False + return + + dtype1 = getattr(x, "dtype", type(x)) + dtype2 = getattr(y, "dtype", type(y)) + + if dtype1 != dtype2: + print(f"[{key_str}] DTYPE MISMATCH: {dtype1} vs {dtype2}") + all_match = False + return + + diff = x - y + abs_diff = jnp.abs(diff) + mean_diff_scalar = jnp.mean(abs_diff) + max_diff_scalar = jnp.max(abs_diff) + is_close_scalar = jnp.allclose(x, y) + + mean_diff = float(mean_diff_scalar) + max_diff = float(max_diff_scalar) + is_close = bool(is_close_scalar) + + print( + f"[{key_str}] " + f"Shape(Linen/NNX): {shape1} / {shape2} — " + f"Mean abs diff: {mean_diff:.2e}, " + f"Max abs diff: {max_diff:.2e}, " + f"AllClose: {is_close}" + ) + + if not is_close: + all_match = False + + except Exception as e: # pylint: disable=broad-exception-caught + print(f"[{key_str}] Error during comparison: {e}") + all_match = False + + tree_map_with_path(_compare_leaf, params1, params2) + + return all_match + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + linen_ckpt_path = _LINEN_CKPT_PATH.value + nnx_ckpt_path = _NNX_CKPT_PATH.value + + print(f"Linen Checkpoint Path: {linen_ckpt_path}") + print(f"NNX Checkpoint Path: {nnx_ckpt_path}") + + print("Loading Linen params...") + linen_params = load_checkpoint_params(linen_ckpt_path) + print("Loading NNX params...") + nnx_params = load_checkpoint_params(nnx_ckpt_path) + + if linen_params is not None and nnx_params is not None: + nnx_params_transformed = transform_nnx_params(nnx_params) + + print("\nComparing Linen params with Transformed NNX params...") + if compare_params(linen_params, nnx_params_transformed): + print("\nCheckpoints are considered the same (within np.allclose tolerance) after transformation!") + else: + print("\nCheckpoints DIFFER after transformation.") + else: + print("Failed to load params from one or both checkpoints.") + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index 5f5542ec31..71e6e07f71 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -55,14 +55,23 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - # Instantiate the Layer - self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer( - config=self.cfg, - mesh=self.mesh, - layer_number=TEST_LAYER_NUM, - transformer_layer_module=DecoderLayer, - rngs=self.rngs, - ) + if self.cfg.pure_nnx_decoder: + # Instantiate the Layer + self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer( + config=self.cfg, + mesh=self.mesh, + layer_number=TEST_LAYER_NUM, + transformer_layer_module=DecoderLayer, + rngs=self.rngs, + ) + else: + # Instantiate the Layer + self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayerLinen( + config=self.cfg, + mesh=self.mesh, + layer_number=TEST_LAYER_NUM, + transformer_layer_module=DecoderLayer, + ) # Dimensions directly from the config object self.batch_size = int(self.cfg.per_device_batch_size) From 2592be6048e880e4aef0bfdcf516711e98889b04 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Tue, 3 Mar 2026 10:42:32 +0000 Subject: [PATCH 05/12] feat: enhance LoRA parameter handling and logging --- .../trainers/post_train/sft/train_sft.py | 255 ++++++++++++------ 1 file changed, 176 insertions(+), 79 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 064c21d51e..47ce88b368 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -38,7 +38,9 @@ from typing import Sequence from absl import app +import math import os +import re import jax import jax.numpy as jnp from flax import nnx @@ -53,9 +55,9 @@ from tunix.sft import utils as tunix_sft_utils from tunix.rl import reshard -from MaxText import optimizers -from MaxText import pyconfig -from MaxText.train import loss_fn +from maxtext.optimizers import optimizers +from maxtext.configs import pyconfig +from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( GoodputEvent, RECORD_JOB_END_TIME, @@ -195,6 +197,17 @@ def _dot_general_with_3d( preferred_element_type=None, out_sharding=None, ): + def _fallback_dot_general(): + return qwix_lora.LoraProvider.dot_general( + self, + lhs, + rhs, + dimension_numbers, + precision, + preferred_element_type, + out_sharding=out_sharding, + ) + res = qwix_ptq.PtqProvider.dot_general( self, lhs, @@ -256,19 +269,13 @@ def _dot_general_with_3d( lhs_trans = jnp.transpose(lhs, lhs_perm) lhs_shape = lhs_trans.shape lhs_flat = jnp.reshape(lhs_trans, lhs_shape[:-len(contract_axes)] + (k,)) + if lora_a.shape[0] != k: + return _fallback_dot_general() delta = jnp.einsum("...k,kr->...r", lhs_flat, lora_a) delta = jnp.einsum("...r,rm->...m", delta, lora_b) return res + delta * (rule.alpha / rule.rank) - return qwix_lora.LoraProvider.dot_general( - self, - lhs, - rhs, - dimension_numbers, - precision, - preferred_element_type, - out_sharding=out_sharding, - ) + return _fallback_dot_general() lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) @@ -296,8 +303,6 @@ def _safe_find_param(x, ptq_array_type=None): continue candidates[name] = value - - else: return original_find_param(x, ptq_array_type) @@ -374,85 +379,180 @@ def _prepare_dummy_inputs(mt_config, mesh): decoder_positions = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32), (dummy_bs, seq_len)) return decoder_input_tokens, decoder_positions -def _precreate_lora_params(target_model, lora_provider, qwix_flax_util, qwix_lora, math, re, types): - rules = getattr(lora_provider, "_rules", []) + +def _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types): + rules = [rule for rule in getattr(lora_provider, "_rules", []) if isinstance(rule, qwix_lora.LoraRule)] if not rules: + max_logging.log("LoRA precreate: no LoRA rules found on provider, skipping.") return - for path, module in target_model.iter_modules(): - module_path = "/".join(map(str, path)) - - for rule in rules: - if rule.op_names and "dot_general" not in rule.op_names: + # MaxText uses a single LoRA rule from the provided module_path regex. + rule = rules[0] + compiled_module_path = re.compile(mt_config.lora_module_path) + num_decoder_layers = getattr(mt_config, "num_decoder_layers", None) + if num_decoder_layers is None: + num_decoder_layers = getattr(mt_config, "base_num_decoder_layers", None) + param_scan_axis = int(getattr(mt_config, "param_scan_axis", 0)) + + def _with_layer_axis(base_shape_or_transpose, layer_value): + axis = max(0, min(param_scan_axis, len(base_shape_or_transpose))) + values = list(base_shape_or_transpose) + values.insert(axis, layer_value) + return tuple(values) + + matched_modules = 0 + precreated_modules = 0 + skipped_modules = [] + precreated_shapes = [] + + for path, module in lora_model.iter_modules(): + module_path = "/".join(str(p) for p in path) + if not compiled_module_path.search(module_path): + continue + + matched_modules += 1 + kernel = getattr(module, "kernel", None) + if kernel is None: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: no kernel") + continue + + try: + kernel_value = qwix_flax_util.unbox(kernel) + except Exception: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: cannot unbox kernel") + continue + + kernel_shape = getattr(kernel_value, "shape", None) + if kernel_shape is None and hasattr(kernel_value, "array"): + kernel_shape = getattr(kernel_value.array, "shape", None) + if kernel_shape is None and hasattr(kernel_value.array, "qvalue"): + kernel_shape = getattr(kernel_value.array.qvalue, "shape", None) + if kernel_shape is None or len(kernel_shape) < 2: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: unsupported kernel shape {kernel_shape}") + continue + + is_scanned_decoder_module = ( + "decoder/layers/" in module_path + and isinstance(num_decoder_layers, int) + and num_decoder_layers > 1 + ) + + if is_scanned_decoder_module: + layer_axis = None + if 0 <= param_scan_axis < len(kernel_shape): + layer_axis = int(param_scan_axis) + elif len(kernel_shape) > 1 and int(kernel_shape[1]) == int(num_decoder_layers): + layer_axis = 1 + else: + for axis, dim in enumerate(kernel_shape): + if int(dim) == int(num_decoder_layers): + layer_axis = axis + break + if layer_axis is None: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: unable to infer layer axis from kernel shape {kernel_shape}") continue - kernel_tensor = None - in_rank = 0 - out_rank = 0 - found_param_name = "kernel" + effective_shape = tuple(int(dim) for i, dim in enumerate(kernel_shape) if i != layer_axis) + if len(effective_shape) < 2: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: unsupported effective shape {effective_shape}") + continue - # Case A: Pure NNX (Standard attributes) - if hasattr(module, "kernel") and hasattr(module, "in_features_shape"): - try: - kernel_tensor = qwix_flax_util.unbox(module.kernel) - in_rank = len(module.in_features_shape) - out_rank = len(module.out_features_shape) - except Exception: - kernel_tensor = None + if "decoder/layers/self_attention/out" in module_path and len(effective_shape) >= 3: + in_dim = int(math.prod(effective_shape[:-1])) + out_dim = int(effective_shape[-1]) + else: + in_dim = int(effective_shape[0]) + out_dim = int(math.prod(effective_shape[1:])) + if in_dim <= 0 or out_dim <= 0: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: non-positive dims in={in_dim} out={out_dim}") + continue - + in_axis = next(i for i in range(len(kernel_shape)) if i != layer_axis) + out_axis = next(i for i in range(len(kernel_shape) - 1, -1, -1) if i != layer_axis) - if kernel_tensor is None: + a_shape = _with_layer_axis((in_dim, rule.rank), num_decoder_layers) + b_shape = _with_layer_axis((rule.rank, out_dim), num_decoder_layers) + a_sharding_transpose = _with_layer_axis((in_axis, None), layer_axis) + b_sharding_transpose = _with_layer_axis((None, out_axis), layer_axis) + else: + prefix_shape = tuple(kernel_shape[:-2]) + in_dim = int(kernel_shape[-2]) + out_dim = int(kernel_shape[-1]) + if in_dim <= 0 or out_dim <= 0: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}: non-positive dims in={in_dim} out={out_dim}") continue - # Closure to define LoRA A and B shapes and sharding - def _init_for_module(self, k_tensor=kernel_tensor, i_r=in_rank, o_r=out_rank, p_name=found_param_name): - kernel_shape = getattr(k_tensor, "shape", ()) - extra_rank = max(0, len(kernel_shape) - (i_r + o_r)) - - prefix_shape = kernel_shape[:extra_rank] - in_shape = kernel_shape[extra_rank : extra_rank + i_r] - out_shape = kernel_shape[extra_rank + i_r :] - - in_size = int(math.prod(in_shape)) - out_size = int(math.prod(out_shape)) - - if in_size <= 0 or out_size <= 0: - return - - a_shape = prefix_shape + (in_size, rule.rank) - b_shape = prefix_shape + (rule.rank, out_size) - - prefix_axes = tuple(range(extra_rank)) - a_sharding_transpose = prefix_axes + (None,) - b_sharding_transpose = prefix_axes + (None,) - - qwix_lora._get_or_create_lora_params( - name=p_name, - rule=rule, - a_shape=a_shape, - b_shape=b_shape, - a_sharding_transpose=a_sharding_transpose, - b_sharding_transpose=b_sharding_transpose, - ) + full_prefix_shape = prefix_shape + a_shape = full_prefix_shape + (in_dim, rule.rank) + b_shape = full_prefix_shape + (rule.rank, out_dim) + + prefix_rank = len(full_prefix_shape) + a_sharding_transpose = tuple(range(prefix_rank)) + (prefix_rank, None) + b_sharding_transpose = tuple(range(prefix_rank)) + (None, prefix_rank + 1) - types.MethodType(_init_for_module, module)() + def _init_for_module( + self, + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, + ): + qwix_lora._get_or_create_lora_params( + name="kernel", + rule=rule, + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, + ) + + types.MethodType(_init_for_module, module)() + precreated_modules += 1 + if len(precreated_shapes) < 10: + precreated_shapes.append((module_path, a_shape, b_shape)) + + max_logging.log( + "LoRA precreate: matched_modules=%s precreated_modules=%s skipped_sample=%s shape_sample=%s" + % (matched_modules, precreated_modules, skipped_modules, precreated_shapes) + ) def _verify_lora_parameters(lora_model, mt_config): + compiled_module_path = re.compile(mt_config.lora_module_path) + matched_module_paths = [] + sample_module_paths = [] + + for path, _ in lora_model.iter_modules(): + module_path = "/".join(str(p) for p in path) + if len(sample_module_paths) < 50: + sample_module_paths.append(module_path) + if compiled_module_path.search(module_path): + matched_module_paths.append(module_path) + is_lora_enabled = tunix_sft_utils.is_lora_enabled(lora_model) - if not is_lora_enabled: - module_paths = [] - for path, _ in lora_model.iter_modules(): - module_paths.append("/".join(str(p) for p in path)) - if len(module_paths) >= 50: - break + if is_lora_enabled: + max_logging.log("LoRA verification: tunix_sft_utils.is_lora_enabled=True") + return + + if not matched_module_paths: max_logging.log( f"LoRA module_path='{mt_config.lora_module_path}' did not match any weights. " - f"Sample module paths: {module_paths}" + f"Sample module paths: {sample_module_paths}" ) raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") - max_logging.log("LoRA verification: tunix_sft_utils.is_lora_enabled=True") + + max_logging.log( + "LoRA verification: matched %s target modules but LoRA params are not yet materialized; " + "continuing with lazy LoRA initialization. Sample matches: %s" + % (len(matched_module_paths), matched_module_paths[:10]) + ) def maybe_apply_lora(model, mesh, mt_config): @@ -461,13 +561,11 @@ def maybe_apply_lora(model, mesh, mt_config): if hasattr(mt_config, 'lora_input_adapters_path') and mt_config.lora_input_adapters_path: max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") return model - + if not getattr(mt_config, "enable_lora", False): return model import qwix - import math - import re import qwix._src.flax_util as qwix_flax_util import qwix._src.providers.lora as qwix_lora import qwix._src.providers.ptq as qwix_ptq @@ -488,8 +586,7 @@ def maybe_apply_lora(model, mesh, mt_config): decoder_positions=decoder_positions, skip_nnx_init=True, ) - - _precreate_lora_params(lora_model, lora_provider, qwix_flax_util, qwix_lora, math, re, types) + _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types) if mesh is not None: lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) From 344c64cbfbbf52b12ed8802c521d2fc2bcc4671b Mon Sep 17 00:00:00 2001 From: Charles Li Date: Wed, 4 Mar 2026 20:33:31 +0000 Subject: [PATCH 06/12] Fix lora_to_maxtext convert script --- .../checkpoint_conversion/lora_to_maxtext.py | 222 +++++++++--------- 1 file changed, 108 insertions(+), 114 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py index 54eee25dd0..ae882654dd 100644 --- a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py @@ -29,7 +29,7 @@ Example Usage: To convert HF LoRA to MaxText adapter: - python src/MaxText/utils/ckpt_conversion/apply_lora.py \ + python src/maxtext/ckpt_conversion/apply_lora.py \ MaxText/configs/sft.yml model_name="llama3.1-8b" \ hf_lora_adapter_path="username/lora-adapter-repo" \ base_output_directory="/path/to/output/directory" \ @@ -54,11 +54,12 @@ from flax import nnx from orbax import checkpoint as ocp -from MaxText import pyconfig -from MaxText.common_types import MODEL_MODE_TRAIN -from MaxText.layers import models, quantizations -from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS +from maxtext.configs import pyconfig +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.layers import quantizations +from maxtext.models import models +from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING +from maxtext.checkpoint_conversion.utils.utils import apply_hook_fns, HF_IDS from maxtext.utils import max_logging from maxtext.utils import maxtext_utils from maxtext.utils import max_utils @@ -68,7 +69,7 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: """Load HF LoRA adapter weights directly from safetensors files.""" max_logging.log(f"Loading HF LoRA adapter from {adapter_path}") - + # Check adapter compatibility adapter_config = None if os.path.isdir(adapter_path): @@ -76,27 +77,23 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: adapter_dir = epath.Path(adapter_path) config_file = adapter_dir / "adapter_config.json" if config_file.exists(): - with open(config_file, 'r') as f: + with open(config_file, "r") as f: adapter_config = json.load(f) else: # HF Hub repo try: - config_file = hf_hub_download( - adapter_path, - "adapter_config.json", - token=os.environ.get("HF_AUTH_TOKEN") - ) - with open(config_file, 'r') as f: + config_file = hf_hub_download(adapter_path, "adapter_config.json", token=os.environ.get("HF_AUTH_TOKEN")) + with open(config_file, "r") as f: adapter_config = json.load(f) except Exception: max_logging.log("Warning: Could not load adapter_config.json from HF Hub") - + if adapter_config: base_model = adapter_config.get("base_model_name_or_path") # if base_model and base_model.replace("-Instruct", "") != hf_model_id.replace("-Instruct", ""): # raise ValueError(f"Adapter base model '{base_model}' does not match expected model '{hf_model_id}'") max_logging.log(f"Adapter compatible with model {hf_model_id}") - + # Handle both local paths and HF Hub paths if os.path.isdir(adapter_path): # Local directory @@ -112,33 +109,30 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: try: # Try to download the adapter config to get the file list from huggingface_hub import list_repo_files + files = list_repo_files(adapter_path, token=os.environ.get("HF_AUTH_TOKEN")) - safetensor_files = [f for f in files if f.endswith('.safetensors')] + safetensor_files = [f for f in files if f.endswith(".safetensors")] if not safetensor_files: - bin_files = [f for f in files if f.endswith('.bin')] + bin_files = [f for f in files if f.endswith(".bin")] if not bin_files: raise ValueError(f"No LoRA adapter files found in {adapter_path}") adapter_file = bin_files[0] else: adapter_file = safetensor_files[0] - + # Download the adapter file - adapter_file = hf_hub_download( - adapter_path, - adapter_file, - token=os.environ.get("HF_AUTH_TOKEN") - ) + adapter_file = hf_hub_download(adapter_path, adapter_file, token=os.environ.get("HF_AUTH_TOKEN")) except Exception as e: raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") - + # Load the adapter weights - if adapter_file.endswith('.safetensors'): + if adapter_file.endswith(".safetensors"): with safe_open(adapter_file, framework="numpy") as f: lora_weights = {k: f.get_tensor(k) for k in f.keys()} else: # For .bin files, we'd need torch.load, but safetensors is preferred raise ValueError(f"Unsupported adapter file format: {adapter_file}") - + max_logging.log(f"Loaded {len(lora_weights)} LoRA parameters from adapter") return lora_weights @@ -146,17 +140,17 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict, config) -> str: """Convert HF LoRA key to MaxText parameter path using the mapping from to_maxtext.py.""" # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight - + # 1. Clean up LoRA suffixes to get the base module path # e.g. ...q_proj.lora_A.weight -> ...q_proj hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "") hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") - + # 2. Handle prefix. Expected target is usually "model.layers..." # Input could be "base_model.model.model.layers..." or "base_model.model.layers..." if hf_param_key.startswith("base_model.model."): - hf_param_key = hf_param_key[len("base_model.model."):] - + hf_param_key = hf_param_key[len("base_model.model.") :] + # 3. Search for the corresponding MaxText key for mt_key, hf_keys in param_mapping.items(): if isinstance(hf_keys, list): @@ -167,76 +161,76 @@ def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict, config) -> elif isinstance(hf_keys, str): if hf_keys.replace(".weight", "") == hf_param_key: return mt_key - + return None def convert_lora_to_maxtext_adapter(config, lora_weights: dict, output_path: str, hf_model_id: str): - """Converts HF LoRA weights to MaxText adapter format without merging.""" - - # 1. Setup Mesh and Model Structure (Abstractly) - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, axis_names=config.mesh_axes) - quant = quantizations.configure_quantization(config) - - # Initialize rngs for model creation - rngs = nnx.Rngs(params=jax.random.PRNGKey(0), dropout=jax.random.PRNGKey(1)) - - # Use the model definition to understand the target parameter paths - model = models.Transformer(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN, rngs=rngs) - - hf_token = config.hf_access_token - - # Get the parameter mapping (MT -> HF) - model_key = config.model_name - if "-Instruct" in model_key: - max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") - model_key = model_key.replace("-Instruct", "") - hf_config_obj = AutoConfig.from_pretrained(hf_model_id, token=hf_token) - param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) - - # 2. Initialize an empty dictionary for the MaxText Adapter - mt_adapter_tree = {} - mapped_count = 0 - - # 3. Map HF LoRA weights to MaxText keys - for hf_key, weight in lora_weights.items(): - # Identify the MaxText path for this specific HF weight - mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf, config) - - if mt_key: - # Determine if this is the 'A' or 'B' matrix - suffix = "lora_A" if "lora_A" in hf_key else "lora_B" - - # Construct a nested dictionary path in mt_adapter_tree - # MaxText expects: { 'decoder': { 'layers': { '0': { 'query': { 'lora_A': ... } } } } } - parts = mt_key.split("/") - current = mt_adapter_tree - for part in parts: - if part not in current: - current[part] = {} - current = current[part] - - # Convert weight to JAX array and store - current[suffix] = jnp.array(weight) - mapped_count += 1 - else: - max_logging.log(f"Warning: Could not map HF LoRA key {hf_key} to MaxText key") - - max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") - - # 4. Save as a standalone adapter checkpoint - max_logging.log(f"Saving MaxText LoRA adapter to {output_path}") - ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) - ckptr.save(epath.Path(output_path), mt_adapter_tree) - - max_logging.log("LoRA adapter conversion completed successfully") + """Converts HF LoRA weights to MaxText adapter format without merging.""" + + # 1. Setup Mesh and Model Structure (Abstractly) + devices_array = maxtext_utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, axis_names=config.mesh_axes) + quant = quantizations.configure_quantization(config) + + # Initialize rngs for model creation + rngs = nnx.Rngs(params=jax.random.PRNGKey(0), dropout=jax.random.PRNGKey(1)) + + # Use the model definition to understand the target parameter paths + model = models.Transformer(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN, rngs=rngs) + + hf_token = config.hf_access_token + + # Get the parameter mapping (MT -> HF) + model_key = config.model_name + if "-Instruct" in model_key: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_key = model_key.replace("-Instruct", "") + hf_config_obj = AutoConfig.from_pretrained(hf_model_id, token=hf_token) + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) + + # 2. Initialize an empty dictionary for the MaxText Adapter + mt_adapter_tree = {} + mapped_count = 0 + + # 3. Map HF LoRA weights to MaxText keys + for hf_key, weight in lora_weights.items(): + # Identify the MaxText path for this specific HF weight + mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf, config) + + if mt_key: + # Determine if this is the 'A' or 'B' matrix + suffix = "lora_A" if "lora_A" in hf_key else "lora_B" + + # Construct a nested dictionary path in mt_adapter_tree + # MaxText expects: { 'decoder': { 'layers': { '0': { 'query': { 'lora_A': ... } } } } } + parts = mt_key.split("/") + current = mt_adapter_tree + for part in parts: + if part not in current: + current[part] = {} + current = current[part] + + # Convert weight to JAX array and store + current[suffix] = jnp.array(weight) + mapped_count += 1 + else: + max_logging.log(f"Warning: Could not map HF LoRA key {hf_key} to MaxText key") + + max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") + + # 4. Save as a standalone adapter checkpoint + max_logging.log(f"Saving MaxText LoRA adapter to {output_path}") + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + ckptr.save(epath.Path(output_path), mt_adapter_tree) + + max_logging.log("LoRA adapter conversion completed successfully") def main(args: Sequence[str]) -> None: # Set logging to INFO level to see max_logging.log messages logging.set_verbosity(logging.INFO) - + # Check if the user is using an Instruct version. If so, use the base model architecture original_model_name = None for i, arg in enumerate(args): @@ -245,7 +239,7 @@ def main(args: Sequence[str]) -> None: # Remove quotes if present model_name_arg = model_name_arg.strip("'").strip('"') original_model_name = model_name_arg - + if "-Instruct" in model_name_arg: max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") model_name_arg = model_name_arg.replace("-Instruct", "") @@ -254,58 +248,58 @@ def main(args: Sequence[str]) -> None: # Initialize maxtext config config = pyconfig.initialize(args) - - if not hasattr(config, 'hf_lora_adapter_path') or not config.hf_lora_adapter_path: + + if not hasattr(config, "hf_lora_adapter_path") or not config.hf_lora_adapter_path: raise ValueError("hf_lora_adapter_path must be specified") - + # Determine HF model ID and check if supported hf_model_id = HF_IDS.get(config.model_name) if hf_model_id is None: raise ValueError(f"Model '{config.model_name}' is not supported. Use a supported model_name from HF_IDS.") - - if not hasattr(config, 'base_output_directory') or not config.base_output_directory: + + if not hasattr(config, "base_output_directory") or not config.base_output_directory: raise ValueError("base_output_directory must be specified (in config file or as command-line argument)") - + output_dir = config.base_output_directory - + # Use original model name for output path model_name_for_path = original_model_name or config.model_name adapter_name = os.path.basename(config.hf_lora_adapter_path) full_output_path = os.path.join(output_dir, model_name_for_path, adapter_name) - + os.makedirs(os.path.dirname(full_output_path), exist_ok=True) - + if os.path.exists(full_output_path): import shutil + max_logging.log(f"Output directory {full_output_path} exists. Removing it to allow Orbax to save.") shutil.rmtree(full_output_path) - + # Load LoRA adapter and check compatibility lora_weights = load_hf_lora_adapter(config.hf_lora_adapter_path, hf_model_id) - + # Convert LoRA to MaxText adapter format and save convert_lora_to_maxtext_adapter(config, lora_weights, full_output_path, hf_model_id) - - # Verify output was created - if not os.path.exists(full_output_path): + + # Verify output was created #using epath for local file and gcs compatibility + outputpath = epath.Path(full_output_path) + if not outputpath.exists(): raise RuntimeError(f"Failed to create output directory {full_output_path}") if __name__ == "__main__": # Argument parsing similar to to_maxtext.py parser = argparse.ArgumentParser() - parser.add_argument( - "--simulated_cpu_devices_count", type=int, required=False, default=16 - ) - + parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) + # Parse local arguments local_args, remaining_args = parser.parse_known_args() - + # Reconstruct model_args (script name + the args MaxText needs) model_args = [sys.argv[0]] + remaining_args # Set jax environment jax.config.update("jax_platforms", "cpu") os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" - - main(model_args) \ No newline at end of file + + main(model_args) From 02c71d2527bf65e29eeb80113ca4bea8afece4d9 Mon Sep 17 00:00:00 2001 From: Charles Li Date: Wed, 4 Mar 2026 23:05:34 +0000 Subject: [PATCH 07/12] Fix pyink warning --- src/maxtext/configs/types.py | 17 ++-- .../trainers/post_train/sft/train_sft.py | 84 ++++++++----------- 2 files changed, 46 insertions(+), 55 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c6b4be3ce6..c9f0930bab 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -287,7 +287,10 @@ class Checkpointing(BaseModel): load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.") lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") - hf_lora_adapter_path: PathStr = Field("", description="HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors.") + hf_lora_adapter_path: PathStr = Field( + "", + description="HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors.", + ) load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") @@ -1062,15 +1065,12 @@ class LoRA(BaseModel): lora_module_path: str = Field( "", description=( - "Regex identifying target modules for LoRA, e.g." - " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + "Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." ), ) lora_weight_qtype: str | None = Field( None, - description=( - "Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied." - ), + description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."), ) lora_tile_size: NonNegativeInt | None = Field( None, @@ -1410,7 +1410,10 @@ class Profiling(BaseModel): xprof_e2e_enable_fw_throttle_event: bool = Field(False, description="Enable FW throttle event.") xprof_e2e_enable_fw_power_level_event: bool = Field(False, description="Enable FW power level event.") xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.") - profile_power_events: bool = Field(False, description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.") + profile_power_events: bool = Field( + False, + description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.", + ) class HloDump(BaseModel): diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 47ce88b368..d83fb3b60e 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -165,24 +165,24 @@ def _build_lora_provider(mt_config, qwix): if mt_config.lora_weight_qtype is not None: lora_kwargs["weight_qtype"] = mt_config.lora_weight_qtype max_logging.log( - "QLoRA configured: module_path=%s rank=%s alpha=%s weight_qtype=%s tile_size=%s" - % ( - mt_config.lora_module_path, - mt_config.lora_rank, - mt_config.lora_alpha, - mt_config.lora_weight_qtype, - mt_config.lora_tile_size, - ) + "QLoRA configured: module_path=%s rank=%s alpha=%s weight_qtype=%s tile_size=%s" + % ( + mt_config.lora_module_path, + mt_config.lora_rank, + mt_config.lora_alpha, + mt_config.lora_weight_qtype, + mt_config.lora_tile_size, + ) ) else: max_logging.log( - "LoRA configured: module_path=%s rank=%s alpha=%s tile_size=%s" - % ( - mt_config.lora_module_path, - mt_config.lora_rank, - mt_config.lora_alpha, - mt_config.lora_tile_size, - ) + "LoRA configured: module_path=%s rank=%s alpha=%s tile_size=%s" + % ( + mt_config.lora_module_path, + mt_config.lora_rank, + mt_config.lora_alpha, + mt_config.lora_tile_size, + ) ) return qwix.LoraProvider(**lora_kwargs) @@ -206,7 +206,7 @@ def _fallback_dot_general(): precision, preferred_element_type, out_sharding=out_sharding, - ) + ) res = qwix_ptq.PtqProvider.dot_general( self, @@ -226,18 +226,14 @@ def _fallback_dot_general(): if weight_name is None: return res - if ( - len(rhs.shape) == 3 - and tuple(dimension_numbers[0][1]) == (0,) - and not dimension_numbers[1][1] - ): + if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0,) and not dimension_numbers[1][1]: lora_params = qwix_lora._get_or_create_lora_params( - name=weight_name, - rule=rule, - a_shape=(rhs.shape[0], rule.rank), - b_shape=(rule.rank, rhs.shape[1] * rhs.shape[2]), - a_sharding_transpose=(0, None), - b_sharding_transpose=(None, 1), + name=weight_name, + rule=rule, + a_shape=(rhs.shape[0], rule.rank), + b_shape=(rule.rank, rhs.shape[1] * rhs.shape[2]), + a_sharding_transpose=(0, None), + b_sharding_transpose=(None, 1), ) lora_a, lora_b = lora_params[:2] if rule.dropout > 0: @@ -247,11 +243,7 @@ def _fallback_dot_general(): delta = jnp.einsum("...r,rnm->...nm", delta, lora_b) return res + delta * (rule.alpha / rule.rank) - if ( - len(rhs.shape) == 3 - and tuple(dimension_numbers[0][1]) == (0, 1) - and not dimension_numbers[1][1] - ): + if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0, 1) and not dimension_numbers[1][1]: k = rhs.shape[0] * rhs.shape[1] lora_params = qwix_lora._get_or_create_lora_params( name=weight_name, @@ -268,7 +260,7 @@ def _fallback_dot_general(): lhs_perm = [i for i in range(lhs.ndim) if i not in contract_axes] + list(contract_axes) lhs_trans = jnp.transpose(lhs, lhs_perm) lhs_shape = lhs_trans.shape - lhs_flat = jnp.reshape(lhs_trans, lhs_shape[:-len(contract_axes)] + (k,)) + lhs_flat = jnp.reshape(lhs_trans, lhs_shape[: -len(contract_axes)] + (k,)) if lora_a.shape[0] != k: return _fallback_dot_general() delta = jnp.einsum("...k,kr->...r", lhs_flat, lora_a) @@ -364,9 +356,7 @@ def _prepare_dummy_inputs(mt_config, mesh): batch_size = getattr(mt_config, "per_device_batch_size", 1) seq_len = getattr(mt_config, "max_target_length", 1) if batch_size <= 0 or seq_len <= 0: - raise ValueError( - "per_device_batch_size and max_target_length must be positive when LoRA is enabled." - ) + raise ValueError("per_device_batch_size and max_target_length must be positive when LoRA is enabled.") devices_data_fsdp = 1 if mesh is not None: @@ -435,9 +425,7 @@ def _with_layer_axis(base_shape_or_transpose, layer_value): continue is_scanned_decoder_module = ( - "decoder/layers/" in module_path - and isinstance(num_decoder_layers, int) - and num_decoder_layers > 1 + "decoder/layers/" in module_path and isinstance(num_decoder_layers, int) and num_decoder_layers > 1 ) if is_scanned_decoder_module: @@ -543,8 +531,8 @@ def _verify_lora_parameters(lora_model, mt_config): if not matched_module_paths: max_logging.log( - f"LoRA module_path='{mt_config.lora_module_path}' did not match any weights. " - f"Sample module paths: {sample_module_paths}" + f"LoRA module_path='{mt_config.lora_module_path}' did not match any weights. " + f"Sample module paths: {sample_module_paths}" ) raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") @@ -558,7 +546,7 @@ def _verify_lora_parameters(lora_model, mt_config): def maybe_apply_lora(model, mesh, mt_config): """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" # Skip Qwix LoRA if MaxText LoRA adapters are loaded - if hasattr(mt_config, 'lora_input_adapters_path') and mt_config.lora_input_adapters_path: + if hasattr(mt_config, "lora_input_adapters_path") and mt_config.lora_input_adapters_path: max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") return model @@ -580,11 +568,11 @@ def maybe_apply_lora(model, mesh, mt_config): decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh) lora_model = qwix.apply_lora_to_model( - model, - lora_provider, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - skip_nnx_init=True, + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + skip_nnx_init=True, ) _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types) @@ -668,4 +656,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) From 11a7d3a95421f1b9220c42a7b8ddf9a90ab7a27f Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Thu, 5 Mar 2026 10:50:12 +0000 Subject: [PATCH 08/12] chore: update code structure for better readability --- .../checkpoint_conversion/lora_to_maxtext.py | 54 ++-- src/maxtext/configs/post_train/sft.yml | 2 + src/maxtext/configs/types.py | 20 +- .../trainers/post_train/sft/train_sft.py | 244 ++++++++++-------- 4 files changed, 175 insertions(+), 145 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py index ae882654dd..d88d82079b 100644 --- a/src/maxtext/checkpoint_conversion/lora_to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/lora_to_maxtext.py @@ -37,32 +37,25 @@ """ import argparse +import json import os +import shutil import sys -import json from typing import Sequence -from functools import partial import jax import jax.numpy as jnp -import numpy as np -import torch -from safetensors import safe_open +from etils import epath from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files +from safetensors import safe_open from transformers import AutoConfig -from etils import epath -from flax import nnx from orbax import checkpoint as ocp +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING +from maxtext.checkpoint_conversion.utils.utils import HF_IDS from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN -from maxtext.layers import quantizations -from maxtext.models import models -from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import apply_hook_fns, HF_IDS from maxtext.utils import max_logging -from maxtext.utils import maxtext_utils -from maxtext.utils import max_utils from absl import logging @@ -77,19 +70,20 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: adapter_dir = epath.Path(adapter_path) config_file = adapter_dir / "adapter_config.json" if config_file.exists(): - with open(config_file, "r") as f: + with open(config_file, "r", encoding="utf-8") as f: adapter_config = json.load(f) else: # HF Hub repo try: config_file = hf_hub_download(adapter_path, "adapter_config.json", token=os.environ.get("HF_AUTH_TOKEN")) - with open(config_file, "r") as f: + with open(config_file, "r", encoding="utf-8") as f: adapter_config = json.load(f) - except Exception: - max_logging.log("Warning: Could not load adapter_config.json from HF Hub") + except Exception as exc: # pylint: disable=broad-exception-caught + max_logging.log(f"Warning: Could not load adapter_config.json from HF Hub: {exc}") if adapter_config: - base_model = adapter_config.get("base_model_name_or_path") + if adapter_config.get("base_model_name_or_path"): + max_logging.log(f"Adapter base model: {adapter_config['base_model_name_or_path']}") # if base_model and base_model.replace("-Instruct", "") != hf_model_id.replace("-Instruct", ""): # raise ValueError(f"Adapter base model '{base_model}' does not match expected model '{hf_model_id}'") max_logging.log(f"Adapter compatible with model {hf_model_id}") @@ -107,9 +101,6 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: else: # Assume it's a HF Hub repo ID try: - # Try to download the adapter config to get the file list - from huggingface_hub import list_repo_files - files = list_repo_files(adapter_path, token=os.environ.get("HF_AUTH_TOKEN")) safetensor_files = [f for f in files if f.endswith(".safetensors")] if not safetensor_files: @@ -123,7 +114,7 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: # Download the adapter file adapter_file = hf_hub_download(adapter_path, adapter_file, token=os.environ.get("HF_AUTH_TOKEN")) except Exception as e: - raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") + raise ValueError(f"Failed to load LoRA adapter from {adapter_path}: {e}") from e # Load the adapter weights if adapter_file.endswith(".safetensors"): @@ -137,7 +128,7 @@ def load_hf_lora_adapter(adapter_path: str, hf_model_id: str) -> dict: return lora_weights -def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict, config) -> str: +def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict) -> str: """Convert HF LoRA key to MaxText parameter path using the mapping from to_maxtext.py.""" # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight @@ -168,17 +159,6 @@ def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict, config) -> def convert_lora_to_maxtext_adapter(config, lora_weights: dict, output_path: str, hf_model_id: str): """Converts HF LoRA weights to MaxText adapter format without merging.""" - # 1. Setup Mesh and Model Structure (Abstractly) - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, axis_names=config.mesh_axes) - quant = quantizations.configure_quantization(config) - - # Initialize rngs for model creation - rngs = nnx.Rngs(params=jax.random.PRNGKey(0), dropout=jax.random.PRNGKey(1)) - - # Use the model definition to understand the target parameter paths - model = models.Transformer(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN, rngs=rngs) - hf_token = config.hf_access_token # Get the parameter mapping (MT -> HF) @@ -196,7 +176,7 @@ def convert_lora_to_maxtext_adapter(config, lora_weights: dict, output_path: str # 3. Map HF LoRA weights to MaxText keys for hf_key, weight in lora_weights.items(): # Identify the MaxText path for this specific HF weight - mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf, config) + mt_key = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf) if mt_key: # Determine if this is the 'A' or 'B' matrix @@ -270,8 +250,6 @@ def main(args: Sequence[str]) -> None: os.makedirs(os.path.dirname(full_output_path), exist_ok=True) if os.path.exists(full_output_path): - import shutil - max_logging.log(f"Output directory {full_output_path} exists. Removing it to allow Orbax to save.") shutil.rmtree(full_output_path) diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index e71447e838..e4dd3b25de 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -30,6 +30,8 @@ lora_module_path: "" # For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size. lora_weight_qtype: null lora_tile_size: null +# Optional NNX LoRA restore checkpoint path (direct `model_params` directory). +lora_restore_path: "" # -------------- HF LoRA Adapter -------------- # HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c9f0930bab..931c854624 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -289,7 +289,10 @@ class Checkpointing(BaseModel): lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") hf_lora_adapter_path: PathStr = Field( "", - description="HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local path to directory containing adapter_model.safetensors.", + description=( + "HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local " + "path to directory containing adapter_model.safetensors." + ), ) load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") @@ -800,7 +803,10 @@ class LayoutAndSharding(BaseModel): description="Allowed percentage of non-sharded parameters.", ) shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.") - internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.") + internal_compile: bool = Field( + False, + description="Use internal_compile to bypass open-source topology mappings.", + ) internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.") @@ -1076,6 +1082,13 @@ class LoRA(BaseModel): None, description="Optional tile size for QLoRA (e.g., 128 or 256).", ) + lora_restore_path: PathStr = Field( + "", + description=( + "Optional NNX LoRA checkpoint path to restore adapter weights from." + " This must be the direct `model_params` path." + ), + ) class Distillation(BaseModel): @@ -1412,7 +1425,8 @@ class Profiling(BaseModel): xprof_e2e_enable_fw_thermal_event: bool = Field(False, description="Enable FW thermal event.") profile_power_events: bool = Field( False, - description="Enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.", + description="Enable TPU-specific power/thermal profiling events." + " Defaults to False to avoid breaking GPU xplane tracing.", ) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index d83fb3b60e..b0e31a54a6 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -19,18 +19,18 @@ Example command: Training & Evaluation: - python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ - run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ - model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ - hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ + run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=${MODEL_NAME?} load_parameters_path=${CHECKPOINT_PATH?} \ + hf_access_token=${HF_ACCESS_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} \ per_device_batch_size=1 max_target_length=1024 \ eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 Training: - python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ - run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ - model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ - hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ + run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=${MODEL_NAME?} load_parameters_path=${CHECKPOINT_PATH?} \ + hf_access_token=${HF_ACCESS_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} \ per_device_batch_size=1 max_target_length=1024 \ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ @@ -132,7 +132,15 @@ def use_maxtext_loss_function(trainer, mt_config): The trainer configured with the MaxText loss function. """ - def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targets_position, targets_segmentation): + def loss_func( + model, + inputs, + inputs_position, + inputs_segmentation, + targets, + targets_position, + targets_segmentation, + ): data = { "inputs": inputs, "inputs_position": inputs_position, @@ -148,6 +156,7 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ def _validate_lora_config(mt_config): + """Validates required LoRA configuration fields.""" if mt_config.lora_rank <= 0: raise ValueError("enable_lora is True but lora_rank is not set to a positive value.") if not mt_config.lora_module_path: @@ -155,6 +164,7 @@ def _validate_lora_config(mt_config): def _build_lora_provider(mt_config, qwix): + """Builds a Qwix LoRA provider from MaxText LoRA settings.""" lora_kwargs = { "module_path": mt_config.lora_module_path, "rank": mt_config.lora_rank, @@ -165,29 +175,23 @@ def _build_lora_provider(mt_config, qwix): if mt_config.lora_weight_qtype is not None: lora_kwargs["weight_qtype"] = mt_config.lora_weight_qtype max_logging.log( - "QLoRA configured: module_path=%s rank=%s alpha=%s weight_qtype=%s tile_size=%s" - % ( - mt_config.lora_module_path, - mt_config.lora_rank, - mt_config.lora_alpha, - mt_config.lora_weight_qtype, - mt_config.lora_tile_size, - ) + f"QLoRA configured: module_path={mt_config.lora_module_path} " + f"rank={mt_config.lora_rank} alpha={mt_config.lora_alpha} " + f"weight_qtype={mt_config.lora_weight_qtype} " + f"tile_size={mt_config.lora_tile_size}" ) else: max_logging.log( - "LoRA configured: module_path=%s rank=%s alpha=%s tile_size=%s" - % ( - mt_config.lora_module_path, - mt_config.lora_rank, - mt_config.lora_alpha, - mt_config.lora_tile_size, - ) + f"LoRA configured: module_path={mt_config.lora_module_path} " + f"rank={mt_config.lora_rank} alpha={mt_config.lora_alpha} " + f"tile_size={mt_config.lora_tile_size}" ) return qwix.LoraProvider(**lora_kwargs) def _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types): + """Patches Qwix LoRA dot_general to support selected 3D-kernel paths.""" + def _dot_general_with_3d( self, lhs, @@ -218,7 +222,7 @@ def _fallback_dot_general(): out_sharding=out_sharding, ) - rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) + rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) # pylint: disable=protected-access if not isinstance(rule, qwix_lora.LoraRule): return res @@ -227,7 +231,7 @@ def _fallback_dot_general(): return res if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0,) and not dimension_numbers[1][1]: - lora_params = qwix_lora._get_or_create_lora_params( + lora_params = qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access name=weight_name, rule=rule, a_shape=(rhs.shape[0], rule.rank), @@ -245,7 +249,7 @@ def _fallback_dot_general(): if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0, 1) and not dimension_numbers[1][1]: k = rhs.shape[0] * rhs.shape[1] - lora_params = qwix_lora._get_or_create_lora_params( + lora_params = qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access name=weight_name, rule=rule, a_shape=(k, rule.rank), @@ -272,67 +276,12 @@ def _fallback_dot_general(): lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) -def _patch_qwix_find_param(qwix_flax_util): - if getattr(qwix_flax_util, "_maxtext_find_param_patched", False): - return - - original_find_param = qwix_flax_util.find_param - - def _safe_find_param(x, ptq_array_type=None): - module = qwix_flax_util.get_current_module() - candidates = {} - - # 1) Pure NNX: scan attributes for nnx.Params / ptq arrays. - if isinstance(module, nnx.Module): - array_types = (nnx.Param,) if ptq_array_type is None else (nnx.Param, ptq_array_type) - for name, node in module.__dict__.items(): - if isinstance(node, array_types): - value = getattr(node, "value", None) - if value is None: - try: - value = qwix_flax_util.unbox(node) - except Exception: - continue - candidates[name] = value - - else: - return original_find_param(x, ptq_array_type) - - candidates_by_id = {id(c): n for n, c in candidates.items()} - - if id(x) in candidates_by_id: - return candidates_by_id[id(x)] - - if isinstance(x, jax.core.Tracer) and hasattr(x, "parent"): - while True: - if id(x) in candidates_by_id: - return candidates_by_id[id(x)] - if x.parent and len(x.parent.in_tracers) == 1: - x = x.parent.in_tracers[0] - elif id(const := x.get_const()) in candidates_by_id: - return candidates_by_id[id(const)] - else: - return None - - if not hasattr(x, "shape"): - return None - candidates = {n: c for n, c in candidates.items() if getattr(c, "shape", None) == x.shape} - if len(candidates) > 2: - raise ValueError(f"Multiple candidate params found: {candidates.keys()}") - if len(candidates) == 1: - return list(candidates.keys())[0] - - return None - - qwix_flax_util.find_param = _safe_find_param - qwix_flax_util._maxtext_find_param_patched = True - - def _patch_with_sharding_constraint(): + """Patches sharding constraint to tolerate shape/spec rank mismatches.""" if getattr(jax.lax, "_maxtext_with_sharding_constraint_patched", False): return - jax.lax._original_with_sharding_constraint = jax.lax.with_sharding_constraint + jax.lax._original_with_sharding_constraint = jax.lax.with_sharding_constraint # pylint: disable=protected-access def _safe_with_sharding_constraint(x, sharding, *args, **kwargs): def _safe_leaf_fn(x_leaf, s_leaf): @@ -342,17 +291,18 @@ def _safe_leaf_fn(x_leaf, s_leaf): ndim = getattr(x_leaf, "ndim", None) if ndim is not None and len(spec) > ndim: return x_leaf - except Exception: + except Exception: # pylint: disable=broad-exception-caught pass - return jax.lax._original_with_sharding_constraint(x_leaf, s_leaf, *args, **kwargs) + return jax.lax._original_with_sharding_constraint(x_leaf, s_leaf, *args, **kwargs) # pylint: disable=protected-access return jax.tree_util.tree_map(_safe_leaf_fn, x, sharding) jax.lax.with_sharding_constraint = _safe_with_sharding_constraint - jax.lax._maxtext_with_sharding_constraint_patched = True + jax.lax._maxtext_with_sharding_constraint_patched = True # pylint: disable=protected-access def _prepare_dummy_inputs(mt_config, mesh): + """Builds dummy decoder inputs used to materialize LoRA parameters.""" batch_size = getattr(mt_config, "per_device_batch_size", 1) seq_len = getattr(mt_config, "max_target_length", 1) if batch_size <= 0 or seq_len <= 0: @@ -371,6 +321,7 @@ def _prepare_dummy_inputs(mt_config, mesh): def _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types): + """Pre-creates LoRA parameter tensors for modules matching the target regex.""" rules = [rule for rule in getattr(lora_provider, "_rules", []) if isinstance(rule, qwix_lora.LoraRule)] if not rules: max_logging.log("LoRA precreate: no LoRA rules found on provider, skipping.") @@ -395,7 +346,7 @@ def _with_layer_axis(base_shape_or_transpose, layer_value): skipped_modules = [] precreated_shapes = [] - for path, module in lora_model.iter_modules(): + for path, module in nnx.iter_modules(lora_model): module_path = "/".join(str(p) for p in path) if not compiled_module_path.search(module_path): continue @@ -409,7 +360,7 @@ def _with_layer_axis(base_shape_or_transpose, layer_value): try: kernel_value = qwix_flax_util.unbox(kernel) - except Exception: + except Exception: # pylint: disable=broad-exception-caught if len(skipped_modules) < 10: skipped_modules.append(f"{module_path}: cannot unbox kernel") continue @@ -486,13 +437,13 @@ def _with_layer_axis(base_shape_or_transpose, layer_value): b_sharding_transpose = tuple(range(prefix_rank)) + (None, prefix_rank + 1) def _init_for_module( - self, + self, # pylint: disable=unused-argument a_shape=a_shape, b_shape=b_shape, a_sharding_transpose=a_sharding_transpose, b_sharding_transpose=b_sharding_transpose, ): - qwix_lora._get_or_create_lora_params( + qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access name="kernel", rule=rule, a_shape=a_shape, @@ -507,17 +458,20 @@ def _init_for_module( precreated_shapes.append((module_path, a_shape, b_shape)) max_logging.log( - "LoRA precreate: matched_modules=%s precreated_modules=%s skipped_sample=%s shape_sample=%s" - % (matched_modules, precreated_modules, skipped_modules, precreated_shapes) + f"LoRA precreate: matched_modules={matched_modules} " + f"precreated_modules={precreated_modules} " + f"skipped_sample={skipped_modules} " + f"shape_sample={precreated_shapes}" ) def _verify_lora_parameters(lora_model, mt_config): + """Validates that LoRA is active or that target modules were matched.""" compiled_module_path = re.compile(mt_config.lora_module_path) matched_module_paths = [] sample_module_paths = [] - for path, _ in lora_model.iter_modules(): + for path, _ in nnx.iter_modules(lora_model): module_path = "/".join(str(p) for p in path) if len(sample_module_paths) < 50: sample_module_paths.append(module_path) @@ -537,9 +491,9 @@ def _verify_lora_parameters(lora_model, mt_config): raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") max_logging.log( - "LoRA verification: matched %s target modules but LoRA params are not yet materialized; " - "continuing with lazy LoRA initialization. Sample matches: %s" - % (len(matched_module_paths), matched_module_paths[:10]) + f"LoRA verification: matched {len(matched_module_paths)} target modules but " + "LoRA params are not yet materialized; continuing with lazy LoRA initialization. " + f"Sample matches: {matched_module_paths[:10]}" ) @@ -553,17 +507,16 @@ def maybe_apply_lora(model, mesh, mt_config): if not getattr(mt_config, "enable_lora", False): return model - import qwix - import qwix._src.flax_util as qwix_flax_util - import qwix._src.providers.lora as qwix_lora - import qwix._src.providers.ptq as qwix_ptq - import types + import qwix # pylint: disable=import-outside-toplevel + import qwix._src.flax_util as qwix_flax_util # pylint: disable=import-outside-toplevel + import qwix._src.providers.lora as qwix_lora # pylint: disable=import-outside-toplevel + import qwix._src.providers.ptq as qwix_ptq # pylint: disable=import-outside-toplevel + import types # pylint: disable=import-outside-toplevel _validate_lora_config(mt_config) lora_provider = _build_lora_provider(mt_config, qwix) _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types) - _patch_qwix_find_param(qwix_flax_util) _patch_with_sharding_constraint() decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh) @@ -574,7 +527,11 @@ def maybe_apply_lora(model, mesh, mt_config): decoder_positions=decoder_positions, skip_nnx_init=True, ) + + # Materialize LoRA parameters. Qwix 0.1.5+ unsets RNGs after apply_lora_to_model, + lora_model.set_attributes(qwix_rngs=nnx.Rngs(10003)) _precreate_lora_params(lora_model, lora_provider, mt_config, qwix_flax_util, qwix_lora, types) + lora_model.set_attributes(qwix_rngs=None) if mesh is not None: lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) @@ -583,6 +540,81 @@ def maybe_apply_lora(model, mesh, mt_config): return lora_model +def maybe_restore_lora_from_path(model, mt_config, mesh=None): + """Optionally restores LoRA params from a dedicated adapter checkpoint path. + + If `lora_restore_path` is set and LoRA params have not yet been materialized on + the model, this function attempts to apply LoRA first (when enabled) before + restoring adapter weights. + """ + lora_restore_path = getattr(mt_config, "lora_restore_path", "") + if not lora_restore_path: + return model + + if not tunix_sft_utils.is_lora_enabled(model): + if getattr(mt_config, "enable_lora", False): + max_logging.log("lora_restore_path is set but model has no LoRA params yet; " "applying LoRA before restore.") + model = maybe_apply_lora(model, mesh, mt_config) + + if not tunix_sft_utils.is_lora_enabled(model): + raise ValueError( + "lora_restore_path is set but LoRA is not enabled on the model. " + "Set enable_lora=True and verify lora_module_path matches model modules." + ) + + if not os.path.exists(lora_restore_path): + raise ValueError(f"lora_restore_path does not exist: {lora_restore_path}") + + max_logging.log(f"Restoring LoRA params from: {lora_restore_path}") + + ckptr = ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + restore_concurrent_gb=mt_config.checkpoint_storage_concurrent_gb, + save_concurrent_gb=mt_config.checkpoint_storage_concurrent_gb, + use_ocdbt=mt_config.checkpoint_storage_use_ocdbt, + use_zarr3=mt_config.checkpoint_storage_use_zarr3, + ) + ) + + lora_state = nnx.state(model, nnx.LoRAParam) + metadata = ckptr.metadata(lora_restore_path) + + # Restore is target-driven from the currently materialized `lora_state`. + # Checkpoint adapter paths that do not match these LoRA params are not + # remapped automatically by Orbax during restore. + + # LoRA restore path is NNX-only. + if "params" in metadata.item_metadata.tree.keys() and "params" in metadata.item_metadata.tree.get("params", {}).keys(): + raise ValueError("lora_restore_path must point to an NNX LoRA checkpoint (not Linen format).") + + target_for_restore = jax.tree.map( + lambda v: {"value": v.value}, + lora_state, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + item_to_restore = target_for_restore + restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) + + restored = ckptr.restore( + lora_restore_path, + item=item_to_restore, + transforms={}, + restore_args=restore_args, + ) + + restored_lora = jax.tree.map( + lambda v: v["value"], + restored, + is_leaf=lambda x: isinstance(x, dict) and "value" in x and not isinstance(x.get("value"), dict), + ) + + if restored_lora: + nnx.update(model, restored_lora) + max_logging.log("LoRA restore complete.") + + return model + + def setup_trainer_state(mt_config, goodput_recorder=None): """Set up prerequisites for training loop.""" tunix_config = get_tunix_config(mt_config) @@ -590,6 +622,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) model = maybe_apply_lora(model, mesh, mt_config) + model = maybe_restore_lora_from_path(model, mt_config, mesh) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -615,7 +648,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + trainer.train( + trainer.data_hooks.train_data_iterator, + trainer.data_hooks.eval_data_iterator, + ) return trainer From 376fbfef12db4c979c0f3c35bd00ca90ce62535b Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Mon, 9 Mar 2026 08:31:02 +0000 Subject: [PATCH 09/12] chore: remove unused sharding constraint patch --- .../trainers/post_train/sft/train_sft.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index b0e31a54a6..fac34373de 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -275,32 +275,6 @@ def _fallback_dot_general(): lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) - -def _patch_with_sharding_constraint(): - """Patches sharding constraint to tolerate shape/spec rank mismatches.""" - if getattr(jax.lax, "_maxtext_with_sharding_constraint_patched", False): - return - - jax.lax._original_with_sharding_constraint = jax.lax.with_sharding_constraint # pylint: disable=protected-access - - def _safe_with_sharding_constraint(x, sharding, *args, **kwargs): - def _safe_leaf_fn(x_leaf, s_leaf): - try: - spec = getattr(s_leaf, "spec", s_leaf) - if hasattr(spec, "__len__"): - ndim = getattr(x_leaf, "ndim", None) - if ndim is not None and len(spec) > ndim: - return x_leaf - except Exception: # pylint: disable=broad-exception-caught - pass - return jax.lax._original_with_sharding_constraint(x_leaf, s_leaf, *args, **kwargs) # pylint: disable=protected-access - - return jax.tree_util.tree_map(_safe_leaf_fn, x, sharding) - - jax.lax.with_sharding_constraint = _safe_with_sharding_constraint - jax.lax._maxtext_with_sharding_constraint_patched = True # pylint: disable=protected-access - - def _prepare_dummy_inputs(mt_config, mesh): """Builds dummy decoder inputs used to materialize LoRA parameters.""" batch_size = getattr(mt_config, "per_device_batch_size", 1) @@ -517,7 +491,6 @@ def maybe_apply_lora(model, mesh, mt_config): lora_provider = _build_lora_provider(mt_config, qwix) _patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types) - _patch_with_sharding_constraint() decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh) lora_model = qwix.apply_lora_to_model( From 94e515c11b5fc70796c4e6bca37effe08d5257ba Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Tue, 10 Mar 2026 10:51:08 +0000 Subject: [PATCH 10/12] feat: improve LoRA parameter handling in dot_general --- .../trainers/post_train/sft/train_sft.py | 381 ++++++++++-------- 1 file changed, 215 insertions(+), 166 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index fac34373de..ad81bd5ad2 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -201,17 +201,6 @@ def _dot_general_with_3d( preferred_element_type=None, out_sharding=None, ): - def _fallback_dot_general(): - return qwix_lora.LoraProvider.dot_general( - self, - lhs, - rhs, - dimension_numbers, - precision, - preferred_element_type, - out_sharding=out_sharding, - ) - res = qwix_ptq.PtqProvider.dot_general( self, lhs, @@ -222,56 +211,74 @@ def _fallback_dot_general(): out_sharding=out_sharding, ) - rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) # pylint: disable=protected-access + rule, _ = self._get_current_rule_and_op_id("dot_general", repeated_call=True) if not isinstance(rule, qwix_lora.LoraRule): return res weight_name = qwix_flax_util.find_param(rhs, qwix_lora.ptq.WithAux) if weight_name is None: - return res - - if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0,) and not dimension_numbers[1][1]: - lora_params = qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access - name=weight_name, - rule=rule, - a_shape=(rhs.shape[0], rule.rank), - b_shape=(rule.rank, rhs.shape[1] * rhs.shape[2]), - a_sharding_transpose=(0, None), - b_sharding_transpose=(None, 1), + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding ) - lora_a, lora_b = lora_params[:2] - if rule.dropout > 0: - lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) - lora_b = jnp.reshape(lora_b, (rule.rank, rhs.shape[1], rhs.shape[2])) - delta = jnp.einsum("...k,kr->...r", lhs, lora_a) - delta = jnp.einsum("...r,rnm->...nm", delta, lora_b) - return res + delta * (rule.alpha / rule.rank) - - if len(rhs.shape) == 3 and tuple(dimension_numbers[0][1]) == (0, 1) and not dimension_numbers[1][1]: - k = rhs.shape[0] * rhs.shape[1] - lora_params = qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access - name=weight_name, - rule=rule, - a_shape=(k, rule.rank), - b_shape=(rule.rank, rhs.shape[2]), - a_sharding_transpose=(0, None), - b_sharding_transpose=(None, 1), + + try: + current_module = qwix_flax_util.get_current_module() + lora_a = getattr(current_module, f"{weight_name}_lora_a", None) + lora_b = getattr(current_module, f"{weight_name}_lora_b", None) + + if lora_a is None or lora_b is None: + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) + + if isinstance(lora_a, nnx.Variable): + lora_a = lora_a[...] + if isinstance(lora_b, nnx.Variable): + lora_b = lora_b[...] + except Exception: + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding ) - lora_a, lora_b = lora_params[:2] - if rule.dropout > 0: - lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) - contract_axes = tuple(dimension_numbers[0][0]) - lhs_perm = [i for i in range(lhs.ndim) if i not in contract_axes] + list(contract_axes) - lhs_trans = jnp.transpose(lhs, lhs_perm) - lhs_shape = lhs_trans.shape - lhs_flat = jnp.reshape(lhs_trans, lhs_shape[: -len(contract_axes)] + (k,)) - if lora_a.shape[0] != k: - return _fallback_dot_general() - delta = jnp.einsum("...k,kr->...r", lhs_flat, lora_a) - delta = jnp.einsum("...r,rm->...m", delta, lora_b) - return res + delta * (rule.alpha / rule.rank) - - return _fallback_dot_general() + + if rule.dropout > 0: + lhs = nnx.Dropout(rule.dropout)(lhs, rngs=qwix_flax_util.make_rng("dropout")) + + contract_axes_lhs = tuple(dimension_numbers[0][0]) + contract_axes_rhs = tuple(dimension_numbers[0][1]) + + # If the default provider fails due to shape, we handle it universally here. + if len(rhs.shape) > 2: + k = 1 + for axis in contract_axes_rhs: + k *= rhs.shape[axis] + + out_dim = lora_b.size // rule.rank + + # Validate that LoRA shapes make mathematical sense + if lora_a.size == k * rule.rank and lora_b.size == rule.rank * out_dim: + # Reshape A to 2D + lora_a_flat = jnp.reshape(lora_a, (k, rule.rank)) + + # Reshape B to 2D + lora_b_flat = jnp.reshape(lora_b, (rule.rank, out_dim)) + + # Flatten LHS to abstract over multiple batch/sequence dimensions + lhs_perm = [i for i in range(lhs.ndim) if i not in contract_axes_lhs] + list(contract_axes_lhs) + lhs_trans = jnp.transpose(lhs, lhs_perm) + lhs_shape = lhs_trans.shape + lhs_flat = jnp.reshape(lhs_trans, (-1, k)) + + # Do the 2D LoRA math + delta_flat = lhs_flat @ lora_a_flat @ lora_b_flat + + # Unflatten the delta to match the original result shape + delta = jnp.reshape(delta_flat, res.shape) + + return res + delta * (rule.alpha / rule.rank) + + return qwix_lora.LoraProvider.dot_general( + self, lhs, rhs, dimension_numbers, precision, preferred_element_type, out_sharding=out_sharding + ) lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider) @@ -315,110 +322,101 @@ def _with_layer_axis(base_shape_or_transpose, layer_value): values.insert(axis, layer_value) return tuple(values) + def _extract_kernel_shape(kernel_value): + kernel_shape = getattr(kernel_value, "shape", None) + if kernel_shape is None and hasattr(kernel_value, "array"): + kernel_shape = getattr(kernel_value.array, "shape", None) + if kernel_shape is None and hasattr(kernel_value.array, "qvalue"): + kernel_shape = getattr(kernel_value.array.qvalue, "shape", None) + if kernel_shape is None: + return None + return tuple(int(dim) for dim in kernel_shape) + matched_modules = 0 precreated_modules = 0 skipped_modules = [] precreated_shapes = [] - for path, module in nnx.iter_modules(lora_model): - module_path = "/".join(str(p) for p in path) - if not compiled_module_path.search(module_path): - continue - - matched_modules += 1 - kernel = getattr(module, "kernel", None) - if kernel is None: - if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: no kernel") - continue - + def _process_param(module, module_path, param_name, param_obj, in_features_shape, out_features_shape): + nonlocal precreated_modules try: - kernel_value = qwix_flax_util.unbox(kernel) - except Exception: # pylint: disable=broad-exception-caught + kernel_value = qwix_flax_util.unbox(param_obj) + except Exception: if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: cannot unbox kernel") - continue + skipped_modules.append(f"{module_path}.{param_name}: cannot unbox kernel") + return False - kernel_shape = getattr(kernel_value, "shape", None) - if kernel_shape is None and hasattr(kernel_value, "array"): - kernel_shape = getattr(kernel_value.array, "shape", None) - if kernel_shape is None and hasattr(kernel_value.array, "qvalue"): - kernel_shape = getattr(kernel_value.array.qvalue, "shape", None) + kernel_shape = _extract_kernel_shape(kernel_value) if kernel_shape is None or len(kernel_shape) < 2: if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: unsupported kernel shape {kernel_shape}") - continue + skipped_modules.append(f"{module_path}.{param_name}: unsupported kernel shape {kernel_shape}") + return False + + expected_suffix = in_features_shape + out_features_shape + layer_axis = None + base_kernel_shape = None + + # 1. Determine if this parameter is scanned over layers + if isinstance(num_decoder_layers, int) and len(kernel_shape) >= len(expected_suffix) + 1: + # Prefer param_scan_axis if it matches the expected layer count + if kernel_shape[param_scan_axis] == num_decoder_layers: + candidate_base = tuple(dim for i, dim in enumerate(kernel_shape) if i != param_scan_axis) + if candidate_base[-len(expected_suffix):] == expected_suffix: + layer_axis = param_scan_axis + base_kernel_shape = candidate_base + + # If not found at param_scan_axis, search other axes (for edge cases where scan axis might differ) + if layer_axis is None: + for axis in range(len(kernel_shape)): + if kernel_shape[axis] == num_decoder_layers: + candidate_base = tuple(dim for i, dim in enumerate(kernel_shape) if i != axis) + if candidate_base[-len(expected_suffix):] == expected_suffix: + layer_axis = axis + base_kernel_shape = candidate_base + break + + # 2. Check if it's an unscanned parameter + if layer_axis is None and len(kernel_shape) >= len(expected_suffix): + if kernel_shape[-len(expected_suffix):] == expected_suffix: + base_kernel_shape = kernel_shape + + # 3. If neither matched, skip this parameter + if base_kernel_shape is None: + if len(skipped_modules) < 10: + skipped_modules.append(f"{module_path}.{param_name}: kernel shape {kernel_shape} does not match expected suffix {expected_suffix}") + return False - is_scanned_decoder_module = ( - "decoder/layers/" in module_path and isinstance(num_decoder_layers, int) and num_decoder_layers > 1 - ) + prefix_shape = base_kernel_shape[:-len(expected_suffix)] if len(expected_suffix) > 0 else base_kernel_shape - if is_scanned_decoder_module: - layer_axis = None - if 0 <= param_scan_axis < len(kernel_shape): - layer_axis = int(param_scan_axis) - elif len(kernel_shape) > 1 and int(kernel_shape[1]) == int(num_decoder_layers): - layer_axis = 1 - else: - for axis, dim in enumerate(kernel_shape): - if int(dim) == int(num_decoder_layers): - layer_axis = axis - break - if layer_axis is None: - if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: unable to infer layer axis from kernel shape {kernel_shape}") - continue - - effective_shape = tuple(int(dim) for i, dim in enumerate(kernel_shape) if i != layer_axis) - if len(effective_shape) < 2: - if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: unsupported effective shape {effective_shape}") - continue - - if "decoder/layers/self_attention/out" in module_path and len(effective_shape) >= 3: - in_dim = int(math.prod(effective_shape[:-1])) - out_dim = int(effective_shape[-1]) - else: - in_dim = int(effective_shape[0]) - out_dim = int(math.prod(effective_shape[1:])) - if in_dim <= 0 or out_dim <= 0: - if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: non-positive dims in={in_dim} out={out_dim}") - continue - - in_axis = next(i for i in range(len(kernel_shape)) if i != layer_axis) - out_axis = next(i for i in range(len(kernel_shape) - 1, -1, -1) if i != layer_axis) - - a_shape = _with_layer_axis((in_dim, rule.rank), num_decoder_layers) - b_shape = _with_layer_axis((rule.rank, out_dim), num_decoder_layers) - a_sharding_transpose = _with_layer_axis((in_axis, None), layer_axis) - b_sharding_transpose = _with_layer_axis((None, out_axis), layer_axis) + # 4. Compute axes mapped sequentially for the base (unscanned) shape + prefix_axes_base = tuple(range(len(prefix_shape))) + input_axes_base = tuple(range(len(prefix_shape), len(prefix_shape) + len(in_features_shape))) + output_axes_base = tuple(range(len(prefix_shape) + len(in_features_shape), len(base_kernel_shape))) + + # 5. Shift axes to account for the layer_axis insertion + if layer_axis is not None: + def shift_axes(axes): + return tuple(axis if axis < layer_axis else axis + 1 for axis in axes) + + a_shape = _with_layer_axis(prefix_shape + in_features_shape + (rule.rank,), num_decoder_layers) + b_shape = _with_layer_axis(prefix_shape + (rule.rank,) + out_features_shape, num_decoder_layers) + a_sharding_transpose = _with_layer_axis(shift_axes(prefix_axes_base + input_axes_base) + (None,), layer_axis) + b_sharding_transpose = _with_layer_axis(shift_axes(prefix_axes_base) + (None,) + shift_axes(output_axes_base), layer_axis) else: - prefix_shape = tuple(kernel_shape[:-2]) - in_dim = int(kernel_shape[-2]) - out_dim = int(kernel_shape[-1]) - if in_dim <= 0 or out_dim <= 0: - if len(skipped_modules) < 10: - skipped_modules.append(f"{module_path}: non-positive dims in={in_dim} out={out_dim}") - continue - - full_prefix_shape = prefix_shape - a_shape = full_prefix_shape + (in_dim, rule.rank) - b_shape = full_prefix_shape + (rule.rank, out_dim) - - prefix_rank = len(full_prefix_shape) - a_sharding_transpose = tuple(range(prefix_rank)) + (prefix_rank, None) - b_sharding_transpose = tuple(range(prefix_rank)) + (None, prefix_rank + 1) + a_shape = prefix_shape + in_features_shape + (rule.rank,) + b_shape = prefix_shape + (rule.rank,) + out_features_shape + a_sharding_transpose = prefix_axes_base + input_axes_base + (None,) + b_sharding_transpose = prefix_axes_base + (None,) + output_axes_base def _init_for_module( - self, # pylint: disable=unused-argument + self, a_shape=a_shape, b_shape=b_shape, a_sharding_transpose=a_sharding_transpose, b_sharding_transpose=b_sharding_transpose, ): qwix_lora._get_or_create_lora_params( # pylint: disable=protected-access - name="kernel", + name=param_name, rule=rule, a_shape=a_shape, b_shape=b_shape, @@ -429,7 +427,36 @@ def _init_for_module( types.MethodType(_init_for_module, module)() precreated_modules += 1 if len(precreated_shapes) < 10: - precreated_shapes.append((module_path, a_shape, b_shape)) + precreated_shapes.append((f"{module_path}.{param_name}", a_shape, b_shape)) + return True + + + for path, module in nnx.iter_modules(lora_model): + module_path = "/".join(str(p) for p in path) + if not compiled_module_path.search(module_path): + continue + + matched_modules += 1 + + # DenseGeneral-style layers (Standard, Vision, Audio) + if hasattr(module, "in_features_shape") and hasattr(module, "out_features_shape"): + in_features_shape = tuple(int(dim) for dim in getattr(module, "in_features_shape", ())) + out_features_shape = tuple(int(dim) for dim in getattr(module, "out_features_shape", ())) + if hasattr(module, "kernel"): + _process_param(module, module_path, "kernel", module.kernel, in_features_shape, out_features_shape) + + # MoE-style layers (RoutedMoE, RoutedAndSharedMoE) + elif type(module).__name__ in ["RoutedMoE", "RoutedAndSharedMoE"]: + emb_dim = getattr(getattr(module, "config", None), "emb_dim", None) + if emb_dim is not None: + intermediate_dim = getattr(module, "intermediate_dim", getattr(getattr(module, "config", None), "moe_mlp_dim", None)) + if intermediate_dim is not None: + if hasattr(module, "wi_0"): + _process_param(module, module_path, "wi_0", module.wi_0, (emb_dim,), (intermediate_dim,)) + if hasattr(module, "wi_1"): + _process_param(module, module_path, "wi_1", module.wi_1, (emb_dim,), (intermediate_dim,)) + if hasattr(module, "wo"): + _process_param(module, module_path, "wo", module.wo, (intermediate_dim,), (emb_dim,)) max_logging.log( f"LoRA precreate: matched_modules={matched_modules} " @@ -549,44 +576,64 @@ def maybe_restore_lora_from_path(model, mt_config, mesh=None): ) ) - lora_state = nnx.state(model, nnx.LoRAParam) metadata = ckptr.metadata(lora_restore_path) - # Restore is target-driven from the currently materialized `lora_state`. - # Checkpoint adapter paths that do not match these LoRA params are not - # remapped automatically by Orbax during restore. - - # LoRA restore path is NNX-only. if "params" in metadata.item_metadata.tree.keys() and "params" in metadata.item_metadata.tree.get("params", {}).keys(): raise ValueError("lora_restore_path must point to an NNX LoRA checkpoint (not Linen format).") - target_for_restore = jax.tree.map( - lambda v: {"value": v.value}, - lora_state, - is_leaf=lambda n: isinstance(n, nnx.Variable), - ) - item_to_restore = target_for_restore - restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore) - - restored = ckptr.restore( - lora_restore_path, - item=item_to_restore, - transforms={}, - restore_args=restore_args, - ) - - restored_lora = jax.tree.map( - lambda v: v["value"], - restored, - is_leaf=lambda x: isinstance(x, dict) and "value" in x and not isinstance(x.get("value"), dict), - ) + lora_state = nnx.state(model, nnx.LoRAParam) - if restored_lora: - nnx.update(model, restored_lora) + # Restore without target to avoid shape mismatch issues during restoration. + # We will handle shape matching and potential reshaping manually. + restored = ckptr.restore(lora_restore_path) + + if restored: + flat_restored, _ = jax.tree_util.tree_flatten_with_path(restored) + for path, source in flat_restored: + try: + var_key_path = path + last_k = path[-1] + last_key = last_k.key if isinstance(last_k, jax.tree_util.DictKey) else (last_k.idx if hasattr(last_k, "idx") else last_k) + if last_key == 'value': + var_key_path = path[:-1] + val = source + else: + val = source["value"] if isinstance(source, dict) and "value" in source else source + + target = lora_state + for k in var_key_path: + k_val = k.key if isinstance(k, jax.tree_util.DictKey) else (k.idx if hasattr(k, "idx") else k) + target = target[k_val] + + if isinstance(target, nnx.Variable): + if hasattr(val, "shape") and val.shape != target.value.shape: + val = jnp.reshape(val, target.value.shape) + target.value = val + except Exception as e: + max_logging.log(f"Failed to restore path {path}: {e}") + + nnx.update(model, lora_state) max_logging.log("LoRA restore complete.") return model +def _maybe_resume_trainer_from_lora_restore_path(trainer, mt_config, tunix_config): + """Updates trainer steps if restoring from a dedicated LoRA checkpoint.""" + lora_restore_path = getattr(mt_config, "lora_restore_path", "") + if lora_restore_path and getattr(trainer, "_train_steps", 0) == 0: + parts = lora_restore_path.strip('/').split('/') + if len(parts) >= 2 and parts[-1] == 'model_params': + try: + start_step = int(parts[-2]) + trainer._train_steps = start_step + grad_accum = getattr(tunix_config, "gradient_accumulation_steps", None) or 1 + trainer._iter_steps = start_step * grad_accum + if hasattr(trainer, "_prof") and trainer._prof: + trainer._prof.initial_step = trainer._iter_steps + max_logging.log(f"Resuming trainer manually from step {start_step} based on lora_restore_path.") + except ValueError: + pass + return trainer def setup_trainer_state(mt_config, goodput_recorder=None): """Set up prerequisites for training loop.""" @@ -611,6 +658,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = _maybe_resume_trainer_from_lora_restore_path(trainer, mt_config, tunix_config) + trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) From 69e481b6cfbada151a20d2f66943a164e85fca4b Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Wed, 11 Mar 2026 07:51:09 +0000 Subject: [PATCH 11/12] feat: enhance LoRA checkpoint handling and restore logic --- src/maxtext/configs/types.py | 5 +- .../trainers/post_train/sft/train_sft.py | 141 +++++++++--------- 2 files changed, 76 insertions(+), 70 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 931c854624..fce9d69577 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1085,8 +1085,9 @@ class LoRA(BaseModel): lora_restore_path: PathStr = Field( "", description=( - "Optional NNX LoRA checkpoint path to restore adapter weights from." - " This must be the direct `model_params` path." + "Optional Tunix NNX LoRA checkpoint path to restore adapter weights from." + " This may point to the checkpoint root, a numeric step directory," + " or a direct `model_params` path." ), ) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index ad81bd5ad2..5fe8c76c8f 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -51,6 +51,7 @@ from orbax import checkpoint as ocp +from tunix.sft import checkpoint_manager as tunix_checkpoint_manager from tunix.sft import metrics_logger, peft_trainer, profiler from tunix.sft import utils as tunix_sft_utils from tunix.rl import reshard @@ -540,16 +541,42 @@ def maybe_apply_lora(model, mesh, mt_config): return lora_model +def _resolve_lora_restore_checkpoint(lora_restore_path): + """Normalizes lora_restore_path into Tunix checkpoint manager inputs.""" + normalized_path = os.path.normpath(lora_restore_path) + basename = os.path.basename(normalized_path) + + if basename == "model_params": + step_dir = os.path.dirname(normalized_path) + root_directory = os.path.dirname(step_dir) + step_name = os.path.basename(step_dir) + try: + return root_directory, int(step_name) + except ValueError as exc: + raise ValueError( + "lora_restore_path ending in 'model_params' must live under a numeric step directory." + ) from exc + + if basename.isdigit(): + return os.path.dirname(normalized_path), int(basename) + + return normalized_path, None + + def maybe_restore_lora_from_path(model, mt_config, mesh=None): """Optionally restores LoRA params from a dedicated adapter checkpoint path. If `lora_restore_path` is set and LoRA params have not yet been materialized on the model, this function attempts to apply LoRA first (when enabled) before restoring adapter weights. + + Returns: + A tuple of `(model, resume_step)` where `resume_step` is the step returned + by Tunix checkpoint restore. """ lora_restore_path = getattr(mt_config, "lora_restore_path", "") if not lora_restore_path: - return model + return model, None if not tunix_sft_utils.is_lora_enabled(model): if getattr(mt_config, "enable_lora", False): @@ -565,74 +592,47 @@ def maybe_restore_lora_from_path(model, mt_config, mesh=None): if not os.path.exists(lora_restore_path): raise ValueError(f"lora_restore_path does not exist: {lora_restore_path}") - max_logging.log(f"Restoring LoRA params from: {lora_restore_path}") + restore_root_directory, restore_step = _resolve_lora_restore_checkpoint(lora_restore_path) + max_logging.log( + f"Restoring LoRA params from checkpoint root '{restore_root_directory}' " + f"at step {restore_step if restore_step is not None else 'latest'}." + ) - ckptr = ocp.Checkpointer( - ocp.PyTreeCheckpointHandler( - restore_concurrent_gb=mt_config.checkpoint_storage_concurrent_gb, - save_concurrent_gb=mt_config.checkpoint_storage_concurrent_gb, - use_ocdbt=mt_config.checkpoint_storage_use_ocdbt, - use_zarr3=mt_config.checkpoint_storage_use_zarr3, - ) + checkpoint_manager = tunix_checkpoint_manager.CheckpointManager( + root_directory=restore_root_directory, ) + try: + restored_step, _ = checkpoint_manager.maybe_restore( + model, + step=restore_step, + restore_only_lora_params=True, + ) + finally: + checkpoint_manager.close() - metadata = ckptr.metadata(lora_restore_path) - - if "params" in metadata.item_metadata.tree.keys() and "params" in metadata.item_metadata.tree.get("params", {}).keys(): - raise ValueError("lora_restore_path must point to an NNX LoRA checkpoint (not Linen format).") - - lora_state = nnx.state(model, nnx.LoRAParam) - - # Restore without target to avoid shape mismatch issues during restoration. - # We will handle shape matching and potential reshaping manually. - restored = ckptr.restore(lora_restore_path) - - if restored: - flat_restored, _ = jax.tree_util.tree_flatten_with_path(restored) - for path, source in flat_restored: - try: - var_key_path = path - last_k = path[-1] - last_key = last_k.key if isinstance(last_k, jax.tree_util.DictKey) else (last_k.idx if hasattr(last_k, "idx") else last_k) - if last_key == 'value': - var_key_path = path[:-1] - val = source - else: - val = source["value"] if isinstance(source, dict) and "value" in source else source - - target = lora_state - for k in var_key_path: - k_val = k.key if isinstance(k, jax.tree_util.DictKey) else (k.idx if hasattr(k, "idx") else k) - target = target[k_val] - - if isinstance(target, nnx.Variable): - if hasattr(val, "shape") and val.shape != target.value.shape: - val = jnp.reshape(val, target.value.shape) - target.value = val - except Exception as e: - max_logging.log(f"Failed to restore path {path}: {e}") - - nnx.update(model, lora_state) - max_logging.log("LoRA restore complete.") - - return model - -def _maybe_resume_trainer_from_lora_restore_path(trainer, mt_config, tunix_config): - """Updates trainer steps if restoring from a dedicated LoRA checkpoint.""" - lora_restore_path = getattr(mt_config, "lora_restore_path", "") - if lora_restore_path and getattr(trainer, "_train_steps", 0) == 0: - parts = lora_restore_path.strip('/').split('/') - if len(parts) >= 2 and parts[-1] == 'model_params': - try: - start_step = int(parts[-2]) - trainer._train_steps = start_step - grad_accum = getattr(tunix_config, "gradient_accumulation_steps", None) or 1 - trainer._iter_steps = start_step * grad_accum - if hasattr(trainer, "_prof") and trainer._prof: - trainer._prof.initial_step = trainer._iter_steps - max_logging.log(f"Resuming trainer manually from step {start_step} based on lora_restore_path.") - except ValueError: - pass + if restore_step is not None and restored_step != restore_step: + raise ValueError( + f"Expected LoRA restore from step {restore_step}, got step {restored_step}." + ) + + if restored_step == 0: + raise ValueError(f"No LoRA checkpoint found for lora_restore_path: {lora_restore_path}") + + max_logging.log("LoRA restore complete.") + return model, restored_step + + +def _maybe_resume_trainer_from_step(trainer, resume_step, tunix_config, source): + """Applies a recovered step to a freshly initialized trainer if needed.""" + if not resume_step or getattr(trainer, "_train_steps", 0) != 0: + return trainer + + grad_accum = getattr(tunix_config, "gradient_accumulation_steps", None) or 1 + trainer._train_steps = resume_step + trainer._iter_steps = resume_step * grad_accum + if hasattr(trainer, "_prof") and trainer._prof: + trainer._prof.initial_step = trainer._iter_steps + max_logging.log(f"Resuming trainer manually from step {resume_step} based on {source}.") return trainer def setup_trainer_state(mt_config, goodput_recorder=None): @@ -642,7 +642,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) model = maybe_apply_lora(model, mesh, mt_config) - model = maybe_restore_lora_from_path(model, mt_config, mesh) + model, lora_resume_step = maybe_restore_lora_from_path(model, mt_config, mesh) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -658,7 +658,12 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) - trainer = _maybe_resume_trainer_from_lora_restore_path(trainer, mt_config, tunix_config) + trainer = _maybe_resume_trainer_from_step( + trainer, + lora_resume_step, + tunix_config, + source="lora_restore_path", + ) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) From 468623f5b6b3c7b55649f4d0dbcb899663cd9de8 Mon Sep 17 00:00:00 2001 From: Emma Lien Date: Wed, 11 Mar 2026 07:23:35 +0000 Subject: [PATCH 12/12] feat: add automated MaxText to HF LoRA conversion script --- .../maxtext_to_hf_lora.py | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py diff --git a/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py new file mode 100644 index 0000000000..7e2d993ebb --- /dev/null +++ b/src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py @@ -0,0 +1,143 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script converts a MaxText LoRA adapter (checkpoint) back to HuggingFace PEFT format. + +Key Parameters: + model_name: The name of the model (e.g., "llama3.1-8b"). + maxtext_ckpt_path: Path to the MaxText checkpoint directory (e.g., .../checkpoints/100/model_params). + hf_model_id: The base HuggingFace model ID for config mapping. + output_dir: The directory where the HuggingFace adapter will be saved. + lora_r: The rank of the LoRA adapter. + lora_alpha: The alpha parameter for LoRA. + +Example Usage: + python src/maxtext/checkpoint_conversion/maxtext_to_hf_lora.py \ + model_name="llama3.1-8b" \ + maxtext_ckpt_path="/path/to/maxtext_lora/ckpt" \ + hf_model_id="meta-llama/Llama-3.1-8B" \ + output_dir="/path/to/hf/adapter/output" +""" + +import os +import json +import numpy as np +import sys +from safetensors.numpy import save_file +from orbax import checkpoint as ocp +from etils import epath +from transformers import AutoConfig +from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING + +def parse_args(args): + """Parses command line arguments in the format key=value.""" + parsed_args = {} + for arg in args: + if "=" in arg: + key, value = arg.split("=", 1) + parsed_args[key] = value + return parsed_args + +def convert(model_name, maxtext_ckpt_path, hf_model_id, output_dir, lora_r=16, lora_alpha=32): + print(f"[*] Starting conversion from {maxtext_ckpt_path}") + + # Initialize Orbax Checkpointer + mngr = ocp.PyTreeCheckpointer() + mt_params = mngr.restore(epath.Path(maxtext_ckpt_path)) + + # Load HF Config for mapping + hf_config = AutoConfig.from_pretrained(hf_model_id).to_dict() + + class MockConfig: + scan_layers = True + model_name = "llama3.1-8b" + + # Get the parameter mapping for the specific model + mapping = PARAM_MAPPING[model_name](hf_config, MockConfig(), scan_layers=True) + final_hf_weights = {} + + def process_data(current_dict, parent_path="decoder/layers"): + """Recursive function to traverse MaxText params and map to HF.""" + for module_name, content in current_dict.items(): + path = f"{parent_path}/{module_name}" + + # Identify LoRA layers + if isinstance(content, dict) and 'kernel_lora_a' in content: + lookup_key = "params-" + path.replace("/", "-") + "-kernel" + + if lookup_key in mapping: + # Get the JAX values (as numpy) + data_a = np.array(content['kernel_lora_a']['value']) + data_b = np.array(content['kernel_lora_b']['value']) + hf_paths = mapping[lookup_key] + + # MaxText stacks multiple heads/projections, iterate through them + for i in range(data_a.shape[1]): + name = hf_paths[i].replace(".weight", "") + # Apply Transpose (.T) to match PyTorch dimension logic + final_hf_weights[f"base_model.model.{name}.lora_A.weight"] = data_a[:, i, :].T + final_hf_weights[f"base_model.model.{name}.lora_B.weight"] = data_b[:, i, :].T + + print(f"[DEBUG] Processed: {path}") + + elif isinstance(content, dict): + process_data(content, path) + + # Start recursion + process_data(mt_params['decoder']['layers']) + + # Save Safetensors + os.makedirs(output_dir, exist_ok=True) + adapter_file = os.path.join(output_dir, "adapter_model.safetensors") + save_file(final_hf_weights, adapter_file) + + # Create PEFT adapter_config.json + config_json = { + "base_model_name_or_path": hf_model_id, + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": int(lora_r), + "lora_alpha": int(lora_alpha), + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0.0, + "bias": "none", + "inference_mode": True + } + + config_file = os.path.join(output_dir, "adapter_config.json") + with open(config_file, "w") as f: + json.dump(config_json, f, indent=4) + + print(f"\n[!] Conversion Complete!") + print(f" Saved weights to: {adapter_file}") + print(f" Saved config to: {config_file}") + +if __name__ == "__main__": + cli_args = parse_args(sys.argv[1:]) + + # Required parameters check + required = ["model_name", "maxtext_ckpt_path", "hf_model_id", "output_dir"] + if not all(k in cli_args for k in required): + print(__doc__) + sys.exit(1) + + convert( + model_name=cli_args["model_name"], + maxtext_ckpt_path=cli_args["maxtext_ckpt_path"], + hf_model_id=cli_args["hf_model_id"], + output_dir=cli_args["output_dir"], + lora_r=cli_args.get("lora_r", 16), + lora_alpha=cli_args.get("lora_alpha", 32) + ) \ No newline at end of file