Skip to content
Draft

Test #3422

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
aada2b3
Update
hsuan-lun-chiang May 29, 2026
e9ba176
NNX: finish MaxEngine inference carve-outs (multisampling, concat, st…
ecnal-cienet May 26, 2026
7d36a8a
NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pur…
ecnal-cienet May 6, 2026
22da7d5
NNX: QK-Clip on NNX + NNX-format checkpoint utilities
ecnal-cienet May 7, 2026
272353d
NNX: AQT in MaxEngine + serve-mode reload + gpt3 prefill fix
ecnal-cienet May 7, 2026
b37558e
NNX: vocab tiling custom_vjp with output-head carve-out
ecnal-cienet May 8, 2026
ed8ecef
tests: pin Linen-only vocab tiling and pipeline tests for upcoming NN…
ecnal-cienet May 8, 2026
cffcc7d
NNX: flip pure_nnx/enable_nnx/pure_nnx_decoder defaults to True
ecnal-cienet May 8, 2026
9da8c73
fix tests/unit/train_compile_test.py::TrainCompile::test_remat_save_q…
hsuan-lun-chiang May 19, 2026
05fd56d
Temp: tests/unit/train_compile_test.py::TrainCompile::test_qk_clip_do…
hsuan-lun-chiang May 19, 2026
938e4ff
fix cpu UT failure
May 19, 2026
0ebed5c
fix gpu UT failures
May 20, 2026
b72f7b9
Fix tests/unit/muon_utils_test.py::TestGetMuonWeightDimensionNumbersN…
hsuan-lun-chiang May 20, 2026
38dac66
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
c7f65dc
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
645b13e
Fix test compatibility with pure_nnx=True defaults
hsuan-lun-chiang May 20, 2026
d3df5cd
Fix diloco related unit tests
hsuan-lun-chiang May 21, 2026
92f8ee7
fix nnx_wrapper.py comment
May 21, 2026
30b9a83
fix nnx_wrapper.py gpu UT failure
May 21, 2026
6346ced
Fix integration test failures under NNX defaults
ecnal-cienet May 21, 2026
2582bc9
Revert fix for fp8
hsuan-lun-chiang May 22, 2026
14f94d9
test: skip NNX int8 parameter-only checkpoint generation for GPU dot …
hsuan-lun-chiang May 25, 2026
268da4c
Fix sft_llama3_demo_tpu.ipynb
hsuan-lun-chiang May 26, 2026
19e7c67
test: skip fp8 sparsity smoke cases under NNX (b/509790223)
ecnal-cienet May 26, 2026
e941ff7
test: make maxengine prefill/cache tests NNX-only
ecnal-cienet May 27, 2026
34a4ad9
Fix Fp8 related Unit tests
hsuan-lun-chiang May 28, 2026
870c0ad
Fix Fp8
hsuan-lun-chiang May 29, 2026
eb8aeb9
Update
hsuan-lun-chiang May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import os
import sys

from flax import nnx
import jax
from jax import random
from jax.sharding import Mesh
Expand All @@ -48,11 +49,15 @@
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.common import train_state_nnx
from maxtext.models.models import transformer_as_linen
from maxtext.optimizers import optimizers
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
from maxtext.utils import train_utils
import numpy as np
from psutil import Process
import tensorstore as ts
Expand Down Expand Up @@ -87,12 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# Output is Linen-format (keystr_map below uses Linen tree paths). Route to
# Linen regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
if cfg.pure_nnx:
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)

def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

else:
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
Expand All @@ -101,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down Expand Up @@ -186,10 +201,21 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}
if cfg.pure_nnx:
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state']['count'].value
state_map = {
".optimizer.step.value": ("step", None),
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
}
else:
state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
Expand All @@ -201,19 +227,27 @@ def get_layer_prefix(keystr_pax):
return prefix_pax_opt_state

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
# second momentum in optimizer state
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
if cfg.pure_nnx:
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
else:
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down Expand Up @@ -265,10 +299,11 @@ def map_fn(key_path, value):
max_logging.log("converted state finished")
max_utils.print_mem_stats("converted state finished")

if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
max_logging.log(f"saved a checkpoint at step {step_value}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
if checkpoint_manager.reached_preemption(step_value):
checkpoint_manager.wait_until_finished()
sys.exit()

Expand Down
31 changes: 22 additions & 9 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,23 +319,36 @@ def get_maxtext_model_info(config):
# Get abstract model structure (name, shape) without materializing the weights to save memory
abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"]

abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree)
# Standardize abstract tree for later unflattening
abstract_params_tree = jax.tree.map(
lambda _: 0,
abstract_params_tree,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
abstract_params_flat, abstract_params_treedef = jax.tree_util.tree_flatten_with_path(
abstract_params_tree, is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned)
)
abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree)

max_logging.log("MaxText abstract model and state initialized.")

# preprocess state
maxtext_abstract_dict = {}
for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat):
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
key_parts = []
for k in path_tuple:
# JAX path components can be DictKey(key), GetItemKey(key), or SequenceKey(idx).
# We prefer string keys. If we see an integer or digit-string index, we assume it's
# a layer/block index and join it with the previous part using '_', matching
# MaxText's Linen-style naming convention (e.g., layers_0).
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)

val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and key_parts:
key_parts[-1] = f"{key_parts[-1]}_{val_str}"
else:
key_parts.append(val_str)

mt_param_key = "params-" + "-".join(key_parts)
mt_target_shape = abstract_leaf_value.shape
if isinstance(abstract_leaf_value, nn.LogicallyPartitioned):
mt_target_shape = abstract_leaf_value.value.shape
else:
mt_target_shape = abstract_leaf_value.shape
maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape)

return maxtext_abstract_dict, abstract_params_treedef
Expand Down
26 changes: 23 additions & 3 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,19 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Skip NNX RNG state variables (not model weights)
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
if "to_nnx__rngs" in path_keys or any(k == "rngs" or k.endswith("_rngs") for k in path_keys):
continue
# Skip if this is the "value" key itself - we want the parent path
if path_keys[-1] == "value":
Expand All @@ -912,7 +922,17 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Construct maxtext_param_key from path_tuple
maxtext_param_key = "params-" + "-".join(path_keys)
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
Expand Down
10 changes: 8 additions & 2 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,17 @@ def combine_sharding(sds, shardings):
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
# NNX checkpoints are saved as a pure dict (see maybe_save_checkpoint), so the
# restore target must also be a pure dict. A boxed nnx.State would not match
# the on-disk tree.
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
restore_args = jax.tree_util.tree_map(
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), restore_target
)
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
return ocp.Checkpointer(handler).restore(p, restore_target, restore_args=restore_args)


def create_orbax_checkpoint_manager(
Expand Down
7 changes: 4 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1169,9 +1169,10 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
pure_nnx: false

enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
Loading
Loading