[NNX] NNX migration prep (5/N): correctness fixes and feature enablements#3766
Draft
ecnal-cienet wants to merge 3 commits intomainfrom
Draft
[NNX] NNX migration prep (5/N): correctness fixes and feature enablements#3766ecnal-cienet wants to merge 3 commits intomainfrom
ecnal-cienet wants to merge 3 commits intomainfrom
Conversation
bf95288 to
7602b11
Compare
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
call from ToLinen after nnx.update to fix "Cannot extract graph node from
different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
(was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
(drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
(nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
{"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
use a Python for-loop instead of jax.lax.scan when quantization is FP8.
Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
(nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
_norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
_pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
no-op without use_nnx_pipeline=True.
NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
_save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
to Linen with a clear log warning since NNX-format checkpoints will fail
at restore time.
Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
Transformer class — wraps NNXDecoder.apply_output_head with the
token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
chunk. The NNX model carries its parameters internally so no explicit
FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
uses default autograd; custom_vjp memory-savings optimization is a
follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
and enable_nnx validation guards (no longer needed).
DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.
Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
7602b11 to
4f7763a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)pure_nnx=Falsestays default; preps the codebase for the flag flip.Description
The original plan was for PR5 to bundle the flag flip (
pure_nnx: True) together with the correctness fixes and feature enablements that make NNX usable end-to-end. That bundle produced many UT failures and was hard to review. This PR splits out the no-op-with-flags-False portion so reviewers can evaluate bug fixes and feature enablements separately from the flag flip and sharding-golden regeneration that PR6 will own.Part 1: NNX correctness fixes
Real bug fixes that activate when
pure_nnx=True. Each addresses a concrete crash or numerical bug observed during NNX trial runs.layers/nnx_wrappers.py_refresh_variable_trace_state+is_linen_initializing; called fromToLinenafternnx.update. Fixes"Cannot extract graph node from different trace level"whenjax.gradtracers leak intoVariable._trace_state.models/gpt_oss.py,models/olmo3.pynn.Dropout(...)()withself.dropout = linears.Dropout(...)in__init__. FixesCallCompactUnboundModuleError.layers/normalizations.pyQwen3NextRMSNormsignature:eps→epsilon; acceptshard_mode/kernel_axes/parameter_memory_host_offloadfor callsite parity.layers/attentions.py,models/qwen3.pyeps=→epsilon=.layers/moe.pyper_expert_scaleblock moved into the unfused-kernelelsebranch (was scalingwoeven whenfused_kernelwas active).models/models.pyMultiTokenPredictionBlock(...)directly (drop theToNNX(linen)+lazy_initwrap); passmultimodal_inputwhole toNNXDecoderinstead of unpacking 5 fields (undo the bad unpack from PR #3652).utils/gradient_accumulation.pylax.scan(reduced/unreducedPartitionSpecis rejected inside scan carry). Usennx.merge(..., copy=True)to avoidVariablereuse across re-traces.trainers/diloco/diloco.pystate.params→state.model.filter(nnx.Param), step counter atstate.optimizer.step,replace_nnx_model_paramshelper forjax.lax.condpytree-structure parity.trainers/pre_train/train_compile.py_collect_nnx_activation_shardingshelper (forward pass populates_ACTIVATION_SHARDINGS_DUMP—get_abstract_state_nnxonly traces__init__). NNX path now passes 2-argshaped_train_args(no rng); diloco path patched to handle the 2-vs-3 length difference.utils/muon_utils.pyget_model_mdndefaultpure_nnx=True; wrap NNX result as{"params": nnx.to_pure_dict(...)}for parity with the Linen tree shape.layers/nnx_decoders.pyfp8_nanoo,fp8_gpu) retain JAX tracers in Linen scope across re-traces. Skipjax.checkpointand use a Python for-loop instead ofjax.lax.scanwhenquantization in ("fp8_nanoo", "fp8_gpu"). Makes FP8 quantization usable on the NNX path.trainers/pre_train/train.py(NNXtrain_step)nnx.state(new_state, nnx.Not(nnx.Intermediate))so sowed forward-pass artifacts (e.g.,max_logitsfor QK-Clip) don't break leaf-count parity withstate_mesh_shardings.models/llama2.pyparameter_memory_host_offloadtopre_self_attention_layer_normRMSNorm(was missing on this norm only — broke host-offload for Llama-2 NNX).configs/base.ymllogical_axis_rules:layers_outside_pipeline,layers_per_stage,num_activations,circular_repeats. Additive; no-op withoutuse_nnx_pipeline=True.Part 2: NNX feature enablements (clear all
NotImplementedErrorsites)Clear all 17
Pure NNX support has not been implemented yetNotImplementedErrorsites by routing the Linen-coupled utilities to the Linen path internally. The on-disk format these utilities operate on is fundamentally Linen, so thepure_nnxflag (which affects training, not these tools) shouldn't gate them.utils/layerwise_quantization.pyDeepSeek*ToLinenlayer classes.utils/lora_utils.pyget_lora_abstract_stateexpects Linen tree shape; LoRA adapters on disk are Linen.utils/standalone_checkpointer.pyadd_entropy_to_checkpointaccessesstate.opt_state[0]._replace(mu=..., nu=...)— Linen-only.utils/generate_param_only_checkpoint.py_possibly_unroll_paramsand_save_decode_checkpointusestate.params["params"]["decoder"]— Linen tree.checkpoint_conversion/.../convert_gpt3_ckpt_from_paxml.pykeystr_maptargets Linen tree paths (.params['params'],.opt_state.mu['params']).inference/maxengine/maxengine.pystate.paramsand serves Linen-format inference checkpoints.experimental/rl/grpo_trainer.pyNNX-native versions of these (DPO, MaxEngine, LoRA, GRPO, the remaining checkpoint utilities) are scheduled for follow-up PRs after the flag flip.
Part 3: Vocab tiling on NNX (real implementation, not just routing)
models/models.pyTransformer.logits_from_hidden_states(hidden_states, deterministic, model_mode)on the NNXTransformerclass — wrapsNNXDecoder.apply_output_head(shared_embedding=self.token_embedder, ...). MirrorsTransformerLinenPure.logits_from_hidden_states.utils/vocabulary_tiling.pyvocab_tiling_nnx_loss(model, hidden_states, data, config, is_train). Chunks the vocab axis viajax.lax.scanand callsmodel.logits_from_hidden_states(chunk)per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linengathered_paramspattern). MVP uses default autograd;custom_vjpmemory-savings optimization is a perf follow-up.trainers/pre_train/train.py(NNXloss_fn)if config.num_vocab_tiling > 1: raise NotImplementedErrorwithvocab_tiling_nnx_lossdispatch usinghidden_statesfromnnx.Intermediate.configs/pyconfig_deprecated.py,configs/types.pynum_vocab_tiling > 1 and enable_nnxvalidation guards (no longer needed).Deferred (intentionally out of scope)
train.py:322NotImplementedErrorretained but with a much more informative message pointing users at thepure_nnx=Falseworkaround. Full implementation requires a new TrainState shape carrying both policy and reference NNX models plus an NNXdpo_loss_fn— too large for this PR; tracked as PR7 follow-up.nnx_decoders.pylarger structural rewrite (gemma4,deepseek_batchsplit_fp8,engram_layers,NNXSequentialPipelineStage/NNXScannedPipelineStage): depends on PR #3114. Reserved for PR6 base.pipeline.py106 KB NNX pipeline rewrite: depends on PR #2885. Reserved for PR6 base.maxtext_utils.pyPartitionSpec→Palias rename: pure cosmetic, no behavior change. PR6 picks this up cleanly.base.ymlflag flip + cosmetic whitespace strip: PR6's whole purpose.Tests
pure_nnx=False,enable_nnx=False,pure_nnx_decoder=Falsestill hold (nobase.ymlflag changes);tests/unit/model_test.py3/3 still passes.pure_nnx=True.Transformer.logits_from_hidden_statesandvocab_tiling_nnx_lossexposed and reachable.grep -c "Pure NNX support has not been implemented" src/is now 0 (was 17).sharding_compare_test::deepseek2-16b,optimizers_test::test_model_integration_kimi-k2-1t,diloco_test::two_slices×2 — all also fail on the parent branch HEAD without these changes (confirmed viagit stash).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.