Skip to content

[NNX] NNX migration prep (5/N): correctness fixes and feature enablements#3766

Draft
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-correctness-fixes
Draft

[NNX] NNX migration prep (5/N): correctness fixes and feature enablements#3766
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-correctness-fixes

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics, bidirectional Linen↔NNX checkpoint conversion utilities, and post-training fixes. (PR #3652)
  5. 🔄 [This PR] NNX correctness fixes, feature enablements, and vocab tiling on NNX. No-op while pure_nnx=False stays default; preps the codebase for the flag flip.
  6. ❌ Enable NNX by default; regenerate sharding goldens; fix unit and integration test failures.
  7. ❌ NNX-native DPO, MaxEngine inference, LoRA, GRPO, and remaining checkpoint utilities.
  8. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

The original plan was for PR5 to bundle the flag flip (pure_nnx: True) together with the correctness fixes and feature enablements that make NNX usable end-to-end. That bundle produced many UT failures and was hard to review. This PR splits out the no-op-with-flags-False portion so reviewers can evaluate bug fixes and feature enablements separately from the flag flip and sharding-golden regeneration that PR6 will own.

Part 1: NNX correctness fixes

Real bug fixes that activate when pure_nnx=True. Each addresses a concrete crash or numerical bug observed during NNX trial runs.

File Fix
layers/nnx_wrappers.py Add _refresh_variable_trace_state + is_linen_initializing; called from ToLinen after nnx.update. Fixes "Cannot extract graph node from different trace level" when jax.grad tracers leak into Variable._trace_state.
models/gpt_oss.py, models/olmo3.py Replace inline nn.Dropout(...)() with self.dropout = linears.Dropout(...) in __init__. Fixes CallCompactUnboundModuleError.
layers/normalizations.py Qwen3NextRMSNorm signature: epsepsilon; accept shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
layers/attentions.py, models/qwen3.py Callsite kwargs: eps=epsilon=.
layers/moe.py per_expert_scale block moved into the unfused-kernel else branch (was scaling wo even when fused_kernel was active).
models/models.py Build MTP block as MultiTokenPredictionBlock(...) directly (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole to NNXDecoder instead of unpacking 5 fields (undo the bad unpack from PR #3652).
utils/gradient_accumulation.py ZeRO-1 + GA: defer the unreduced-→all-reduce annotation until after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan carry). Use nnx.merge(..., copy=True) to avoid Variable reuse across re-traces.
trainers/diloco/diloco.py NNX-aware state handling: state.paramsstate.model.filter(nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params helper for jax.lax.cond pytree-structure parity.
trainers/pre_train/train_compile.py New _collect_nnx_activation_shardings helper (forward pass populates _ACTIVATION_SHARDINGS_DUMPget_abstract_state_nnx only traces __init__). NNX path now passes 2-arg shaped_train_args (no rng); diloco path patched to handle the 2-vs-3 length difference.
utils/muon_utils.py get_model_mdn default pure_nnx=True; wrap NNX result as {"params": nnx.to_pure_dict(...)} for parity with the Linen tree shape.
layers/nnx_decoders.py FP8 + NNX scan fix: Linen FP8 ops (fp8_nanoo, fp8_gpu) retain JAX tracers in Linen scope across re-traces. Skip jax.checkpoint and use a Python for-loop instead of jax.lax.scan when quantization in ("fp8_nanoo", "fp8_gpu"). Makes FP8 quantization usable on the NNX path.
trainers/pre_train/train.py (NNX train_step) Return nnx.state(new_state, nnx.Not(nnx.Intermediate)) so sowed forward-pass artifacts (e.g., max_logits for QK-Clip) don't break leaf-count parity with state_mesh_shardings.
models/llama2.py Pass parameter_memory_host_offload to pre_self_attention_layer_norm RMSNorm (was missing on this norm only — broke host-offload for Llama-2 NNX).
configs/base.yml Add 4 pipeline-related logical_axis_rules: layers_outside_pipeline, layers_per_stage, num_activations, circular_repeats. Additive; no-op without use_nnx_pipeline=True.

Part 2: NNX feature enablements (clear all NotImplementedError sites)

Clear all 17 Pure NNX support has not been implemented yet NotImplementedError sites by routing the Linen-coupled utilities to the Linen path internally. The on-disk format these utilities operate on is fundamentally Linen, so the pure_nnx flag (which affects training, not these tools) shouldn't gate them.

File Sites cleared Why route to Linen
utils/layerwise_quantization.py 2 Operates on Linen-format checkpoints via DeepSeek*ToLinen layer classes.
utils/lora_utils.py 1 Downstream get_lora_abstract_state expects Linen tree shape; LoRA adapters on disk are Linen.
utils/standalone_checkpointer.py 2 add_entropy_to_checkpoint accesses state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
utils/generate_param_only_checkpoint.py 3 _possibly_unroll_params and _save_decode_checkpoint use state.params["params"]["decoder"] — Linen tree.
checkpoint_conversion/.../convert_gpt3_ckpt_from_paxml.py 2 keystr_map targets Linen tree paths (.params['params'], .opt_state.mu['params']).
inference/maxengine/maxengine.py 3 Inference engine uses state.params and serves Linen-format inference checkpoints.
experimental/rl/grpo_trainer.py 4 RL trainer is end-to-end Linen-shaped (uses MaxEngine for sampling). Routed with a clear log warning since NNX-format checkpoints will fail at restore time.

NNX-native versions of these (DPO, MaxEngine, LoRA, GRPO, the remaining checkpoint utilities) are scheduled for follow-up PRs after the flag flip.

Part 3: Vocab tiling on NNX (real implementation, not just routing)

File Change
models/models.py Add Transformer.logits_from_hidden_states(hidden_states, deterministic, model_mode) on the NNX Transformer class — wraps NNXDecoder.apply_output_head(shared_embedding=self.token_embedder, ...). Mirrors TransformerLinenPure.logits_from_hidden_states.
utils/vocabulary_tiling.py Add vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train). Chunks the vocab axis via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linen gathered_params pattern). MVP uses default autograd; custom_vjp memory-savings optimization is a perf follow-up.
trainers/pre_train/train.py (NNX loss_fn) Replace if config.num_vocab_tiling > 1: raise NotImplementedError with vocab_tiling_nnx_loss dispatch using hidden_states from nnx.Intermediate.
configs/pyconfig_deprecated.py, configs/types.py Drop the num_vocab_tiling > 1 and enable_nnx validation guards (no longer needed).

Deferred (intentionally out of scope)

  • DPO + NNX: train.py:322 NotImplementedError retained but with a much more informative message pointing users at the pure_nnx=False workaround. Full implementation requires a new TrainState shape carrying both policy and reference NNX models plus an NNX dpo_loss_fn — too large for this PR; tracked as PR7 follow-up.
  • nnx_decoders.py larger structural rewrite (gemma4, deepseek_batchsplit_fp8, engram_layers, NNXSequentialPipelineStage/NNXScannedPipelineStage): depends on PR #3114. Reserved for PR6 base.
  • pipeline.py 106 KB NNX pipeline rewrite: depends on PR #2885. Reserved for PR6 base.
  • maxtext_utils.py PartitionSpecP alias rename: pure cosmetic, no behavior change. PR6 picks this up cleanly.
  • base.yml flag flip + cosmetic whitespace strip: PR6's whole purpose.

Tests

  • Linen invariant verified: pure_nnx=False, enable_nnx=False, pure_nnx_decoder=False still hold (no base.yml flag changes); tests/unit/model_test.py 3/3 still passes.
  • NNX path imports verified: all 26 modified files import cleanly under pure_nnx=True. Transformer.logits_from_hidden_states and vocab_tiling_nnx_loss exposed and reachable.
  • NotImplementedError audit: grep -c "Pure NNX support has not been implemented" src/ is now 0 (was 17).
  • Pre-existing failures left unchanged: sharding_compare_test::deepseek2-16b, optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two_slices ×2 — all also fail on the parent branch HEAD without these changes (confirmed via git stash).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-correctness-fixes branch 2 times, most recently from bf95288 to 7602b11 Compare April 28, 2026 21:56
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes

Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)

Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
  whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
  raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
  nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
  call from ToLinen after nnx.update to fix "Cannot extract graph node from
  different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
  linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
  shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
  (was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
  (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
  to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
  after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
  carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
  (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
  helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
  pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
  traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
  diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
  {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
  retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
  use a Python for-loop instead of jax.lax.scan when quantization is FP8.
  Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
  (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
  QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
  _norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
  _pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
  no-op without use_nnx_pipeline=True.

NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
  via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
  tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
  state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
  _save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
  paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
  Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
  to Linen with a clear log warning since NNX-format checkpoints will fail
  at restore time.

Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
  Transformer class — wraps NNXDecoder.apply_output_head with the
  token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
  via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
  chunk. The NNX model carries its parameters internally so no explicit
  FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
  uses default autograd; custom_vjp memory-savings optimization is a
  follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
  to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
  and enable_nnx validation guards (no longer needed).

DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.

Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-correctness-fixes branch from 7602b11 to 4f7763a Compare April 29, 2026 14:02
@ecnal-cienet ecnal-cienet changed the title [NNX] NNX migration prep (4.5/N): correctness fixes and feature enablements [NNX] NNX migration prep (5/N): correctness fixes and feature enablements Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant