From bc62d649bb60971235d76aaccb69aebae043cbb5 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Wed, 27 May 2026 10:07:15 +0000 Subject: [PATCH] feat: Add Gemma3 LoRA SFT integration and programmatic weight mapping end-to-end support --- .../checkpoint_conversion/to_huggingface.py | 6 +- .../checkpoint_conversion/utils/utils.py | 45 +++- src/maxtext/inference/vllm_decode.py | 85 ++++++- .../integration/tunix/tunix_adapter.py | 226 ++++++++++++++++++ src/maxtext/integration/tunix/utils.py | 43 +++- .../tunix/weight_mapping/__init__.py | 3 + .../tunix/weight_mapping/gemma3.py | 122 ++++++++++ .../vllm/maxtext_vllm_adapter/adapter.py | 5 + src/maxtext/utils/lora_utils.py | 19 +- src/maxtext/utils/model_creation_utils.py | 4 +- .../tpu/gemma3/4b/test_gemma3_lora.sh | 89 +++++++ .../integration/setup_train_loop_nnx_test.py | 5 +- tests/unit/hf_checkpoint_conversion_test.py | 119 +++++++++ tests/unit/test_checkpoint_merging.py | 142 +++++++++++ 14 files changed, 874 insertions(+), 39 deletions(-) create mode 100644 src/maxtext/integration/tunix/weight_mapping/gemma3.py create mode 100644 tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh create mode 100644 tests/unit/test_checkpoint_merging.py diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index f6c96b11df..30b7b3acb3 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -417,11 +417,11 @@ def main(argv: Sequence[str]) -> None: hook_fn_map = mappings["hook_fn_mapping"] # 4. Extract and transform weights for Linen/NNX-SFT/NNX-RL checkpoints - maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict) + maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict, config) # Validate that checkpoint keys match the parameter mapping - state_keys = set(maxtext_state_dict) | { - k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k + state_keys = { + k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict } filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys) diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 93253cffb0..84fc800ff8 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -794,6 +794,16 @@ def format_meter( return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs) +def recursive_update(d: dict, u: dict) -> dict: + """Recursively updates dictionary d with dictionary u in place.""" + for k, v in u.items(): + if isinstance(v, dict) and k in d and isinstance(d[k], dict): + recursive_update(d[k], v) + else: + d[k] = v + return d + + def load_orbax_checkpoint(config) -> dict: """Loads Orbax checkpoints from Base and/or LoRA paths in config. @@ -837,7 +847,7 @@ def create_restore_args(tree_metadata): metadata.item_metadata.tree, is_leaf=lambda x: hasattr(x, "shape"), ) - merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args)) + recursive_update(merged_dict, ckptr.restore(checkpoint_path, restore_args=restore_args)) return merged_dict @@ -918,7 +928,7 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]: return result -def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray]: +def detect_and_extract_checkpoint(checkpoint_dict: dict, config=None) -> dict[str, np.ndarray]: """Detect checkpoint type (Linen vs NNX) and extract weights. Handles multiple NNX checkpoint variants: @@ -932,24 +942,35 @@ def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray Args: checkpoint_dict: Raw checkpoint dictionary from Orbax + config: Optional MaxText configuration Returns: Dictionary mapping MaxText parameter names to weight arrays """ + # Determine if we are using an NNX model from config or if there is no top-level "params" key + is_nnx = config.enable_nnx if config is not None else ("params" not in checkpoint_dict) + # Detect checkpoint type by structure actual_weights_dict = checkpoint_dict.get("params") - if actual_weights_dict is None: - # NNX checkpoint: structure is directly at the root - # Check for NNX-RL variant with 'base' wrapper - if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict): - # NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}} - max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)") - return extract_nnx_weights(checkpoint_dict["base"]) + if is_nnx: + if actual_weights_dict is None: + # NNX checkpoint: structure is directly at the root + # Check for NNX-RL variant with 'base' wrapper + if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict): + # NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}} + max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)") + return extract_nnx_weights(checkpoint_dict["base"]) + else: + # NNX-SFT: {'decoder': ..., 'token_embedder': ...} + max_logging.log("Detected NNX-SFT checkpoint structure") + return extract_nnx_weights(checkpoint_dict) else: - # NNX-SFT: {'decoder': ..., 'token_embedder': ...} - max_logging.log("Detected NNX-SFT checkpoint structure") - return extract_nnx_weights(checkpoint_dict) + # NNX checkpoint wrapped inside top-level 'params' key + if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict: + actual_weights_dict = actual_weights_dict["params"] + max_logging.log("Detected NNX checkpoint structure wrapped in 'params'") + return extract_nnx_weights(actual_weights_dict) else: # Linen checkpoint: check if there's a nested 'params' key if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict: diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index c2b1e5e5d2..5b9446c5c4 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -40,6 +40,7 @@ from maxtext.utils import model_creation_utils from maxtext.utils import max_logging +from maxtext.utils import lora_utils from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR from maxtext.common.common_types import Config from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter @@ -96,6 +97,17 @@ def decode_with_vllm(config: Config) -> None: }, } + if config.lora.enable_lora: + vllm_args["additional_config"]["maxtext_config"].update( + { + "lora.enable_lora": config.lora.enable_lora, + "lora.lora_restore_path": config.lora.lora_restore_path, + "lora.lora_rank": config.lora.lora_rank, + "lora.lora_alpha": config.lora.lora_alpha, + "lora.lora_module_path": config.lora.lora_module_path, + } + ) + if config.load_parameters_path: vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = config.load_parameters_path else: @@ -178,7 +190,19 @@ def decode_with_tunix( mesh: The JAX mesh for parallelism. """ # Wrap the model for Tunix - tunix_model = TunixMaxTextAdapter(base_model=model) + use_no_op_mappings = False + if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides: + overrides = config.vllm_hf_overrides + if isinstance(overrides, str) and "MaxTextForCausalLM" in overrides: + use_no_op_mappings = True + elif isinstance(overrides, dict) and "MaxTextForCausalLM" in overrides.get("architectures", []): + use_no_op_mappings = True + + tunix_model = TunixMaxTextAdapter( + base_model=model, + use_no_op_mappings=use_no_op_mappings, + mesh=mesh, + ) # Load the tokenizer tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -210,13 +234,57 @@ def decode_with_tunix( ) # Create vLLM rollout for inference + rollout_vllm_lora_config = None + if config.lora.enable_lora: + rollout_vllm_lora_config = { + "module_path": lora_utils._get_lora_module_path(config), + "rank": config.lora.lora_rank, + "alpha": config.lora.lora_alpha, + } + + # MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1]. + top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0 + top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1 + + rollout_vllm_additional_config = { + "maxtext_config": { + "model_name": config.model_name, + "weight_dtype": "bfloat16", + "allow_split_physical_axes": True, + "debug_sharding": config.debug_sharding, + "prefuse_moe_weights": config.prefuse_moe_weights, + "scan_layers": config.scan_layers, + } + } + + if config.lora.enable_lora: + rollout_vllm_additional_config["maxtext_config"]["lora"] = { + "enable_lora": config.lora.enable_lora, + "lora_restore_path": config.lora.lora_restore_path, + "lora_rank": config.lora.lora_rank, + "lora_alpha": config.lora.lora_alpha, + "lora_module_path": config.lora.lora_module_path, + } + + rollout_config = base_rollout.RolloutConfig( max_tokens_to_generate=max_tokens_to_generate, max_prompt_length=max_prompt_length, temperature=config.decode_sampling_temperature, - top_p=config.decode_sampling_nucleus_p, - top_k=config.decode_sampling_top_k, + top_p=top_p, + top_k=top_k, + rollout_vllm_model_version=config.tokenizer_path, + rollout_vllm_hbm_utilization=config.hbm_utilization_vllm, + rollout_vllm_init_with_random_weights=True, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_hf_config_path=config.model if hasattr(config, "model") else None, + rollout_vllm_lora_config=rollout_vllm_lora_config, + rollout_vllm_additional_config=rollout_vllm_additional_config, + rollout_vllm_kwargs={ + "hf_overrides": config.vllm_hf_overrides, + } if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides else {}, ) + vllm_rollout = VllmRollout( model=tunix_model, tokenizer=tokenizer, @@ -225,12 +293,7 @@ def decode_with_tunix( # other special formatting, which is not part of max_prompt_length. cache_config_or_size=max_prompt_length + max_tokens_to_generate + 256, mesh=mesh, - model_version=config.tokenizer_path, - hbm_utilization=0.8, - # Initialize vllm model with random weights to speed up bootstrap time. - # Actual model weights will be loaded later. - init_with_random_weights=True, - tpu_backend_type="jax", + rollout_config=rollout_config, ) # Generate text @@ -251,6 +314,10 @@ def main(argv: Sequence[str]) -> None: if FLAGS.use_tunix: maxtext_model, mesh = model_creation_utils.from_pretrained(config) + if config.lora.enable_lora: + maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config) + if config.lora.lora_restore_path: + lora_utils.restore_lora_from_path(maxtext_model, config) decode_with_tunix(config, model=maxtext_model, mesh=mesh) else: decode_with_vllm(config) diff --git a/src/maxtext/integration/tunix/tunix_adapter.py b/src/maxtext/integration/tunix/tunix_adapter.py index d509e512a1..a85949b40b 100644 --- a/src/maxtext/integration/tunix/tunix_adapter.py +++ b/src/maxtext/integration/tunix/tunix_adapter.py @@ -30,6 +30,183 @@ from maxtext.models.models import Transformer +# --- MONKEY-PATCH Weight Synchronization Bug in Tunix --- +try: + from collections import abc + import re + import gc + from typing import Any + from flax import traverse_util + from maxtext.utils import max_logging + from tunix.generate import utils as tunix_utils + + def patched_transfer_state_directly( + src_state, + dst_state, + reshard_fn, + scan_axis: int = 1, + delete_dst_buffers: bool = False, + reshard_chunk_size: Any = None, + ): + max_logging.log("MONKEY-PATCH: transfer_state_directly running!") + if delete_dst_buffers: + if hasattr(tunix_utils, '_delete_target_buffers'): + tunix_utils._delete_target_buffers(dst_state, src_state) + gc.collect() + + def safe_has_key(obj, key: str) -> bool: + if isinstance(obj, abc.Mapping): + return key in obj + return hasattr(obj, key) + + # Unwrap 'base' from src_state + if isinstance(src_state, abc.Mapping) and safe_has_key(src_state, 'base'): + max_logging.log("MONKEY-PATCH: Unwrapping 'base' from source state.") + src_state = src_state['base'] + + # Unwrap 'model' from dst_state + while isinstance(dst_state, abc.Mapping) and safe_has_key(dst_state, 'model'): + max_logging.log("MONKEY-PATCH: Unwrapping 'model' from destination state.") + dst_state = dst_state['model'] + + def to_pure_spec(node: Any) -> Any: + if hasattr(node, 'to_pure_dict'): + node = node.to_pure_dict() + if isinstance(node, abc.Mapping): + return {k: to_pure_spec(v) for k, v in node.items()} + if isinstance(node, nnx.Variable): + return to_pure_spec(node[...]) + if hasattr(node, 'value'): + return node.value + return node + + def intersect_trees(src, tgt_spec): + if not isinstance(src, abc.Mapping) or not isinstance(tgt_spec, abc.Mapping): + return src, tgt_spec + + src_flat = traverse_util.flatten_dict(src) + tgt_flat = traverse_util.flatten_dict(tgt_spec) + src_flat = tunix_utils._fuse_moe_weights(src_flat, tgt_flat) + + filtered_src_flat = {} + filtered_tgt_flat = {} + unstacked_cache = {} + layer_pattern = re.compile(r'^layers_(\d+)$') + + for key_tuple, tgt_val in tgt_flat.items(): + path_str = '.'.join(str(k) for k in key_tuple) + if key_tuple in src_flat: + src_val = src_flat[key_tuple] + src_val = tunix_utils._apply_dtype_cast(src_val, tgt_val.dtype, path_str) + src_val = tunix_utils._align_to_model_shape(src_val, tgt_val, path_str) + filtered_src_flat[key_tuple] = src_val + filtered_tgt_flat[key_tuple] = tgt_val + continue + + # Try scanned layer mapping + layer_idx = -1 + match_index = -1 + for i, part in enumerate(key_tuple): + if isinstance(part, str) and part.startswith('layers_'): + m = layer_pattern.match(part) + if m: + layer_idx = int(m.group(1)) + match_index = i + break + + if match_index != -1: + candidate_a = list(key_tuple) + candidate_a[match_index] = 'layers' + candidate_b = list(key_tuple) + candidate_b.pop(match_index) + + found_candidate = None + for cand in [tuple(candidate_a), tuple(candidate_b)]: + if cand in src_flat: + found_candidate = cand + break + + if found_candidate: + cache_key = (found_candidate, tgt_val.shape, 'aligned') + if cache_key not in unstacked_cache: + src_val = src_flat[found_candidate] + candidate_path = '.'.join(str(k) for k in found_candidate) + src_val = tunix_utils._apply_dtype_cast(src_val, tgt_val.dtype, candidate_path) + scanned_per_layer_shape = src_val.shape[:scan_axis] + src_val.shape[scan_axis + 1:] + if scanned_per_layer_shape == tgt_val.shape: + unstacked_cache[cache_key] = tunix_utils._unstack_scanned_param( + src_val, tgt_val, candidate_path, scan_axis=scan_axis + ) + else: + unstacked_cache[cache_key] = tunix_utils._bulk_align_and_unstack( + src_val, scan_axis, tgt_val, candidate_path + ) + + sliced_val = unstacked_cache[cache_key][layer_idx] + sliced_val = tunix_utils._align_to_model_shape(sliced_val, tgt_val, path_str) + filtered_src_flat[key_tuple] = sliced_val + filtered_tgt_flat[key_tuple] = tgt_val + continue + + return ( + traverse_util.unflatten_dict(filtered_src_flat), + traverse_util.unflatten_dict(filtered_tgt_flat), + ) + + full_source_dict = to_pure_spec(src_state) + full_target_spec = to_pure_spec(dst_state) + + final_source, final_spec = intersect_trees(full_source_dict, full_target_spec) + + dst_shardings_flat = { + k: tunix_utils._snapshot_dst_sharding( + tgt_val.value if hasattr(tgt_val, 'value') else tgt_val + ) + for k, tgt_val in traverse_util.flatten_dict(final_spec).items() + } + + resharded_weights = reshard_fn( + source=final_source, + target=traverse_util.unflatten_dict(dst_shardings_flat), + ) + + # Assign to target State + if isinstance(dst_state, nnx.State): + flat_resharded = traverse_util.flatten_dict(resharded_weights) + dst_vars = {path: var for path, var in dst_state.flat_state()} + for path, value in flat_resharded.items(): + if path in dst_vars: + var = dst_vars[path] + if isinstance(var, nnx.Variable): + var.value = value + else: + dst_state[path] = value + elif isinstance(dst_state, dict): + flat_resharded = traverse_util.flatten_dict(resharded_weights) + flat_dst = traverse_util.flatten_dict(dst_state) + for path, value in flat_resharded.items(): + if path in flat_dst: + var = flat_dst[path] + if isinstance(var, nnx.Variable): + var.value = value + else: + node = dst_state + for part in path[:-1]: + node = node[part] + node[path[-1]] = value + else: + nnx.update(dst_state, resharded_weights) + + gc.collect() + max_logging.log("MONKEY-PATCH: transfer_state_directly finished successfully!") + + tunix_utils.transfer_state_directly = patched_transfer_state_directly + max_logging.log("MONKEY-PATCH: Successfully applied tunix weight sync patch inside MaxText!") +except Exception as e: + from maxtext.utils import max_logging + max_logging.log(f"MONKEY-PATCH ERROR: Failed to apply tunix weight sync patch: {e}") + + class TunixMaxTextAdapter(nnx.Module): """Adapter exposing Tunix Trainer call signature over a Transformer model.""" @@ -38,8 +215,57 @@ def __init__( base_model: Transformer, use_standalone_mappings: bool = True, use_no_op_mappings: bool = False, + mesh: Optional[Any] = None, ): super().__init__() + config = base_model.config + if config and hasattr(config, "lora") and config.lora.enable_lora: + from maxtext.utils import lora_utils + from maxtext.utils import max_logging + import types + max_logging.log("Applying LoRA parameters to model via TunixMaxTextAdapter...") + base_model = lora_utils.apply_lora_to_model(base_model, mesh, config) + if config.lora.lora_restore_path: + max_logging.log(f"Restoring LoRA parameters from path in TunixMaxTextAdapter: {config.lora.lora_restore_path}") + lora_utils.restore_lora_from_path(base_model, config) + + # Mathematically premerge LoRA weights into base parameters so vLLM decoding is accurate + import jax.numpy as jnp + lora_rank = config.lora.lora_rank + lora_alpha = config.lora.lora_alpha + lora_scale_factor = lora_alpha / lora_rank + max_logging.log(f"Premerging LoRA weights into base parameters with scale factor: {lora_scale_factor}...") + + def merge_lora_recursively(module): + if hasattr(module, "kernel") and hasattr(module, "kernel_lora_a") and hasattr(module, "kernel_lora_b"): + lora_a = module.kernel_lora_a.value + lora_b = module.kernel_lora_b.value + base_w = module.kernel.value + if lora_a is not None and lora_b is not None and base_w is not None: + if len(base_w.shape) == 3: # Scanned layers: (input_dim, scan_dim, output_dim) + delta = jnp.einsum("isr,rsd->isd", lora_a, lora_b) + elif len(base_w.shape) == 2: # Non-scanned layers: (input_dim, output_dim) + delta = jnp.einsum("ir,rd->id", lora_a, lora_b) + else: + raise ValueError(f"Unexpected base weight shape: {base_w.shape}") + module.kernel.value = base_w + delta * lora_scale_factor + + # Recurse into all submodules/attributes of module + for name, attr in list(module.__dict__.items()): + if isinstance(attr, nnx.Module): + merge_lora_recursively(attr) + elif isinstance(attr, dict): + for k, v in attr.items(): + if isinstance(v, nnx.Module): + merge_lora_recursively(v) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, nnx.Module): + merge_lora_recursively(item) + + merge_lora_recursively(base_model) + max_logging.log("Successfully premerged LoRA weights into base parameters!") + self.base = base_model self._vllm_weight_mapping = VllmWeightMapping( self.base.config.model_name, diff --git a/src/maxtext/integration/tunix/utils.py b/src/maxtext/integration/tunix/utils.py index 8d608956bc..d934696973 100644 --- a/src/maxtext/integration/tunix/utils.py +++ b/src/maxtext/integration/tunix/utils.py @@ -153,10 +153,47 @@ def to_hf_hook_fns(self): return {} def lora_to_hf_mappings(self): - if self.use_standalone_mappings: - return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].lora_to_hf_mappings() + # Dynamically generate LoRA mappings from base model weights mappings + base_mappings = self.to_hf_mapping() + if not base_mappings: + return None + + lora_mapping = {} + for maxtext_key, (hf_key, sharding_spec) in base_mappings.items(): + segments = set(maxtext_key.split(".")) + is_input_proj = any(p in segments for p in ["wi_0", "wi_1", "query", "key", "value", "wq_a", "wq_b", "wkv_a", "wkv_b"]) + is_output_proj = any(p in segments for p in ["wo", "out"]) + + if not (is_input_proj or is_output_proj): + continue + + # Derive MaxText LoRA keys + maxtext_lora_a = maxtext_key + "_lora_a" + maxtext_lora_b = maxtext_key + "_lora_b" + + # Derive HF/vLLM LoRA keys + if hf_key.endswith(".kernel"): + hf_lora_a = hf_key.replace(".kernel", ".kernel_lora_a") + hf_lora_b = hf_key.replace(".kernel", ".kernel_lora_b") + elif hf_key.endswith(".weight"): + hf_lora_a = hf_key.replace(".weight", ".weight_lora_a") + hf_lora_b = hf_key.replace(".weight", ".weight_lora_b") + else: + hf_lora_a = hf_key + "_lora_a" + hf_lora_b = hf_key + "_lora_b" + + # Derive sharding specifications for Qwix LoRA parameters + if is_input_proj: + sharding_a = (None, "layer", None) # Input -> Rank (unsharded) + sharding_b = sharding_spec # Rank -> Output (same as base) + else: + sharding_a = sharding_spec # Input -> Rank (same as base) + sharding_b = (None, "layer", None) # Rank -> Output (unsharded) + + lora_mapping[maxtext_lora_a] = (hf_lora_a, sharding_a) + lora_mapping[maxtext_lora_b] = (hf_lora_b, sharding_b) - return None + return lora_mapping def _generalize_maxtext_key(self, maxtext_key): """Generalizes the MaxText key to a common vLLM format.""" diff --git a/src/maxtext/integration/tunix/weight_mapping/__init__.py b/src/maxtext/integration/tunix/weight_mapping/__init__.py index 6fd7f35028..39ab12ff8f 100644 --- a/src/maxtext/integration/tunix/weight_mapping/__init__.py +++ b/src/maxtext/integration/tunix/weight_mapping/__init__.py @@ -19,6 +19,7 @@ model name. This allows for easy extension to support new models. """ from maxtext.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING +from maxtext.integration.tunix.weight_mapping.gemma3 import GEMMA3_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING from maxtext.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING @@ -35,6 +36,8 @@ def __getattr__(self, name): return QWEN2_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING + elif name.startswith("gemma3"): + return GEMMA3_VLLM_MAPPING elif name.startswith("deepseek3"): return DEEPSEEK_VLLM_MAPPING elif name.startswith("gpt-oss"): diff --git a/src/maxtext/integration/tunix/weight_mapping/gemma3.py b/src/maxtext/integration/tunix/weight_mapping/gemma3.py new file mode 100644 index 0000000000..2ac7cb2379 --- /dev/null +++ b/src/maxtext/integration/tunix/weight_mapping/gemma3.py @@ -0,0 +1,122 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's Gemma3 model to a vLLM-compatible format.""" + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class GEMMA3_VLLM_MAPPING: + """Mapping MaxText Gemma3 weights to vLLM's Gemma3 weights.""" + + @staticmethod + def to_hf_hook_fns(): + """Returns a dictionary of hook functions to be applied to MaxText weights.""" + + def scale_embedding(arr): + hidden_size = arr.shape[1] + normalizer = np.dtype(arr.dtype).type(hidden_size**0.5) + return arr / normalizer + + return { + "base.token_embedder.embedding": scale_embedding, + } + + @staticmethod + def to_hf_transpose_keys(): + """Returns a list of keys for weights that need to be transposed.""" + return {} + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights.""" + return None + + @staticmethod + def to_hf_mapping(): + """Mapping from MaxText model to HuggingFace vLLM model. + + Returns: + A dictionary mapping MaxText parameter names to HuggingFace parameter names and sharding. + """ + return { + # Token embeddings - shard vocab dimension + "base.token_embedder.embedding": ( + "model.language_model.embed_tokens.kernel", + ("model", None), + ), + # Final layer norm - no sharding needed + "base.decoder.decoder_norm.scale": ( + "model.language_model.norm.scale", + (None,), + ), + # Layer norms - no sharding needed + "base.decoder.layers.pre_self_attention_norm.scale": ( + "model.language_model.layers.*.input_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_self_attention_norm.scale": ( + "model.language_model.layers.*.post_attention_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.self_attention.query_norm.scale": ( + "model.language_model.layers.*.self_attn.q_norm.scale", + (None, "layer"), + ), + "base.decoder.layers.self_attention.key_norm.scale": ( + "model.language_model.layers.*.self_attn.k_norm.scale", + (None, "layer"), + ), + "base.decoder.layers.pre_ffw_norm.scale": ( + "model.language_model.layers.*.pre_feedforward_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_ffw_norm.scale": ( + "model.language_model.layers.*.post_feedforward_layernorm.scale", + (None, "layer"), + ), + # MLP components - shard hidden dimensions + "base.decoder.layers.mlp.wi_0.kernel": ( + "model.language_model.layers.*.mlp.gate_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wi_1.kernel": ( + "model.language_model.layers.*.mlp.up_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wo.kernel": ( + "model.language_model.layers.*.mlp.down_proj.kernel", + ("model", "layer", None), + ), + # Attention components - shard head dimensions + "base.decoder.layers.self_attention.query.kernel": ( + "model.language_model.layers.*.self_attn.q_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.kernel": ( + "model.language_model.layers.*.self_attn.k_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.kernel": ( + "model.language_model.layers.*.self_attn.v_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.out.kernel": ( + "model.language_model.layers.*.self_attn.o_proj.kernel", + ("model", "layer", None, None), + ), + } diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 07231f965e..dbec57620e 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -28,6 +28,7 @@ from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE from maxtext.utils import max_logging from maxtext.utils import model_creation_utils +from maxtext.utils import lora_utils try: @@ -323,4 +324,8 @@ def load_weights(self, rng_key: jax.Array) -> None: model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) + if self.maxtext_config.lora.enable_lora: + model = lora_utils.apply_lora_to_model(model, self.mesh, self.maxtext_config) + if self.maxtext_config.lora.lora_restore_path: + lora_utils.restore_lora_from_path(model, self.maxtext_config) self.model = nnx.data(model) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index ba7d540dae..70fde6eb3f 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -532,18 +532,23 @@ def apply_lora_to_model( return lora_model -def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: +def restore_lora_from_path(trainer_or_model: Any, mt_config: pyconfig.HyperParameters) -> Any: """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" lora_restore_path = mt_config.lora.lora_restore_path - train_steps = getattr(trainer, "train_steps", 0) + if isinstance(trainer_or_model, nnx.Module): + model = trainer_or_model + train_steps = 0 + else: + model = trainer_or_model.model + train_steps = getattr(trainer_or_model, "train_steps", 0) if train_steps > 0: max_logging.log( f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'." ) - return trainer + return trainer_or_model - if not is_lora_enabled(trainer.model): + if not is_lora_enabled(model): lora_module_path = _get_lora_module_path(mt_config) if not mt_config.lora.enable_lora: raise ValueError( @@ -551,7 +556,7 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." ) - abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) + abstract_lora_params = nnx.state(model, nnx.LoRAParam) target_for_restore = jax.tree.map( lambda v: {"value": v.value}, @@ -607,6 +612,6 @@ def _map_to_state(path, variable): is_leaf=lambda n: isinstance(n, nnx.Variable), ) - nnx.update(trainer.model, abstract_lora_params) + nnx.update(model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") - return trainer + return trainer_or_model diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 646642786e..daa561c404 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -738,7 +738,7 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa # TunixMaxTextAdapter wraps MaxText models to be compatible with Tunix's default APIs # The weight mappings for vllm (which is interfaced to from MaxText via Tunix) are model specific. # The mappings are defined inside src/maxtext/integration/tunix/weight_mapping - actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings, mesh=reference_mesh) actor_model.config = None actor_mesh = reference_mesh else: @@ -1041,7 +1041,7 @@ def _walk_align(ckpt, model_arr, axes): if wrap_with_tunix_adapter: with mesh: use_no_op_mappings = "maxtext_config" in config.vllm_additional_config - model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings, mesh=mesh) model.config = None if original_mesh: diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh new file mode 100644 index 0000000000..9116f81a0c --- /dev/null +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_lora.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# Validates the Gemma3-4B LoRA pipeline using a pre-converted MaxText checkpoint. + +# The flow of this script is as follows: +# 1. Run inference on the pre-converted checkpoint. +# 2. Run LoRA starting from the pre-converted checkpoint. +# 3. Run inference on the checkpoint produced by the LoRA run. +# 4. Convert the checkpoint produced by the LoRA run back to HuggingFace format. + +# Usage: +# export HF_TOKEN= +# export RUN_ID=$(date +%Y-%m-%d-%H-%M) +# bash test_gemma3_to_mt.sh $RUN_ID +# bash test_gemma3_lora.sh $RUN_ID + + +set -ex + +source /home/jackyf_google_com/maxtext/.venv/bin/activate +export PYTHONPATH=src:$PYTHONPATH + +run_id=${1:-$(date +%Y-%m-%d-%H-%M)} +MODEL_NAME='gemma3-4b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to the GCS paths where you have the scanned and unscanned checkpoints stored +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME} +UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items +SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items + +# Step 1: Install torch +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 2: Run inference on the original checkpoint converted from Hugging Face +python3 -m maxtext.inference.vllm_decode \ + model_name=${MODEL_NAME} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.6 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True scan_layers=false + +# Step 3: Run LoRA on the converted checkpoint +python3 -m maxtext.trainers.post_train.sft.train_sft \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/lora \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + per_device_batch_size=1 run_name=${run_id} \ + steps=5 scan_layers=true \ + model_name=${MODEL_NAME} \ + hf_path=openai/gsm8k \ + train_split=train \ + hf_data_dir=main \ + train_data_columns=['question','answer'] \ + max_target_length=1024 \ + learning_rate=3e-6 \ + chat_template_path=maxtext/examples/chat_templates/math_qa.json \ + lora.enable_lora=True \ + lora.lora_rank=16 \ + lora.lora_alpha=32.0 \ + enable_nnx=True \ + pure_nnx_decoder=True \ + enable_single_controller=True \ + checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False + + +# Step 4: Run inference on the checkpoint generated from the previous run +python3 -m maxtext.inference.vllm_decode \ + --use_tunix=True \ + model_name=${MODEL_NAME} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + lora.enable_lora=True \ + lora.lora_restore_path=${BASE_OUTPUT_DIRECTORY}/lora/${run_id}/checkpoints/5/model_params \ + lora.lora_rank=16 \ + lora.lora_alpha=32.0 \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + hbm_utilization_vllm=0.6 \ + prompt="Suggest some famous landmarks in London." \ + use_chat_template=True \ + enable_nnx=True \ + pure_nnx_decoder=True \ + scan_layers=True + +# Step 5: Convert the checkpoint from MaxText format to Hugging Face format +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name=${MODEL_NAME} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + lora.lora_restore_path=${BASE_OUTPUT_DIRECTORY}/lora/${run_id}/checkpoints/5/model_params \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \ + scan_layers=true diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index fb5cd6f0b6..a0af657b58 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -26,6 +26,7 @@ import sys import unittest + from flax import nnx import jax import jax.numpy as jnp @@ -141,7 +142,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): def test_pure_nnx_dpo_setup_materializes_reference_model(self): """With use_dpo=True the NNX init_state_fn materializes a frozen reference - model alongside the policy (train_utils.py:233-237). Both come from _create_model_partial() with the same init_weights_seed, so absent a step-0 checkpoint the reference starts bit-identical to the policy. @@ -149,8 +149,7 @@ def test_pure_nnx_dpo_setup_materializes_reference_model(self): Positive replacement for the removed test_pure_nnx_dpo_raises_not_implemented: NNX DPO is supported now, so setup_train_loop builds the reference instead - of - raising. + of raising. """ config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) *_, train_state = setup_train_loop(config, recorder=None) diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index 02ed7a5598..4d3e757d3a 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -169,5 +169,124 @@ def test_process_and_stack_weights(self): self.assertEqual(stacked[1, 0, 0], 2.0) +class CheckpointMergingTest(unittest.TestCase): + """Tests the recursive_update and load_orbax_checkpoint functions to ensure we don't overwrite weights.""" + + def test_recursive_update(self): + from maxtext.checkpoint_conversion.utils.utils import recursive_update + + base = { + "params": { + "decoder": { + "layers": { + "kernel": np.ones((4, 4)), + } + } + } + } + lora = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + merged = {} + recursive_update(merged, base) + recursive_update(merged, lora) + + # Verify that both base and lora weights are present and not overwritten + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel"], np.ones((4, 4))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_a"], np.ones((4, 2))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_b"], np.ones((2, 4))) + + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + from maxtext.checkpoint_conversion.utils.utils import load_orbax_checkpoint + + # Mock jax devices + mock_jax_devices.return_value = [MagicMock()] + + # Mock Orbax Checkpointer and its restore results + mock_ckptr = MagicMock() + mock_checkpointer_cls.return_value = mock_ckptr + + # Base checkpoint metadata and content + base_metadata = MagicMock() + base_metadata.item_metadata.tree = { + "params": { + "decoder": { + "layers": { + "kernel": MagicMock(shape=(4, 4)) + } + } + } + } + base_restore_content = { + "params": { + "decoder": { + "layers": { + "kernel": np.ones((4, 4)) + } + } + } + } + + # LoRA checkpoint metadata and content + lora_metadata = MagicMock() + lora_metadata.item_metadata.tree = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": MagicMock(shape=(4, 2)), + "kernel_lora_b": MagicMock(shape=(2, 4)), + } + } + } + } + lora_restore_content = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + # Mock metadata and restore calls + mock_ckptr.metadata.side_effect = [base_metadata, lora_metadata] + mock_ckptr.restore.side_effect = [base_restore_content, lora_restore_content] + + # Create dummy config + config = MagicMock() + config.checkpoint_storage_concurrent_gb = 8 + config.checkpoint_storage_use_ocdbt = True + config.checkpoint_storage_use_zarr3 = True + config.load_parameters_path = "gs://base-bucket/checkpoints" + config.lora.lora_restore_path = "gs://lora-bucket/checkpoints" + + # Load and merge + merged = load_orbax_checkpoint(config) + + # Assert checkpointer was called twice and restored both + self.assertEqual(mock_ckptr.restore.call_count, 2) + + # Verify that the keys are recursively merged correctly! + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_checkpoint_merging.py b/tests/unit/test_checkpoint_merging.py new file mode 100644 index 0000000000..1aa78418b4 --- /dev/null +++ b/tests/unit/test_checkpoint_merging.py @@ -0,0 +1,142 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for recursive checkpoint merging and load_orbax_checkpoint.""" + +import unittest +from unittest.mock import MagicMock, patch +import numpy as np + + +class CheckpointMergingTest(unittest.TestCase): + """Tests the recursive_update and load_orbax_checkpoint functions to ensure we don't overwrite weights.""" + + def test_recursive_update(self): + from maxtext.checkpoint_conversion.utils.utils import recursive_update + + base = { + "params": { + "decoder": { + "layers": { + "kernel": np.ones((4, 4)), + } + } + } + } + lora = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + merged = {} + recursive_update(merged, base) + recursive_update(merged, lora) + + # Verify that both base and lora weights are present and not overwritten + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel"], np.ones((4, 4))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_a"], np.ones((4, 2))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_b"], np.ones((2, 4))) + + @patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") + @patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") + @patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + from maxtext.checkpoint_conversion.utils.utils import load_orbax_checkpoint + + # Mock jax devices + mock_jax_devices.return_value = [MagicMock()] + + # Mock Orbax Checkpointer and its restore results + mock_ckptr = MagicMock() + mock_checkpointer_cls.return_value = mock_ckptr + + # Base checkpoint metadata and content + base_metadata = MagicMock() + base_metadata.item_metadata.tree = { + "params": { + "decoder": { + "layers": { + "kernel": MagicMock(shape=(4, 4)) + } + } + } + } + base_restore_content = { + "params": { + "decoder": { + "layers": { + "kernel": np.ones((4, 4)) + } + } + } + } + + # LoRA checkpoint metadata and content + lora_metadata = MagicMock() + lora_metadata.item_metadata.tree = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": MagicMock(shape=(4, 2)), + "kernel_lora_b": MagicMock(shape=(2, 4)), + } + } + } + } + lora_restore_content = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + # Mock metadata and restore calls + mock_ckptr.metadata.side_effect = [base_metadata, lora_metadata] + mock_ckptr.restore.side_effect = [base_restore_content, lora_restore_content] + + # Create dummy config + config = MagicMock() + config.checkpoint_storage_concurrent_gb = 8 + config.checkpoint_storage_use_ocdbt = True + config.checkpoint_storage_use_zarr3 = True + config.load_parameters_path = "gs://base-bucket/checkpoints" + config.lora.lora_restore_path = "gs://lora-bucket/checkpoints" + + # Load and merge + merged = load_orbax_checkpoint(config) + + # Assert checkpointer was called twice and restored both + self.assertEqual(mock_ckptr.restore.call_count, 2) + + # Verify that the keys are recursively merged correctly! + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + + +if __name__ == "__main__": + unittest.main()