Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ def _prepare_metrics_for_json(metrics, step, run_name):


def record_activation_metrics(output_metrics, intermediate_outputs, config):
"""Adds the activation metrics to the metrics dict"""
"""Adds the activation metrics to the metrics dict.

if config.scan_layers:
metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"]

for layer_num in range(config.num_decoder_layers):
output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][
0
][layer_num]
output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num]
output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num]
else:
Collects each metric by path suffix rather than a hardcoded path, so it works for
both the Linen ("intermediates"-prefixed) and NNX (model-rooted) layouts and for
both scanned (one stacked leaf) and unscanned (one leaf per layer) decoders.
"""
for label, key in (
("activ_fraction_zero", "activation_fraction_zero"),
("activ_mean", "activation_mean"),
("activ_stdev", "activation_stdev"),
):
vals = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, key)
if not vals:
continue
per_layer = jax.numpy.concatenate(vals)
for layer_num in range(config.num_decoder_layers):
layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"]
output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0]
output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0]
output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0]
output_metrics["scalar"][f"{label}/layer_{layer_num:03d}"] = per_layer[layer_num]


class MetadataKey(enum.Enum):
Expand Down
7 changes: 4 additions & 3 deletions src/maxtext/common/train_state_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ def __init__(
self.model = model
self.optimizer = optimizer

def apply_gradients(self, grads: Any):
def apply_gradients(self, grads: Any, **kwargs):
"""Mimics the Linen apply_gradients function.

Updates the optimizer state, applies updates to parameters, and increments
the step counter. Only updates `self.model`.
the step counter. Only updates `self.model`. Extra kwargs (e.g. loss/grad_norm
for the skip-step-on-spikes optimizer) are forwarded to the optax update.
"""
if self.optimizer is None:
raise RuntimeError(
"Cannot call apply_gradients on a TrainStateNNX initialized without"
" an optimizer. This usually happens when the state was created for"
" inference only."
)
self.optimizer.update(self.model, grads)
self.optimizer.update(self.model, grads, **kwargs)


# On-disk checkpoint format.
Expand Down
18 changes: 18 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,16 @@ class RLHardware(BaseModel):
description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.",
)
rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout")
inference_replicas: int = Field(1, description="Legacy experimental GRPO: number of inference (sampler) replicas.")
inference_devices_per_replica: int = Field(
4, description="Legacy experimental GRPO: devices per inference replica (single-controller device split)."
)
inference_rollouts: int = Field(
1, description="Legacy experimental GRPO: refresh rollouts every N steps (step % inference_rollouts)."
)
use_pathways_reshard: bool = Field(
True, description="Legacy experimental GRPO: use Pathways resharding to move policy params to the sampler."
)


class VLLM(BaseModel):
Expand Down Expand Up @@ -2696,6 +2706,14 @@ def validate_and_set_hlo_dump_defaults():
if not self.enable_nnx:
raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True")

if self.pure_nnx and not self.pure_nnx_decoder and self.use_qwix_quantization and not self.use_batch_split_schedule:
if self.quantization:
raise ValueError(
f"quantization='{self.quantization}' with use_qwix_quantization=True under pure_nnx=True requires "
"pure_nnx_decoder=True. The bridged Linen decoder (pure_nnx_decoder=False) is invisible to Qwix, "
"so quantization (and weight sparsity) would silently have no effect. Set pure_nnx_decoder=True."
)

# Validate distillation schedule parameters
if self.distill_alpha_end is not None and not 0.0 <= self.distill_alpha_end <= 1.0:
raise ValueError(f"distill_alpha_end must be in [0, 1], got {self.distill_alpha_end}")
Expand Down
73 changes: 58 additions & 15 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,23 +471,15 @@ def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data):
Args:
model_graphdef: NNX `GraphDef` of the `TrainStateNNX`.
config: Training configuration object.
state_mesh_shardings: Sharding spec for the train state. Unused on this
path; kept for signature parity with `train_step`.
state_mesh_shardings: Sharding spec for the train state; used to move the
optimizer state to device when `optimizer_memory_host_offload` is set.
state: Flat `nnx.State` matching `model_graphdef`.
data: A batch dict produced by the GRPO input pipeline.

