Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,11 @@ def main(argv: Sequence[str]) -> None:
hook_fn_map = mappings["hook_fn_mapping"]

# 4. Extract and transform weights for Linen/NNX-SFT/NNX-RL checkpoints
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict)
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict, config)

# Validate that checkpoint keys match the parameter mapping
state_keys = set(maxtext_state_dict) | {
k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k
state_keys = {
k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict
}
filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys)

Expand Down
45 changes: 33 additions & 12 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,16 @@ def format_meter(
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)


def recursive_update(d: dict, u: dict) -> dict:
"""Recursively updates dictionary d with dictionary u in place."""
for k, v in u.items():
if isinstance(v, dict) and k in d and isinstance(d[k], dict):
recursive_update(d[k], v)
else:
d[k] = v
return d


def load_orbax_checkpoint(config) -> dict:
"""Loads Orbax checkpoints from Base and/or LoRA paths in config.

Expand Down Expand Up @@ -837,7 +847,7 @@ def create_restore_args(tree_metadata):
metadata.item_metadata.tree,
is_leaf=lambda x: hasattr(x, "shape"),
)
merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args))
recursive_update(merged_dict, ckptr.restore(checkpoint_path, restore_args=restore_args))

return merged_dict

Expand Down Expand Up @@ -918,7 +928,7 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
return result


def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray]:
def detect_and_extract_checkpoint(checkpoint_dict: dict, config=None) -> dict[str, np.ndarray]:
"""Detect checkpoint type (Linen vs NNX) and extract weights.

Handles multiple NNX checkpoint variants:
Expand All @@ -932,24 +942,35 @@ def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray

Args:
checkpoint_dict: Raw checkpoint dictionary from Orbax
config: Optional MaxText configuration

Returns:
Dictionary mapping MaxText parameter names to weight arrays
"""
# Determine if we are using an NNX model from config or if there is no top-level "params" key
is_nnx = config.enable_nnx if config is not None else ("params" not in checkpoint_dict)

# Detect checkpoint type by structure
actual_weights_dict = checkpoint_dict.get("params")

if actual_weights_dict is None:
# NNX checkpoint: structure is directly at the root
# Check for NNX-RL variant with 'base' wrapper
if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict):
# NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}}
max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)")
return extract_nnx_weights(checkpoint_dict["base"])
if is_nnx:
if actual_weights_dict is None:
# NNX checkpoint: structure is directly at the root
# Check for NNX-RL variant with 'base' wrapper
if "base" in checkpoint_dict and isinstance(checkpoint_dict["base"], dict):
# NNX-RL: {'base': {'decoder': ..., 'token_embedder': ...}}
max_logging.log("Detected NNX-RL checkpoint structure (with 'base' wrapper)")
return extract_nnx_weights(checkpoint_dict["base"])
else:
# NNX-SFT: {'decoder': ..., 'token_embedder': ...}
max_logging.log("Detected NNX-SFT checkpoint structure")
return extract_nnx_weights(checkpoint_dict)
else:
# NNX-SFT: {'decoder': ..., 'token_embedder': ...}
max_logging.log("Detected NNX-SFT checkpoint structure")
return extract_nnx_weights(checkpoint_dict)
# NNX checkpoint wrapped inside top-level 'params' key
if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict:
actual_weights_dict = actual_weights_dict["params"]
max_logging.log("Detected NNX checkpoint structure wrapped in 'params'")
return extract_nnx_weights(actual_weights_dict)
else:
# Linen checkpoint: check if there's a nested 'params' key
if isinstance(actual_weights_dict, dict) and "params" in actual_weights_dict:
Expand Down
85 changes: 76 additions & 9 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from maxtext.utils import model_creation_utils
from maxtext.utils import max_logging
from maxtext.utils import lora_utils
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import Config
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
Expand Down Expand Up @@ -96,6 +97,17 @@ def decode_with_vllm(config: Config) -> None:
},
}

if config.lora.enable_lora:
vllm_args["additional_config"]["maxtext_config"].update(
{
"lora.enable_lora": config.lora.enable_lora,
"lora.lora_restore_path": config.lora.lora_restore_path,
"lora.lora_rank": config.lora.lora_rank,
"lora.lora_alpha": config.lora.lora_alpha,
"lora.lora_module_path": config.lora.lora_module_path,
}
)

if config.load_parameters_path:
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = config.load_parameters_path
else:
Expand Down Expand Up @@ -178,7 +190,19 @@ def decode_with_tunix(
mesh: The JAX mesh for parallelism.
"""
# Wrap the model for Tunix
tunix_model = TunixMaxTextAdapter(base_model=model)
use_no_op_mappings = False
if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides:
overrides = config.vllm_hf_overrides
if isinstance(overrides, str) and "MaxTextForCausalLM" in overrides:
use_no_op_mappings = True
elif isinstance(overrides, dict) and "MaxTextForCausalLM" in overrides.get("architectures", []):
use_no_op_mappings = True

