Skip to content

Test#3422

Draft
hsuan-lun-chiang wants to merge 28 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/NNX-Migration-PR-Test
Draft

Test#3422
hsuan-lun-chiang wants to merge 28 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/NNX-Migration-PR-Test

Conversation

@hsuan-lun-chiang
Copy link
Copy Markdown
Collaborator

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch from 3fa42e7 to a6ad144 Compare March 23, 2026 09:23
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from 8b7313a to a39561a Compare March 23, 2026 10:22
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch from a39561a to 26a122a Compare March 23, 2026 10:32
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch 2 times, most recently from 2f932d2 to 26a122a Compare March 25, 2026 02:28
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch 3 times, most recently from 8bacbe0 to ca4db06 Compare March 27, 2026 02:45
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from ca4db06 to ce03e2f Compare March 27, 2026 06:17
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch from 0215bcc to 1dcd278 Compare March 27, 2026 07:24
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from 2301126 to c793145 Compare March 27, 2026 08:57
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch 3 times, most recently from bdd6279 to 257711e Compare March 27, 2026 09:59
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from f80fc67 to e6881e8 Compare March 30, 2026 08:33
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch 2 times, most recently from 791d0b1 to de5baa1 Compare March 30, 2026 09:32
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from a1c4186 to dcd8502 Compare March 30, 2026 10:03
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch 2 times, most recently from c993a10 to dd8fee9 Compare April 13, 2026 09:12
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch 3 times, most recently from 1ff2d73 to 8ca8fc8 Compare April 28, 2026 07:40
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch 4 times, most recently from 6d00430 to eb9756c Compare April 29, 2026 07:04
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from 2f6ca7e to 8d1e4fb Compare April 29, 2026 07:07
@mesakhcienet mesakhcienet force-pushed the test/NNX-Migration-PR-Test branch from 8d1e4fb to 5939451 Compare April 29, 2026 07:23
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from 8c52dc7 to 6fec4a4 Compare May 29, 2026 11:45
ecnal-cienet and others added 27 commits May 29, 2026 11:45
…acked prefill cache)

PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert
path work under pure_nnx=True but left three serving features raising
NotImplementedError on the NNX path. This promotes all three to NNX-native.
Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"])
calls are unchanged, just moved into else: branches, and every NNX edit is
gated `if config.pure_nnx:`.

maxengine.py:
- _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx
  branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1)
  with a fresh _nnx_init_cache_dict. The loop that draws num_samples first
  tokens from the shared logits is unchanged.
- prefill_concat: same swap; the packed positions and segment ids thread
  through _nnx_run_model unchanged.
- stack_prefill_result_cache=True: now supported for both scan_layers values.
  scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen
  post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are
  no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False
  keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int
  keys), so _maybe_stack stacks them into one subtree with a leading layer axis,
  _maybe_unstack splits it back into the int-keyed per-layer dict that
  bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to
  each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) +
  ["decoder"]["layers_0"] reshape).

tests/integration/maxengine_test.py:
- New _build_linen_params helper and a shared _stack_prefill_roundtrip helper.
- test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen
  result-shape parity, finite logits + cache.
- test_stack_prefill_result_cache_nnx (scan_layers=True) and
  test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False):
  prefill -> insert -> generate round-trip, layer-stacked leaves, finite
  logits, next_pos advances.

Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8),
which are other PRs' scope.
Closes the QK-Clip TODO and migrates the remaining Linen-only
checkpoint utilities to NNX. Linen paths preserved byte-for-byte
(every NNX edit is gated on `config.pure_nnx` or runtime state-shape
detection).

QK-Clip:
- qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via
  nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict ->
  nnx.update. Accepts both the production NNX intermediate shape
  (self_attention.attention_op.max_logits) and the synthetic-fixture
  shape from the existing Linen tests (self_attention.max_logits).
- train.py train_step dispatches to apply_qk_clip_nnx for NNX,
  removing the prior TODO at the QK-Clip call site.

Checkpoint utilities (NNX paths added):
- standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn
  under pure_nnx; add_entropy_to_checkpoint dispatches across Linen
  TrainState, NNX TrainStateNNX Module, and post-split nnx.State
  shapes.