Returns:
A tuple `(new_state, metrics)`. `new_state` is filtered to exclude
`nnx.Intermediate`. `metrics` is a dict shaped like the Linen path's.
"""
del state_mesh_shardings # Host-offload paths are not yet wired up here.

if config.gradient_accumulation_steps > 1:
raise NotImplementedError(
"GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. "
"Set gradient_accumulation_steps=1 or pure_nnx=False."
)

state = nnx.merge(model_graphdef, state) # Reconstruct the TrainStateNNX.
policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...)
# Split the reference model into (graphdef, state) so we pass `ref_state` as
Expand All @@ -505,13 +497,61 @@ def diff_wrapper(param, rest, ref_state, config, data):
return loss, (aux, new_rest)

grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True)
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data)

if config.gradient_accumulation_steps <= 1:
(loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data)
else:
# Mirror the pre-train NNX gradient-accumulation loop and the Linen GRPO one:
# params stay fixed across microbatches while the non-param state (rest, e.g.
# RNGs) advances in the scan carry. Grads are accumulated weighted by each
# microbatch's total_weights, then normalized once after the scan.
def reshape_to_microbatch_accumulations(batch_arr):
microbatches = config.gradient_accumulation_steps
microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:]
return jnp.reshape(batch_arr, microbatch_shape)

ga_data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data)

def accumulate_gradient(carry, microbatch):
(_, (aux, new_rest)), cur_grad = grad_func(curr_params, carry["rest"], ref_state, config, microbatch)
carry["loss"] += aux.total_loss
carry["grad"] = jax.tree_util.tree_map(lambda x, y: x * aux.total_weights + y, cur_grad, carry["grad"])
carry["total_weights"] += aux.total_weights
carry["rest"] = new_rest
return carry, aux

init_carry = {
"loss": 0.0,
"grad": jax.tree_util.tree_map(jnp.zeros_like, curr_params),
"total_weights": 0.0,
"rest": rest,
}
carry, aux = jax.lax.scan(accumulate_gradient, init_carry, ga_data, length=config.gradient_accumulation_steps)
# total_loss is already a per-batch mean (and includes moe_lb), so the full-batch
# loss is the mean across the equal-sized microbatches.
loss = carry["loss"] / config.gradient_accumulation_steps
raw_grads = jax.tree_util.tree_map(lambda arr: arr / carry["total_weights"], carry["grad"])
aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux)
new_rest = carry["rest"]

nnx.update(state.model, new_rest)

if config.gradient_clipping_threshold > 0:
grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold)
else:
grads = raw_grads
if config.optimizer_memory_host_offload:
# Mirror the pre-train NNX path: move the optimizer state from pinned_host to
# device before the in-place optimizer update. (The Linen GRPO path also casts
# params/reference to bf16 under this flag; NNX host-offload moves the memory
# kind without casting, matching the pre-train NNX convention.)
device_opt_shardings = jax.tree_util.tree_map_with_path(
maxtext_utils_nnx.move_memory_to_device,
state_mesh_shardings.optimizer,
is_leaf=lambda x: isinstance(x, jax.sharding.NamedSharding),
)
opt_state = nnx.state(state.optimizer)
nnx.update(state.optimizer, jax.device_put(opt_state, device_opt_shardings))
state.apply_gradients(grads)
new_state = state

Expand All @@ -524,11 +564,14 @@ def diff_wrapper(param, rest, ref_state, config, data):
"learning/completion_length": aux.completion_length,
"learning/moe_lb_loss": aux.moe_lb_loss,
"learning/total_weights": aux.total_weights,
"learning/grad_norm": max_utils.l2norm_pytree(grads),
"learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads),
}
new_policy_params = nnx.state(new_state.model, nnx.Param)
scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params)
# These norms pull host-resident tensors back to device, defeating the offload,
# so skip them when offloading (matches the Linen GRPO path).
if not config.optimizer_memory_host_offload:
scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads)
scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads)
new_policy_params = nnx.state(new_state.model, nnx.Param)
scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params)
metrics = {"scalar": scalar_metrics, "scalars": {}}

return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics
Expand Down
14 changes: 8 additions & 6 deletions src/maxtext/experimental/rl/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,14 @@ def pathways_reshard_nnx(

Splits the policy `nnx.Param` state out of the training-side TrainStateNNX
model substate, reshards it onto the inference mesh, and pushes the
resharded params into the inference engine. Requires `scan_layers=True`;
the Linen `unscan_train_state_params` helper has no NNX equivalent yet.
resharded params into the inference engine.

Unlike Linen — where the policy is always scanned and must be unrolled via
`unscan_train_state_params` when the inference side is unscanned — the NNX
policy model is built per `config.scan_layers` (scanned: a single stacked
`decoder/layers` subtree; unscanned: per-layer `decoder/layers/{i}`). The
inference-side model is built from the same config, so both layouts already
match and `reshard_pytree` maps them directly without an explicit unscan.

Args:
config: Training configuration object.
Expand All @@ -255,10 +261,6 @@ def pathways_reshard_nnx(
because the same shardings are already attached to the params.
destination_shardings_model: Shardings for the inference-side model.
"""
if not config.scan_layers:
raise NotImplementedError(
"GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
)
policy_params = nnx.state(policy_state_model, nnx.Param)
source_param_shardings = nnx.state(source_shardings_model, nnx.Param)
dest_param_shardings = nnx.state(destination_shardings_model, nnx.Param)
Expand Down
19 changes: 11 additions & 8 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,15 @@ def _load_params_nnx(self, params, rng):
# axis metadata but no physical .sharding. Resolve logical to physical here so
# device_put actually reshards instead of being a no-op.
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
target_shardings = sharding.nnx_construct_named_sharding(
params_abs, self._mesh
)
target_shardings = sharding.nnx_construct_named_sharding(params_abs, self._mesh)
params_state = jax.device_put(params, target_shardings)
# We only need a concrete `rest` (RNG vars) for nnx.merge. create_nnx_sharded_model
# builds the model with a jitted out_shardings so params are produced already
# sharded, avoiding a single-device allocation of the full model (an OOM risk for
# large models). self.model is abstract with no .sharding, so pass an explicit one.
_, full_abs = nnx.split(self.model)
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
full_sharding = sharding.nnx_construct_named_sharding(
full_abs, self._mesh
)
full_sharding = sharding.nnx_construct_named_sharding(full_abs, self._mesh)
concrete_model = maxtext_utils_nnx.create_nnx_sharded_model(
self.model, self._create_model_fn, mesh=self._mesh, named_sharding=full_sharding
)
Expand Down Expand Up @@ -2047,11 +2043,18 @@ def set_engine_vars_from_base_engine(
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
mesh, and kv cache related vars set.
"""
if base_engine.model.quant:
if not engine.config.pure_nnx and base_engine.model.quant:
# NNX bakes the quant mode in at construction (via _nnx_quant_mode_str) rather
# than mutating model.quant.quant_mode, so there's nothing to copy on that path.
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
engine.state_mesh_annotations = base_engine.state_mesh_annotations
engine.abstract_params = base_engine.abstract_params
engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access
if engine.config.pure_nnx:
# Linen's get_kv_cache_annotations calls model.init(); NNX modules have no
# .init, so use the abstract-model variant (mirrors _load_params_nnx).
engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(engine.model_ar, engine.config, engine.mesh)
else:
engine.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access
engine.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(engine.mesh, x),
engine.kv_cache_annotations, # pylint: disable=protected-access
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,12 +1891,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
# for efficiency, as the main model is frozen and the LM loss is not needed.
elif (
cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training
) and self.model_mode == MODEL_MODE_TRAIN:
) and model_mode == MODEL_MODE_TRAIN:
logits = None

# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow(nnx.Intermediate, "hidden_states", hidden_state)

Expand Down
13 changes: 8 additions & 5 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
"""Convert a dict of NNX variables (or variable states) to Linen-style variables."""
linen_structured = {}
for kp, v in nnx.traversals.flatten_mapping(nnx_attrs).items():
if isinstance(v, variablelib.Variable):
col_name = variablelib.variable_name_from_type(v.type)
v = to_linen_var(v)
else:
raise ValueError(f"Cannot infer collection name from value: {v}")
if not isinstance(v, variablelib.Variable):
# Plain (non-Variable) attributes aren't Linen collections, so they have no
# place in the variables dict passed to the wrapped module's apply(). Qwix
# attaches bookkeeping attrs like qwix_path/qwix_rngs/disable_quant_stats_update
# to the module during interception; leave them on the module and skip them here.
continue
col_name = variablelib.variable_name_from_type(v.type)
v = to_linen_var(v)
linen_structured[(col_name, *kp)] = v
variables = nnx.traversals.unflatten_mapping(linen_structured)
return variables
Expand Down
Loading
Loading