From 9ce8edd960b9776ae912a8f3ee962cc7b7f179b3 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 25 Jun 2026 19:11:06 +0000 Subject: [PATCH] NNX: fix Linen-parity gaps on the default path + unit tests With pure_nnx/enable_nnx/pure_nnx_decoder defaulting to True, several train/loss/decoder/metrics/GRPO paths diverged from Linen. Fixes: - skip_step_on_spikes: forward loss/grad_norm through apply_gradients to the optax skip-step optimizer; read is_skipped back off the NNX optimizer. - loss_fn: check the indexer dense-warmup before num_vocab_tiling (Linen order). - decoder logits guards: use the model_mode call-arg, not self.model_mode. - routed_bias read: dispatch the Linen intermediates path vs an NNX suffix match. - record_activation_metrics: collect by path suffix so it works for Linen and NNX, scanned and unscanned (also fixes a pre-existing Linen KeyError). - nnx_attrs_to_linen_vars: skip non-Variable attrs (qwix bookkeeping) not raise. - config: error when qwix quant can't reach a bridged Linen decoder under pure_nnx. - maxengine.set_engine_vars_from_base_engine: skip the quant copy and use the NNX kv-cache annotations on the NNX path. - GRPO _train_step_nnx: gradient-accumulation scan loop; fix the GA loss metric. - GRPO pathways reshard: drop the scan_layers=False NotImplementedError. - GRPO host-offload: move optimizer state to device before the in-place update. Tests: train_nnx_test, grpo_nnx_test, maxengine_nnx_test, nnx_quant_guard_test. --- src/maxtext/common/metric_logger.py | 30 ++--- src/maxtext/common/train_state_nnx.py | 7 +- src/maxtext/configs/types.py | 18 +++ src/maxtext/experimental/rl/grpo_trainer.py | 73 +++++++--- src/maxtext/experimental/rl/grpo_utils.py | 14 +- src/maxtext/inference/maxengine/maxengine.py | 19 +-- src/maxtext/layers/nnx_decoders.py | 4 +- src/maxtext/layers/nnx_wrappers.py | 13 +- src/maxtext/trainers/pre_train/train.py | 51 +++++-- tests/unit/grpo_nnx_test.py | 129 ++++++++++++++++++ tests/unit/maxengine_nnx_test.py | 70 ++++++++++ tests/unit/nnx_decoder_test.py | 85 ++++++++++++ tests/unit/nnx_quant_guard_test.py | 78 +++++++++++ tests/unit/quantizations_test.py | 1 + tests/unit/train_nnx_test.py | 132 +++++++++++++++++++ 15 files changed, 659 insertions(+), 65 deletions(-) create mode 100644 tests/unit/maxengine_nnx_test.py create mode 100644 tests/unit/nnx_decoder_test.py create mode 100644 tests/unit/nnx_quant_guard_test.py diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 472ede809e..2137dd6482 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -64,23 +64,23 @@ def _prepare_metrics_for_json(metrics, step, run_name): def record_activation_metrics(output_metrics, intermediate_outputs, config): - """Adds the activation metrics to the metrics dict""" + """Adds the activation metrics to the metrics dict. - if config.scan_layers: - metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] - - for layer_num in range(config.num_decoder_layers): - output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][ - 0 - ][layer_num] - output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num] - output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num] - else: + Collects each metric by path suffix rather than a hardcoded path, so it works for + both the Linen ("intermediates"-prefixed) and NNX (model-rooted) layouts and for + both scanned (one stacked leaf) and unscanned (one leaf per layer) decoders. + """ + for label, key in ( + ("activ_fraction_zero", "activation_fraction_zero"), + ("activ_mean", "activation_mean"), + ("activ_stdev", "activation_stdev"), + ): + vals = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, key) + if not vals: + continue + per_layer = jax.numpy.concatenate(vals) for layer_num in range(config.num_decoder_layers): - layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"] - output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0] - output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0] - output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0] + output_metrics["scalar"][f"{label}/layer_{layer_num:03d}"] = per_layer[layer_num] class MetadataKey(enum.Enum): diff --git a/src/maxtext/common/train_state_nnx.py b/src/maxtext/common/train_state_nnx.py index 141820b35f..a123956a73 100644 --- a/src/maxtext/common/train_state_nnx.py +++ b/src/maxtext/common/train_state_nnx.py @@ -40,11 +40,12 @@ def __init__( self.model = model self.optimizer = optimizer - def apply_gradients(self, grads: Any): + def apply_gradients(self, grads: Any, **kwargs): """Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, and increments - the step counter. Only updates `self.model`. + the step counter. Only updates `self.model`. Extra kwargs (e.g. loss/grad_norm + for the skip-step-on-spikes optimizer) are forwarded to the optax update. """ if self.optimizer is None: raise RuntimeError( @@ -52,7 +53,7 @@ def apply_gradients(self, grads: Any): " an optimizer. This usually happens when the state was created for" " inference only." ) - self.optimizer.update(self.model, grads) + self.optimizer.update(self.model, grads, **kwargs) # On-disk checkpoint format. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 757e04e515..0f991d424d 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2078,6 +2078,16 @@ class RLHardware(BaseModel): description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.", ) rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout") + inference_replicas: int = Field(1, description="Legacy experimental GRPO: number of inference (sampler) replicas.") + inference_devices_per_replica: int = Field( + 4, description="Legacy experimental GRPO: devices per inference replica (single-controller device split)." + ) + inference_rollouts: int = Field( + 1, description="Legacy experimental GRPO: refresh rollouts every N steps (step % inference_rollouts)." + ) + use_pathways_reshard: bool = Field( + True, description="Legacy experimental GRPO: use Pathways resharding to move policy params to the sampler." + ) class VLLM(BaseModel): @@ -2696,6 +2706,14 @@ def validate_and_set_hlo_dump_defaults(): if not self.enable_nnx: raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True") + if self.pure_nnx and not self.pure_nnx_decoder and self.use_qwix_quantization and not self.use_batch_split_schedule: + if self.quantization: + raise ValueError( + f"quantization='{self.quantization}' with use_qwix_quantization=True under pure_nnx=True requires " + "pure_nnx_decoder=True. The bridged Linen decoder (pure_nnx_decoder=False) is invisible to Qwix, " + "so quantization (and weight sparsity) would silently have no effect. Set pure_nnx_decoder=True." + ) + # Validate distillation schedule parameters if self.distill_alpha_end is not None and not 0.0 <= self.distill_alpha_end <= 1.0: raise ValueError(f"distill_alpha_end must be in [0, 1], got {self.distill_alpha_end}") diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 41b20e4e04..9ea0a68f82 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -471,8 +471,8 @@ def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): Args: model_graphdef: NNX `GraphDef` of the `TrainStateNNX`. config: Training configuration object. - state_mesh_shardings: Sharding spec for the train state. Unused on this - path; kept for signature parity with `train_step`. + state_mesh_shardings: Sharding spec for the train state; used to move the + optimizer state to device when `optimizer_memory_host_offload` is set. state: Flat `nnx.State` matching `model_graphdef`. data: A batch dict produced by the GRPO input pipeline. @@ -480,14 +480,6 @@ def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): A tuple `(new_state, metrics)`. `new_state` is filtered to exclude `nnx.Intermediate`. `metrics` is a dict shaped like the Linen path's. """ - del state_mesh_shardings # Host-offload paths are not yet wired up here. - - if config.gradient_accumulation_steps > 1: - raise NotImplementedError( - "GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. " - "Set gradient_accumulation_steps=1 or pure_nnx=False." - ) - state = nnx.merge(model_graphdef, state) # Reconstruct the TrainStateNNX. policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) # Split the reference model into (graphdef, state) so we pass `ref_state` as @@ -505,13 +497,61 @@ def diff_wrapper(param, rest, ref_state, config, data): return loss, (aux, new_rest) grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) - (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data) + + if config.gradient_accumulation_steps <= 1: + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data) + else: + # Mirror the pre-train NNX gradient-accumulation loop and the Linen GRPO one: + # params stay fixed across microbatches while the non-param state (rest, e.g. + # RNGs) advances in the scan carry. Grads are accumulated weighted by each + # microbatch's total_weights, then normalized once after the scan. + def reshape_to_microbatch_accumulations(batch_arr): + microbatches = config.gradient_accumulation_steps + microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:] + return jnp.reshape(batch_arr, microbatch_shape) + + ga_data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data) + + def accumulate_gradient(carry, microbatch): + (_, (aux, new_rest)), cur_grad = grad_func(curr_params, carry["rest"], ref_state, config, microbatch) + carry["loss"] += aux.total_loss + carry["grad"] = jax.tree_util.tree_map(lambda x, y: x * aux.total_weights + y, cur_grad, carry["grad"]) + carry["total_weights"] += aux.total_weights + carry["rest"] = new_rest + return carry, aux + + init_carry = { + "loss": 0.0, + "grad": jax.tree_util.tree_map(jnp.zeros_like, curr_params), + "total_weights": 0.0, + "rest": rest, + } + carry, aux = jax.lax.scan(accumulate_gradient, init_carry, ga_data, length=config.gradient_accumulation_steps) + # total_loss is already a per-batch mean (and includes moe_lb), so the full-batch + # loss is the mean across the equal-sized microbatches. + loss = carry["loss"] / config.gradient_accumulation_steps + raw_grads = jax.tree_util.tree_map(lambda arr: arr / carry["total_weights"], carry["grad"]) + aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) + new_rest = carry["rest"] + nnx.update(state.model, new_rest) if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) else: grads = raw_grads + if config.optimizer_memory_host_offload: + # Mirror the pre-train NNX path: move the optimizer state from pinned_host to + # device before the in-place optimizer update. (The Linen GRPO path also casts + # params/reference to bf16 under this flag; NNX host-offload moves the memory + # kind without casting, matching the pre-train NNX convention.) + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, jax.sharding.NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + nnx.update(state.optimizer, jax.device_put(opt_state, device_opt_shardings)) state.apply_gradients(grads) new_state = state @@ -524,11 +564,14 @@ def diff_wrapper(param, rest, ref_state, config, data): "learning/completion_length": aux.completion_length, "learning/moe_lb_loss": aux.moe_lb_loss, "learning/total_weights": aux.total_weights, - "learning/grad_norm": max_utils.l2norm_pytree(grads), - "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), } - new_policy_params = nnx.state(new_state.model, nnx.Param) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) + # These norms pull host-resident tensors back to device, defeating the offload, + # so skip them when offloading (matches the Linen GRPO path). + if not config.optimizer_memory_host_offload: + scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) + scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) + new_policy_params = nnx.state(new_state.model, nnx.Param) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) metrics = {"scalar": scalar_metrics, "scalars": {}} return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 8989405eab..946ae552ef 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -244,8 +244,14 @@ def pathways_reshard_nnx( Splits the policy `nnx.Param` state out of the training-side TrainStateNNX model substate, reshards it onto the inference mesh, and pushes the - resharded params into the inference engine. Requires `scan_layers=True`; - the Linen `unscan_train_state_params` helper has no NNX equivalent yet. + resharded params into the inference engine. + + Unlike Linen — where the policy is always scanned and must be unrolled via + `unscan_train_state_params` when the inference side is unscanned — the NNX + policy model is built per `config.scan_layers` (scanned: a single stacked + `decoder/layers` subtree; unscanned: per-layer `decoder/layers/{i}`). The + inference-side model is built from the same config, so both layouts already + match and `reshard_pytree` maps them directly without an explicit unscan. Args: config: Training configuration object. @@ -255,10 +261,6 @@ def pathways_reshard_nnx( because the same shardings are already attached to the params. destination_shardings_model: Shardings for the inference-side model. """ - if not config.scan_layers: - raise NotImplementedError( - "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." - ) policy_params = nnx.state(policy_state_model, nnx.Param) source_param_shardings = nnx.state(source_shardings_model, nnx.Param) dest_param_shardings = nnx.state(destination_shardings_model, nnx.Param) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index f5232e0d89..5377e86ec3 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -399,9 +399,7 @@ def _load_params_nnx(self, params, rng): # axis metadata but no physical .sharding. Resolve logical to physical here so # device_put actually reshards instead of being a no-op. with nn_partitioning.axis_rules(self.config.logical_axis_rules): - target_shardings = sharding.nnx_construct_named_sharding( - params_abs, self._mesh - ) + target_shardings = sharding.nnx_construct_named_sharding(params_abs, self._mesh) params_state = jax.device_put(params, target_shardings) # We only need a concrete `rest` (RNG vars) for nnx.merge. create_nnx_sharded_model # builds the model with a jitted out_shardings so params are produced already @@ -409,9 +407,7 @@ def _load_params_nnx(self, params, rng): # large models). self.model is abstract with no .sharding, so pass an explicit one. _, full_abs = nnx.split(self.model) with nn_partitioning.axis_rules(self.config.logical_axis_rules): - full_sharding = sharding.nnx_construct_named_sharding( - full_abs, self._mesh - ) + full_sharding = sharding.nnx_construct_named_sharding(full_abs, self._mesh) concrete_model = maxtext_utils_nnx.create_nnx_sharded_model( self.model, self._create_model_fn, mesh=self._mesh, named_sharding=full_sharding ) @@ -2047,11 +2043,18 @@ def set_engine_vars_from_base_engine( """Set internal vars from base_engine, which has already loaded the checkpoint and has sharding, mesh, and kv cache related vars set. """ - if base_engine.model.quant: + if not engine.config.pure_nnx and base_engine.model.quant: + # NNX bakes the quant mode in at construction (via _nnx_quant_mode_str) rather + # than mutating model.quant.quant_mode, so there's nothing to copy on that path. engine.model.quant.quant_mode = base_engine.model.quant.quant_mode engine.state_mesh_annotations = base_engine.state_mesh_annotations engine.abstract_params = base_engine.abstract_params - engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access + if engine.config.pure_nnx: + # Linen's get_kv_cache_annotations calls model.init(); NNX modules have no + # .init, so use the abstract-model variant (mirrors _load_params_nnx). + engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(engine.model_ar, engine.config, engine.mesh) + else: + engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access engine.kv_cache_shardings = jax.tree_util.tree_map( lambda x: jax.sharding.NamedSharding(engine.mesh, x), engine.kv_cache_annotations, # pylint: disable=protected-access diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 6810f9a86c..2aa868cb1e 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1891,12 +1891,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): # for efficiency, as the main model is frozen and the LM loss is not needed. elif ( cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training - ) and self.model_mode == MODEL_MODE_TRAIN: + ) and model_mode == MODEL_MODE_TRAIN: logits = None # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 07faeefda2..2d4df4e44c 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -126,11 +126,14 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: """Convert a dict of NNX variables (or variable states) to Linen-style variables.""" linen_structured = {} for kp, v in nnx.traversals.flatten_mapping(nnx_attrs).items(): - if isinstance(v, variablelib.Variable): - col_name = variablelib.variable_name_from_type(v.type) - v = to_linen_var(v) - else: - raise ValueError(f"Cannot infer collection name from value: {v}") + if not isinstance(v, variablelib.Variable): + # Plain (non-Variable) attributes aren't Linen collections, so they have no + # place in the variables dict passed to the wrapped module's apply(). Qwix + # attaches bookkeeping attrs like qwix_path/qwix_rngs/disable_quant_stats_update + # to the module during interception; leave them on the module and skip them here. + continue + col_name = variablelib.variable_name_from_type(v.type) + v = to_linen_var(v) linen_structured[(col_name, *kp)] = v variables = nnx.traversals.unflatten_mapping(linen_structured) return variables diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index b2c06aaeef..ba8c5c13bb 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -208,15 +208,15 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs["mtp_losses"] = nnx.pop(model, mtp_losses).to_pure_dict() intermediate_outputs["mtp_acceptance"] = nnx.pop(model, mtp_acceptance).to_pure_dict() - if config.num_vocab_tiling > 1: - hidden_state_key = ("decoder", "hidden_states") - hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] - xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) - elif (config.use_indexer and not config.indexer_sparse_training) and is_train: + if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 total_z_loss = 0.0 + elif config.num_vocab_tiling > 1: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) else: one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) @@ -293,8 +293,25 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr # get MoE routed bias term updates moe_bias_updates = None if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) + if isinstance(model, nn.Module): + nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") + moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) + else: + # NNX intermediates are model-rooted (no "intermediates" prefix), so match by + # suffix instead. Unlike collect_intermediates_by_suffix we must not ravel: + # the update is a 2-D matrix that's transposed at the apply site below. + moe_bias_updates = next( + ( + val + for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs) + if tuple(k.key for k in path if hasattr(k, "key"))[-1:] == ("moe_bias_updates",) + ), + None, + ) + if moe_bias_updates is not None: + # The Linen path returns the sow tuple and indexes [0] downstream; tree_leaves + # already descended that tuple, so wrap it back so the apply site is uniform. + moe_bias_updates = (moe_bias_updates,) # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. @@ -504,7 +521,14 @@ def move(path, value): opt_state = nnx.state(state.optimizer) new_opt_state = jax.device_put(opt_state, device_opt_shardings) nnx.update(state.optimizer, new_opt_state) - state.apply_gradients(grads) + if config.skip_step_on_spikes: + # The skip-step optimizer is a GradientTransformationExtraArgs that reads + # loss/grad_norm to decide whether to zero the update on a spike. nnx + # Optimizer.update forwards these kwargs to tx.update. + grad_norm = max_utils.l2norm_pytree(grads) + state.apply_gradients(grads, loss=loss, grad_norm=grad_norm) + else: + state.apply_gradients(grads) new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family @@ -542,10 +566,15 @@ def move(path, value): model_params = nnx.state(new_state.model, nnx.Param) scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) - # Surface skip-step rejections as a TB metric. Linen path only — the NNX - # branch doesn't apply skip-step, so new_opt_state stays None. + # Surface skip-step rejections as a TB metric. The skip-step optimizer stores + # is_skipped in its opt_state: the Linen path gets it from the tx.update return, + # the NNX path reads it back off the optimizer it just updated in place. if config.skip_step_on_spikes: - is_skipped = new_opt_state.get("is_skipped") if isinstance(new_opt_state, dict) else None + if isinstance(model, nn.Module): + is_skipped = new_opt_state.get("is_skipped") if isinstance(new_opt_state, dict) else None + else: + opt_state = nnx.to_pure_dict(nnx.state(new_state.optimizer)).get("opt_state", {}) + is_skipped = opt_state.get("is_skipped") if isinstance(opt_state, dict) else None if is_skipped is not None: scalar_metrics["optim/step_skipped"] = is_skipped.astype(jnp.float32) metrics = { diff --git a/tests/unit/grpo_nnx_test.py b/tests/unit/grpo_nnx_test.py index 77f6361b9b..259ebae267 100644 --- a/tests/unit/grpo_nnx_test.py +++ b/tests/unit/grpo_nnx_test.py @@ -23,10 +23,13 @@ import jax import jax.numpy as jnp import numpy as np +import optax from flax import nnx +from maxtext.common import train_state_nnx from maxtext.experimental.rl import grpo_trainer from maxtext.experimental.rl import grpo_utils +from maxtext.utils import maxtext_utils_nnx class _MockTransformer(nnx.Module): @@ -162,6 +165,132 @@ def test_returns_correct_shape(self): self.assertEqual(log_probs.shape, (data["prompt_completions"].shape[0], data["prompt_completions"].shape[1] - 1)) +class TestGrpoTrainStepNnxGradAccum(unittest.TestCase): + """Gradient accumulation on the NNX GRPO step must match a single full-batch step.""" + + def _step(self, ga_steps): + """Run one `_train_step_nnx` from a fixed init; return (loss, updated policy params).""" + # pylint: disable=protected-access # the test drives the internal _train_step_nnx + policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Different seed from the policy so KL(policy||reference) != 0 and the step + # produces a real (non-zero) gradient — otherwise the equivalence check is vacuous. + reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + optimizer = nnx.Optimizer(policy, optax.sgd(0.1), wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(policy, optimizer) + state.reference_model = reference + graphdef, flat_state = nnx.split(state) + config = _make_grpo_config( + gradient_accumulation_steps=ga_steps, + gradient_clipping_threshold=0.0, + optimizer_memory_host_offload=False, + ) + # B=2, G=2 -> 4 rows; GA=2 splits into 2 microbatches that each hold one + # complete generation-group, so per-group advantages are unchanged and the + # accumulated gradient must equal the full-batch gradient. + data = _make_grpo_batch(B=2, G=2, S=6) + new_flat, metrics = grpo_trainer._train_step_nnx(graphdef, config, None, flat_state, data) + updated = nnx.merge(graphdef, new_flat) + params = jax.tree_util.tree_leaves(nnx.to_pure_dict(nnx.state(updated.model, nnx.Param))) + return float(metrics["scalar"]["learning/loss"]), params + + def test_gradient_accumulation_matches_single_shot(self): + loss_full, params_full = self._step(ga_steps=1) + loss_ga, params_ga = self._step(ga_steps=2) + + # Guard against a trivial pass: the step must actually move the params. + init = jax.tree_util.tree_leaves( + nnx.to_pure_dict(nnx.state(_MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)), nnx.Param)) + ) + moved = any(not np.allclose(np.asarray(p), np.asarray(i)) for p, i in zip(params_full, init)) + self.assertTrue(moved, "params did not change — test would be trivially true") + + # GA=2 must reproduce the full-batch step's loss and resulting parameters. + np.testing.assert_allclose(loss_ga, loss_full, rtol=1e-5, atol=1e-5) + self.assertEqual(len(params_full), len(params_ga)) + # GA reorders the per-microbatch gradient summation, so on lower-precision hardware + # (TPU bf16 matmuls) the updated params differ from the full-batch step by + # accumulation rounding (~1e-6 absolute); fp32/CPU matches to ~1e-10. A real GA bug + # (wrong normalization) would be a gross mismatch, so this tolerance still catches it. + for pf, pg in zip(params_full, params_ga): + np.testing.assert_allclose(np.asarray(pg), np.asarray(pf), rtol=1e-2, atol=1e-5) + + +class TestPathwaysReshardNnxScanLayersFalse(unittest.TestCase): + """scan_layers=False must no longer raise; the unscanned policy params reshard to the engine. + + The actual `reshard_pytree` runs through pathwaysutils (Pathways infra, not available off-cluster), + so it's mocked to a pass-through — the test pins *our* change: the guard is gone and the policy + params are pushed to the inference engine. + """ + + def test_unscanned_policy_pushes_params_to_engine(self): + captured = {} + + class _Engine: + + def update_params(self, params): + captured["params"] = params + + original_reshard = grpo_utils.reshard_pytree + grpo_utils.reshard_pytree = lambda source, target, **kw: source # pass-through (skip Pathways) + try: + policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + shardings_model = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + cfg = _make_grpo_config(scan_layers=False) + grpo_utils.pathways_reshard_nnx(cfg, _Engine(), policy, shardings_model, shardings_model) + finally: + grpo_utils.reshard_pytree = original_reshard + + self.assertIn("params", captured) # no NotImplementedError; engine received params + pushed = jax.tree_util.tree_leaves(captured["params"]) + expected = jax.tree_util.tree_leaves(nnx.state(_MockTransformer(8, 4, nnx.Rngs(0)), nnx.Param)) + self.assertEqual(len(pushed), len(expected)) + + +class TestGrpoHostOffloadNnx(unittest.TestCase): + """optimizer_memory_host_offload must run and not change the math (only memory placement). + + The memory-kind move needs TPU host-offload, so `move_memory_to_device` is mocked to identity; + the test verifies the surrounding plumbing (extract opt_state, device_put, nnx.update, then + apply_gradients) runs and yields the same params as the no-offload step. + """ + + def _step(self, host_offload): + """Run one GRPO `_train_step_nnx` with/without host-offload; return (loss, policy params).""" + # pylint: disable=protected-access # the test drives the internal _train_step_nnx + policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + optimizer = nnx.Optimizer(policy, optax.sgd(0.1), wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(policy, optimizer) + state.reference_model = reference + graphdef, flat_state = nnx.split(state) + cfg = _make_grpo_config( + gradient_accumulation_steps=1, gradient_clipping_threshold=0.0, optimizer_memory_host_offload=host_offload + ) + sms = None + original = maxtext_utils_nnx.move_memory_to_device + if host_offload: + mesh = jax.make_mesh((1,), ("x",)) + replicated = jax.tree.map( + lambda _: jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), nnx.state(state.optimizer) + ) + sms = types.SimpleNamespace(optimizer=replicated) + maxtext_utils_nnx.move_memory_to_device = lambda path, x: x # identity (skip TPU memory kinds) + try: + new_flat, metrics = grpo_trainer._train_step_nnx(graphdef, cfg, sms, flat_state, _make_grpo_batch(B=2, G=2, S=6)) + finally: + maxtext_utils_nnx.move_memory_to_device = original + params = jax.tree_util.tree_leaves(nnx.to_pure_dict(nnx.state(nnx.merge(graphdef, new_flat).model, nnx.Param))) + return float(metrics["scalar"]["learning/loss"]), params + + def test_host_offload_matches_no_offload(self): + loss_off, params_off = self._step(host_offload=False) + loss_on, params_on = self._step(host_offload=True) + np.testing.assert_allclose(loss_on, loss_off, rtol=1e-6, atol=1e-6) + for a, b in zip(params_on, params_off): + np.testing.assert_allclose(np.asarray(a), np.asarray(b), rtol=1e-6, atol=1e-6) + + # --------------------------------------------------------------------------- # Linen-path regression smoke tests # --------------------------------------------------------------------------- diff --git a/tests/unit/maxengine_nnx_test.py b/tests/unit/maxengine_nnx_test.py new file mode 100644 index 0000000000..08303f817a --- /dev/null +++ b/tests/unit/maxengine_nnx_test.py @@ -0,0 +1,70 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for NNX dispatch in maxengine (no jetstream / checkpoint needed).""" + +import sys +import unittest + +import jax +import pytest + +from maxtext.configs import pyconfig + +pytest.importorskip("jetstream", reason="jetstream not installed") +from maxtext.inference.maxengine import maxengine +from tests.utils.test_helpers import get_test_config_path + + +class SetEngineVarsNNXTest(unittest.TestCase): + """set_engine_vars_from_base_engine must work on the NNX path.""" + + def _nnx_config(self, **kwargs): + """Tiny pure-NNX config from base.yml with the given overrides.""" + init_kwargs = { + "base_emb_dim": 32, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "base_num_decoder_layers": 2, + "max_prefill_predict_length": 4, + "max_target_length": 8, + "per_device_batch_size": 1, + "enable_checkpointing": False, + "pure_nnx": True, + "enable_nnx": True, + "pure_nnx_decoder": True, + } | kwargs + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) + + def test_set_engine_vars_from_base_engine_nnx(self): + """NNX dispatches to get_kv_cache_annotations_nnx; the Linen model.init() path AttributeErrors. + + state_mesh_annotations / abstract_params are merely copied from the base engine, + so they're stubbed here — that lets the test exercise the kv-cache-annotations + dispatch without loading a checkpoint. + """ + cfg = self._nnx_config() + engine = maxengine.MaxEngine(cfg, jax.devices()) + engine.state_mesh_annotations = None + engine.abstract_params = None + + maxengine.set_engine_vars_from_base_engine(engine, engine, jax.random.PRNGKey(0)) + + self.assertIsNotNone(engine.kv_cache_annotations) + self.assertIsNotNone(engine.kv_cache_shardings) + self.assertGreater(len(jax.tree_util.tree_leaves(engine.kv_cache_annotations)), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoder_test.py b/tests/unit/nnx_decoder_test.py new file mode 100644 index 0000000000..70d59eb6a9 --- /dev/null +++ b/tests/unit/nnx_decoder_test.py @@ -0,0 +1,85 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The NNX decoder logits guards must read the model_mode passed to __call__. + +The vocab-tiling (and indexer warm-up) guards skip the output head only in TRAIN. +They must key off the model_mode argument, not the model_mode fixed at construction: +the same model invoked in TRAIN must skip the head (logits None) while invoked in a +serving mode must run it (real logits). +""" + +import sys +import unittest + +import jax +import jax.numpy as jnp +from flax import nnx + +from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_TRAIN +from maxtext.configs import pyconfig +from maxtext.utils import model_creation_utils +from tests.utils.test_helpers import get_test_config_path + + +class DecoderLogitsGuardModelModeTest(unittest.TestCase): + """A tiny pure-NNX model with vocab tiling, built once and called in two modes. + + Built in PREFILL so the attention KV cache is allocated (TRAIN construction leaves + it None, which a serving call can't populate). Whether the output head runs must + then depend solely on the call-arg model_mode. + """ + + def _model(self): + """Build a tiny pure-NNX model with vocab tiling, constructed in PREFILL mode.""" + cfg = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + base_emb_dim=32, + base_num_query_heads=2, + base_num_kv_heads=2, + base_mlp_dim=64, + base_num_decoder_layers=2, + head_dim=16, + max_target_length=16, + max_prefill_predict_length=8, + per_device_batch_size=1, + enable_checkpointing=False, + scan_layers=False, + num_vocab_tiling=2, + pure_nnx=True, + enable_nnx=True, + pure_nnx_decoder=True, + ) + model = model_creation_utils.from_config(cfg, devices=jax.devices(), model_mode=MODEL_MODE_PREFILL, rngs=nnx.Rngs(0)) + return cfg, model + + def test_guard_follows_call_arg_not_construction_mode(self): + cfg, model = self._model() + seq = cfg.max_prefill_predict_length + toks = jnp.ones((1, seq), dtype=jnp.int32) + pos = jnp.broadcast_to(jnp.arange(seq), toks.shape) + + # Called in a serving mode, the output head runs -> real logits. + logits_serving = model(toks, pos, model_mode=MODEL_MODE_PREFILL, enable_dropout=False) + self.assertIsNotNone(logits_serving) + self.assertEqual(logits_serving.shape[-1], cfg.vocab_size) + + # Called in TRAIN with vocab tiling on, the head is skipped -> logits None. Keying + # off the construction mode (PREFILL) instead of the call-arg would run the head. + logits_train = model(toks, pos, model_mode=MODEL_MODE_TRAIN, enable_dropout=False) + self.assertIsNone(logits_train) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_quant_guard_test.py b/tests/unit/nnx_quant_guard_test.py new file mode 100644 index 0000000000..50cac5d349 --- /dev/null +++ b/tests/unit/nnx_quant_guard_test.py @@ -0,0 +1,78 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""qwix + NNX coverage: the config guard and the ToNNX->Linen bridge. + +- Config guard: qwix quantization under pure_nnx requires the pure NNX decoder. + The bridged Linen decoder (pure_nnx_decoder=False) is invisible to qwix, so + quantization/sparsity would silently no-op; validation must reject that combo. +- Bridge: nnx_attrs_to_linen_vars must skip qwix's non-Variable bookkeeping attrs + (qwix_path/qwix_rngs/disable_quant_stats_update) instead of raising. +""" + +import sys +import unittest + +import jax.numpy as jnp +from flax import nnx + +from maxtext.configs import pyconfig +from maxtext.layers import nnx_wrappers +from tests.utils.test_helpers import get_test_config_path + + +class QwixNnxQuantGuardTest(unittest.TestCase): + + def _init(self, **overrides): + overrides.setdefault("enable_checkpointing", False) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **overrides) + + def test_bridged_decoder_with_qwix_quant_raises(self): + with self.assertRaisesRegex(Exception, "pure_nnx_decoder"): + self._init(pure_nnx=True, pure_nnx_decoder=False, use_qwix_quantization=True, quantization="fp8_full") + + def test_pure_nnx_decoder_with_qwix_quant_ok(self): + cfg = self._init(pure_nnx=True, pure_nnx_decoder=True, use_qwix_quantization=True, quantization="fp8_full") + self.assertTrue(cfg.pure_nnx_decoder) + + def test_bridged_decoder_without_quant_ok(self): + cfg = self._init(pure_nnx=True, pure_nnx_decoder=False, quantization="") + self.assertEqual(cfg.quantization, "") + + +class NnxAttrsToLinenVarsBridgeTest(unittest.TestCase): + """The ToNNX->Linen conversion must skip qwix's non-Variable attrs.""" + + def test_non_variable_attrs_are_skipped(self): + # qwix attaches plain attrs (qwix_path / disable_quant_stats_update) during + # interception; before the fix these raised "Cannot infer collection name". + attrs = { + "kernel": nnx.Param(jnp.ones((2, 3))), + "qwix_path": ("decoder", "layer"), + "disable_quant_stats_update": True, + } + out = nnx_wrappers.nnx_attrs_to_linen_vars(attrs) # must not raise + keys = {k for kp in nnx.traversals.flatten_mapping(out) for k in kp} + self.assertIn("params", keys) # the real Variable survived, under its collection + self.assertIn("kernel", keys) + self.assertNotIn("qwix_path", keys) + self.assertNotIn("disable_quant_stats_update", keys) + + def test_only_non_variable_attrs_yields_empty(self): + out = nnx_wrappers.nnx_attrs_to_linen_vars({"qwix_path": ("x",), "qwix_rngs": 0}) + self.assertEqual(out, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index 9335b1433b..c11f5ea3b7 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -662,6 +662,7 @@ def test_maybe_quantize_model_pops_intermediates(self): use_qwix_quantization=True, use_batch_split_schedule=False, pure_nnx=True, + pure_nnx_decoder=True, micro_batch_size_to_train_on=1, max_target_length=2, ) diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 6467c6f196..c2d642fba7 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -20,12 +20,16 @@ """ from dataclasses import dataclass +import types as pytypes import unittest from flax import nnx import jax import jax.numpy as jnp +import numpy as np from maxtext.common import train_state_nnx +from maxtext.common.metric_logger import record_activation_metrics +from maxtext.optimizers import optimizers from maxtext.trainers.pre_train import train as pre_train import optax @@ -95,6 +99,16 @@ def __call__( return self.proj(h) +class _TinyDecoderMoEBias(_TinyDecoder): + """`_TinyDecoder` that also sows a `moe_bias_updates` intermediate (DeepSeek routed-bias).""" + + def __call__(self, decoder_input_tokens, decoder_positions, **kwargs): + out = super().__call__(decoder_input_tokens, decoder_positions, **kwargs) + # 2-D so the downstream `[0].transpose()` in train_step is shape-valid. + self.sow(nnx.Intermediate, "moe_bias_updates", jnp.ones((2, 3))) + return out + + def _make_data(batch=2, seq=4, vocab=8): return { "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), @@ -156,6 +170,19 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) + def test_indexer_warmup_precedes_vocab_tiling(self): + # The indexer dense warm-up branch must be checked before the num_vocab_tiling>1 + # branch. With the order reversed, a warm-up step with tiling on ran the + # vocab-tiling loss instead of skipping xent. With both on, xent must still be 0. + cfg, ts = _build_state() + cfg.use_indexer = True + cfg.indexer_sparse_training = False + cfg.num_vocab_tiling = 2 + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + self.assertEqual(float(aux["xent_sum"]), 0.0) + self.assertEqual(float(loss), 0.0) + class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" @@ -218,5 +245,110 @@ def test_eval_step_returns_metrics(self): self.assertTrue(jnp.isfinite(metrics["scalar"]["evaluation/loss"])) +class TestSkipStepOnSpikesNNX(unittest.TestCase): + """The NNX optimizer must actually skip a loss/grad spike — i.e. apply_gradients forwards + loss/grad_norm to the GradientTransformationExtraArgs, and a skipped step freezes params.""" + + def _is_skipped(self, optimizer): + return bool(nnx.to_pure_dict(nnx.state(optimizer))["opt_state"]["is_skipped"]) + + def test_spike_is_skipped_and_params_frozen(self): + model = _TinyDecoder(8, hidden=4, rngs=nnx.Rngs(0)) + tx = optimizers.skip_step_on_spikes(optax.sgd(0.1), interval=4, scaling_factor=6.0) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(model, optimizer) + grads = jax.tree.map(jnp.ones_like, nnx.state(model, nnx.Param)) + + # Prime a stable baseline (mean≈1, std≈0); these are applied, not skipped. + for _ in range(3): + state.apply_gradients(grads, loss=jnp.float32(1.0), grad_norm=jnp.float32(1.0)) + self.assertFalse(self._is_skipped(optimizer)) + + before = [np.asarray(x) for x in jax.tree_util.tree_leaves(nnx.to_pure_dict(nnx.state(model, nnx.Param)))] + # A large spike must be skipped (params unchanged). If apply_gradients did NOT forward + # loss/grad_norm, the optimizer would never skip and this would fail. + state.apply_gradients(grads, loss=jnp.float32(1e3), grad_norm=jnp.float32(1e3)) + self.assertTrue(self._is_skipped(optimizer)) + after = [np.asarray(x) for x in jax.tree_util.tree_leaves(nnx.to_pure_dict(nnx.state(model, nnx.Param)))] + for b, a in zip(before, after): + np.testing.assert_allclose(a, b) + + +class TestRoutedBiasReadNNX(unittest.TestCase): + """loss_fn must find the DeepSeek `moe_bias_updates` intermediate on the NNX (model-rooted) shape.""" + + def test_routed_bias_update_found_by_suffix(self): + cfg = _Cfg() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + model = _TinyDecoderMoEBias(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + _, aux = pre_train.loss_fn(model, cfg, data, None, None, is_train=True) + self.assertIsNotNone(aux["moe_bias_updates"]) + np.testing.assert_allclose(np.asarray(aux["moe_bias_updates"][0]), np.ones((2, 3))) + + def test_routed_bias_disabled_returns_none(self): + cfg = _Cfg() # routed_bias=False + model = _TinyDecoderMoEBias(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + _, aux = pre_train.loss_fn(model, cfg, data, None, None, is_train=True) + self.assertIsNone(aux["moe_bias_updates"]) + + +class TestRecordActivationMetricsParity(unittest.TestCase): + """record_activation_metrics must yield identical metrics for Linen- and NNX-shaped intermediates. + + Linen sows into the "intermediates" collection; NNX's `nnx.pop(...).to_pure_dict()` is + model-rooted with no "intermediates" prefix. The fix routes the NNX shape through a + suffix collector — this test pins that both shapes produce the same per-layer numbers. + """ + + def _metrics(self, intermediates, scan_layers, num_layers): + cfg = pytypes.SimpleNamespace(scan_layers=scan_layers, num_decoder_layers=num_layers) + out = {"scalar": {}} + record_activation_metrics(out, intermediates, cfg) + return out["scalar"] + + def test_scanned_layout_linen_matches_nnx(self): + num_layers = 3 + mean, std, fz = jnp.array([0.1, 0.2, 0.3]), jnp.array([1.0, 1.1, 1.2]), jnp.array([0.5, 0.4, 0.3]) + triples = {"activation_mean": (mean,), "activation_stdev": (std,), "activation_fraction_zero": (fz,)} + # Linen scanned: intermediates/decoder/decoder/[0][layer] + linen = {"intermediates": {"decoder": {"decoder": triples}}} + # NNX scanned: model-rooted, one stacked array per key (no "intermediates" prefix) + nnx_shaped = {"decoder": {"layers": triples}} + + m_linen = self._metrics(linen, scan_layers=True, num_layers=num_layers) + m_nnx = self._metrics(nnx_shaped, scan_layers=True, num_layers=num_layers) + self.assertEqual(set(m_linen), set(m_nnx)) + for key, expected in m_linen.items(): + np.testing.assert_allclose(np.asarray(m_nnx[key]), np.asarray(expected)) + np.testing.assert_allclose(np.asarray(m_nnx["activ_mean/layer_001"]), 0.2) + + def test_unscanned_layout_linen_matches_nnx(self): + num_layers = 3 + means, stds, fzs = [0.1, 0.2, 0.3], [1.0, 1.1, 1.2], [0.5, 0.4, 0.3] + + def per_layer(d, n): + return { + "activation_mean": (jnp.array(d[0][n]),), + "activation_stdev": (jnp.array(d[1][n]),), + "activation_fraction_zero": (jnp.array(d[2][n]),), + } + + data = (means, stds, fzs) + # Linen unscanned: intermediates/decoder/layers_/[0] + linen = {"intermediates": {"decoder": {f"layers_{n}": per_layer(data, n) for n in range(num_layers)}}} + # NNX unscanned: model-rooted per-layer entries (one leaf per layer, matched by suffix) + nnx_shaped = {"decoder": {f"layers_{n}": per_layer(data, n) for n in range(num_layers)}} + + m_linen = self._metrics(linen, scan_layers=False, num_layers=num_layers) + m_nnx = self._metrics(nnx_shaped, scan_layers=False, num_layers=num_layers) + self.assertEqual(set(m_linen), set(m_nnx)) + for key, expected in m_linen.items(): + np.testing.assert_allclose(np.asarray(m_nnx[key]), np.asarray(expected)) + np.testing.assert_allclose(np.asarray(m_nnx["activ_stdev/layer_002"]), 1.2) + + if __name__ == "__main__": unittest.main()