tunix_model = TunixMaxTextAdapter(
base_model=model,
use_no_op_mappings=use_no_op_mappings,
mesh=mesh,
)

# Load the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -210,13 +234,57 @@ def decode_with_tunix(
)

# Create vLLM rollout for inference
rollout_vllm_lora_config = None
if config.lora.enable_lora:
rollout_vllm_lora_config = {
"module_path": lora_utils._get_lora_module_path(config),
"rank": config.lora.lora_rank,
"alpha": config.lora.lora_alpha,
}

# MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1].
top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0
top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1

rollout_vllm_additional_config = {
"maxtext_config": {
"model_name": config.model_name,
"weight_dtype": "bfloat16",
"allow_split_physical_axes": True,
"debug_sharding": config.debug_sharding,
"prefuse_moe_weights": config.prefuse_moe_weights,
"scan_layers": config.scan_layers,
}
}

if config.lora.enable_lora:
rollout_vllm_additional_config["maxtext_config"]["lora"] = {
"enable_lora": config.lora.enable_lora,
"lora_restore_path": config.lora.lora_restore_path,
"lora_rank": config.lora.lora_rank,
"lora_alpha": config.lora.lora_alpha,
"lora_module_path": config.lora.lora_module_path,
}


rollout_config = base_rollout.RolloutConfig(
max_tokens_to_generate=max_tokens_to_generate,
max_prompt_length=max_prompt_length,
temperature=config.decode_sampling_temperature,
top_p=config.decode_sampling_nucleus_p,
top_k=config.decode_sampling_top_k,
top_p=top_p,
top_k=top_k,
rollout_vllm_model_version=config.tokenizer_path,
rollout_vllm_hbm_utilization=config.hbm_utilization_vllm,
rollout_vllm_init_with_random_weights=True,
rollout_vllm_tpu_backend_type="jax",
rollout_vllm_hf_config_path=config.model if hasattr(config, "model") else None,
rollout_vllm_lora_config=rollout_vllm_lora_config,
rollout_vllm_additional_config=rollout_vllm_additional_config,
rollout_vllm_kwargs={
"hf_overrides": config.vllm_hf_overrides,
} if hasattr(config, "vllm_hf_overrides") and config.vllm_hf_overrides else {},
)

vllm_rollout = VllmRollout(
model=tunix_model,
tokenizer=tokenizer,
Expand All @@ -225,12 +293,7 @@ def decode_with_tunix(
# other special formatting, which is not part of max_prompt_length.
cache_config_or_size=max_prompt_length + max_tokens_to_generate + 256,
mesh=mesh,
model_version=config.tokenizer_path,
hbm_utilization=0.8,
# Initialize vllm model with random weights to speed up bootstrap time.
# Actual model weights will be loaded later.
init_with_random_weights=True,
tpu_backend_type="jax",
rollout_config=rollout_config,
)

# Generate text
Expand All @@ -251,6 +314,10 @@ def main(argv: Sequence[str]) -> None:

if FLAGS.use_tunix:
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
if config.lora.enable_lora:
maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config)
if config.lora.lora_restore_path:
lora_utils.restore_lora_from_path(maxtext_model, config)
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
else:
decode_with_vllm(config)
Expand Down
Loading
Loading