- generate_param_only_checkpoint: NNX init_state_fn under pure_nnx;
  _possibly_unroll_params_nnx slices scanned NNX layers via dict-style
  state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict
  tree to orbax. Parallel LoRA decode flow operates on the
  single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx.
- convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr
  translation (.params['params']<rest> -> .model<rest>.value, etc.).
  End-to-end paxml -> NNX conversion is wired but not yet validated
  on hardware.

Tests:
- qk_clip_test: 7 new NNX cases covering attention-type guard, MLA
  wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
  missing-stats resilience, Linen<->NNX numeric parity.
- standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu
  overwrite on TrainStateNNX Module shape, no mutation of state.model
  params, post-split nnx.State shape from setup_training_state.
- generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned
  layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA
  delta unroll on the single-nested NNX shape).

NNX + AQT in MaxEngine and the layerwise_quantization NNX path are
split into the follow-up PR9.5.
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.

NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
  "serve") through from_config / create_model /
  get_nnx_create_model_fn / create_nnx_abstract_model /
  from_pretrained. Default "train" preserves existing callers; "serve"
  propagates to configure_quantization so AQT layers don't materialize
  the full-precision kernel when the on-disk checkpoint already
  carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
  config.checkpoint_is_quantized; _load_params_nnx drops its
  NotImplementedError. Two paths: pre-quantized
  (checkpoint_is_quantized=True) loads via quant_mode_str="serve";
  full-precision + quantization=int8 loads in TRAIN mode and AQT
  layers quantize per-forward (same numerical result for absmax
  calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
  convert path. Loads full-precision in TRAIN mode, transfers kernels
  into a CONVERT-mode model, runs forward to populate qrhs.frozen via
  the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
  saves serve-mode-shaped state.

Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
  parallel-tree of replicated NamedSharding leaves when a Variable's
  value is a composite pytree (AQT serve-mode QTensor with a qvalue
  int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
  jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
  AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
  _unwrap_for_align use Variable.get_value() instead of v[...]
  indexing for QTensor leaves (QTensor.__getitem__ trips on the
  LogicallyPartitioned wrapper around qvalue). Widens the restore
  filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
  type. Skips QTensor leaves in the per-axis shape-alignment dispatch
  (their saved shape already matches the model). _build_value_target
  strips Partitioned wrappers around composite-leaf values so the
  restore tree path matches the on-disk layout (LogicallyPartitioned
  was adding an extra .value key under each QTensor leaf, which made
  orbax silently fill the path with zero-init values).

gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
  ever calling update_kv_caches to build cached_values, so any
  non-TRAIN forward (prefill or autoregressive) tripped the
  `assert prefill_kv_cache` check. Mirror the standard Attention
  plumbing in attentions.py: __init__ constructs a KVCache_0 module
  when model_mode != MODEL_MODE_TRAIN, threads
  max_prefill_predict_length into AttentionOp; __call__ calls
  self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
  cached_values to attention_op. TRAIN-mode shape unchanged.

Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
  _strip_kernels_at_quantized_paths covering quantized removal,
  non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
  builds a small NNX model in CONVERT mode with int8, runs a forward
  to populate qrhs.frozen via the ToNNX bridge, saves the
  serve-mode-shape state to a tmp local orbax checkpoint, reloads via
  from_pretrained(quant_mode_str="serve"), and asserts every saved
  qrhs.frozen.qvalue array byte-matches what came back. Guards the
  full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
  test_quantize_passes_gate_for_nnx; added
  test_load_pre_quantized_nnx_passes_quant_gate and
  test_quantized_prefill_nnx_train_mode (real numerical verification
  with quantization=int8 + random params + TRAIN mode).

End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.

Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
  (_is_output_head_param_path) matching {token_embedder,
  shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
  paths apply_output_head touches. Returns (graphdef, head_params,
  other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
  labels, segmentation). Only head_params and hidden_states are
  differentiated; other_params + rest are threaded through as
  non-differentiated primals so their tracers don't have to cross both
  the custom_vjp and the inner lax.scan boundary (which previously
  caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
  (num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
  rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
  chunk_other, chunk_rest, copy=True) and calls
  logits_from_hidden_states. Initial scan accumulator is fp32 (was
  hidden_states.dtype previously — caused a lax.scan carry dtype
  mismatch with bf16 hidden_states since cross_entropy_with_logits
  always returns fp32). Residuals are (chunk_head, chunk_other,
  chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
  body builds loss_fn_for_vjp = lambda p, h: ..., calls
  jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
  accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
  Chain-rules grad_head *= loss_cotangent and dtype-casts to each
  primal's dtype (custom_vjp requires this). chunk_other_params and
  chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
  pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
  with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
  transformer-layer params, which carry axis-1 stacking (nnx
  convention), and the cotangent-shape check fails as
  "Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
  but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
  shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
  head_params; the gradient w.r.t. other_params through this loss is
  exactly zero. When train.py also calls the full model forward (which
  produces hidden_states), transformer-layer gradients flow back
  through grad_hidden_states → outer backward, unaffected by the
  carve-out.

Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
  reads embedding_table = shared_embedding.embedding[...] instead of
  the deprecated .value shim. The .value shim registers the access in
  NNX mutation tracking, which JAX detects as a tracer leak when the
  embedding is closure-captured / threaded across the custom_vjp +
  lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
  if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
  NNX Transformer class. Two lines left over from an early PR5
  implementation idea — neither path actually reads
  model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
  reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
  ("decoder", "hidden_states") intermediate). Without this fix, AOT
  compile under pure_nnx=True + num_vocab_tiling>1 raised
  ValueError: Cannot assign data value of type 'LinearizeTracer' to
  static attribute 'hidden_states' of Pytree type 'Transformer' —
  would have silently broken any post-PR11 user with vocab tiling on.

Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
  grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
  tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
  exercises the segmentation != 0 mask branch and asserts padded loss
  is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
  differentiation; exercises the custom_vjp's second-primal cotangent
  path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
  loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
  preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
  total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
  and asserts identical loss + grads (catches off-by-one in
  vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
  invariant): asserts every non-head leaf has gradient exactly zero
  AND at least one head leaf has non-zero gradient (so the test can't
  trivially pass by zeroing everything). Catches filter bugs (e.g.
  forgetting that NNX names the embedder token_embedder while Linen
  names it shared_embedding) and bwd zero-shape bugs.

AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
  NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
  step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
  with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
  models.py dead-code regression and the cotangent-axis-ordering issue
  the explicit-zeros bwd fixes.

Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.
…X default flip

Pre-flip safety: PR11 will flip pure_nnx/enable_nnx/pure_nnx_decoder from
False to True in base.yml. Some existing tests are Linen-coupled and would
either silently switch to NNX (and break) or silently SKIP after that flip.
Pin them to Linen explicitly so they keep exercising the Linen path, with
no behavior change today (the pin matches the current default).

tests/unit/tiling_test.py:
  LossAndGradientCorrectnessTest builds models via transformer_as_linen and
  exercises the Linen vocab_tiling path. Extend self.base_config in setUp
  with enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False, then drop
  the 6 stale pytest.skip("We currently don't support vocab tiling on NNX
  module.") guards (NNX-side coverage lives in VocabTilingNNXTest in the
  same file, added in PR10).

tests/unit/pipeline_parallelism_test.py:
  Pipeline parallelism does not yet have an NNX path (deferred to PR11.5).
  Add _LINEN_PIN class const and append *self._LINEN_PIN to the 6
  train_main arg lists in test_full_train_circular,
  test_full_train_circular_pipeline_ag_per_repeat,
  test_full_train_non_circular, test_subset_layers, test_full_train_fp8,
  and test_full_train_nanoo_fp8. The unit-style
  assert_pipeline_same_output_and_grad tests bypass the dispatch by
  calling pipeline.create_pipeline + SimpleDecoderLayerToLinen directly,
  so they are flag-immune and need no change.
The PR6-PR10 sequence promoted every routed-to-Linen feature to
NNX-native (DPO/PR6, MaxEngine/PR7, LoRA+GRPO/PR8, QK-Clip + checkpoint
utilities/PR9, AQT + serve-mode/PR9.5, vocab tiling custom_vjp/PR10).
With those gaps closed, NNX is the production path; this commit makes
it the default.

Empirical break-test on CPU (pytest before/after the flip across
tiling_test, train_compile_test, sharding_compare_test,
maxtext_utils_test, maxengine_test) showed zero flip-induced failures
- every CPU unit-test failure pre-existed on PR10 tip. TPU smoke
verified end-to-end: gemma2-2b 3-step train under the new defaults
logged "pure_nnx: True" in pyconfig and produced loss
13.04 -> 12.32 -> 11.82 (decreasing, no NaN/inf, no Traceback).
Linen-only test files were already pinned in the prior commit so no
per-test breakage from the flip.

base.yml: enable_nnx, pure_nnx_decoder, pure_nnx all flip False -> True.

No use_nnx_pipeline flag is added: PR10 tip has no NNX pipeline path
to opt out of, so a one-valued flag would be dead weight. Pipeline
tests keep their Linen pin from the prior commit; the eventual NNX
pipeline work (PR11.5) will introduce its own opt-in if needed.

Sharding goldens not regenerated: tests/unit/sharding_compare_test.py
already pins enable_nnx=False, pure_nnx=False, pure_nnx_decoder=False
explicitly when invoking the dump utility, so existing goldens at
tests/utils/sharding_info/ stay valid against the flipped default.
…NX::test_nnx_model_dispatches_to_tree_map_with_path
1. Sanitize unmapped logical axes to None in maxtext_utils.py get_nnx_named_sharding_with_scan_axis to prevent compilation ValueError.

2. Fix qk_clip_utils.py broadcast shape mismatch (axis=0 to axis=-2) causing TypeError.

3. Update max_utils_test.py unscan utility to correctly parse TrainStateNNX and its parameters/sharding trees.

4. Fix muon_utils_test.py NNX dict mapping assertIsNone() against raw objects rather than .

5. Patch train_distill and train_sft to explicitly nnx.pop(Intermediate) to prevent GraphDef mutation ValueErrors.

6. Update diloco.py to use nnx.split instead of the deprecated filter() method for param extraction.

7. Update diloco_test.py to execute pure NNX training loop simulations instead of legacy Linen.
After flipping pure_nnx/enable_nnx/pure_nnx_decoder to True, several
integration tests broke because their code paths assumed Linen. Fixes:

- maxengine_test: remove the Linen-only test_basic_prefill / test_basic_decode
  (they build the model with transformer_as_linen but the engine now expects
  NNX state). The NNX path is already covered by test_basic_prefill_nnx /
  test_basic_decode_nnx. Drop the now-unused imports and get_data helper.

- train_sft_deprecated: support the NNX train loop. Split the TrainStateNNX
  into GraphDef + flat state before jit, only pass a dropout rng on the Linen
  path (the NNX step takes (state, batch)), and read setup params via
  nnx.split on the NNX path.

- quantizations.maybe_quantize_model: qwix.quantize_model traces NNX modules
  and needs example inputs, so pass dummy decoder tokens/positions for the
  NNX path. Fixes the fp8 sparsity smoke test.

- generate_param_only_checkpoint (NNX param-only flow):
  - checkpointing._load_full_state_from_path: restore into a pure dict, since
    NNX checkpoints are saved as pure dicts; a boxed nnx.State did not match.
  - read opt_state from state.optimizer.opt_state on the NNX path.
  - save only nnx.Param leaves (the rng PRNGKeyArray can't be cast to bf16)
    and wrap each leaf as {"value": ...} so from_pretrained can read it back.
  - skip the int8 case: it is a convert-on-load scenario (the fp32 training
    checkpoint has no AqtDotGeneral state the int8 model expects); tracked as
    a follow-up alongside layerwise_quantization.
…product test

NNX int8 parameter-only generation requires a convert-on-load setup, which causes a ValueError since the fp32 training checkpoint lacks the AqtDotGeneral state that the target int8 model expects. This aligns the GPU/dot-product test with the existing skip in the TPU/autoselected test variant.
fp8 Qwix stateful quant crashes under NNX (tracer leak / pytree ValueError via ToLinen). Skip the fp8_full and fp8_full_with_sparsity cases until b/509790223 is fixed.
PR#11 flips the defaults to NNX, so the Linen reference engine in the prefill_multisampling/prefill_concat parity tests silently became NNX and crashed (device_put State-vs-dict), and test_stack_and_unstack_prefill_cache hit the NNX no-op branch. Drop the Linen comparisons and assert the NNX result shapes directly, rewrite the cache test for the NNX scan_layers=False path, and remove _build_linen_params and its imports.
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the test/NNX-Migration-PR-Test branch from 6fec4a4 to eb8aeb9 Compare May 29, 2026 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants