Skip to content

Add fp32 LM head, docs_per_step accumulation, and document-count metrics#520

Open
jlamypoirier wants to merge 8 commits into
mainfrom
jlp_rl_features
Open

Add fp32 LM head, docs_per_step accumulation, and document-count metrics#520
jlamypoirier wants to merge 8 commits into
mainfrom
jlp_rl_features

Conversation

@jlamypoirier

@jlamypoirier jlamypoirier commented May 19, 2026

Copy link
Copy Markdown
Collaborator

Summary

Training-side changes for RL fine-tuning, consolidated into a single PR to main
(previously stacked PRs #526 + #520; #526 is closed in favor of this one).

fp32_lm_head — FP32 LM head logits

New fp32_lm_head flag on LanguageModelHeadConfig (default False). When enabled, the LM head
upcasts input and weight to FP32 for the logits projection and casts back, matching vLLM's
bf16_last_layer_fp32, so the trainer computes log-probabilities at the same precision the actor
sampled with. Includes the gradient-flow fix for the detached FP32 weight copy — gradients are
accumulated back into the BF16 parameter's buffer.

Dynamic docs_per_step accumulation

New ScheduleConfig.docs_per_step field. When >0, each step accumulates microbatches one at a
time, all-reduces the per-microbatch document count, and stops once the global total reaches the
target, instead of using a fixed microbatch count. The final step total is broadcast to every
microbatch so the loss-normalization denominator stays consistent. Off by default
(docs_per_step=0 keeps the original static-schedule path).

num_documents / documents_seen metrics

Logs the per-step document count (the divisor docs_per_step produces) and the cumulative document
total as training metrics — lets the dynamic accumulation be verified, and gives documents-seen as a
cross-run x-axis. Gated on docs_per_step>0; no effect on the static path.

GSPO segment-index fix for padded sequences

Clamp global_document_index_q to num_documents_in_sequence in fast_llm/data/document/token.py.
Padding tokens fall past the last real document, so searchsorted assigned them a phantom
out-of-range segment index, causing a CUDA device-side assert in the GSPO index_add_. Padding
targets are masked, so clamping them onto the last real document contributes zero.

Test plan

  • pytest tests/layers/test_docs_per_step.py

Split from #502.

@jlamypoirier jlamypoirier mentioned this pull request May 19, 2026
4 tasks
@jlamypoirier jlamypoirier changed the title RL training features (#502 minus GSPO) Deepspeed parity hacks May 21, 2026
@jlamypoirier jlamypoirier changed the title Deepspeed parity hacks Deepspeed parity tweaks May 25, 2026
jlamypoirier and others added 2 commits May 27, 2026 15:33
When True, upcasts the LM head linear's input and weight to FP32 before
the matmul, matching vLLM's bf16_last_layer_fp32 quantization. This lets
the trainer compute log-probabilities at the same numerical precision as
the actor's sampling, so the importance-sampling ratio starts near 1.0
instead of being inflated by trainer/actor precision mismatch.

The detached FP32 weight has requires_grad=False, which makes
output_parallel_linear_backward skip the weight-grad path. The FSDP
gradient contract is restored by computing grad_weight explicitly and
accumulating into the original BF16 param's grad_buffer via
accumulate_gradient.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A schedule config field that replaces the static microbatch count with a
runtime document-count target. Matches DeepSpeed's
gradient_accumulation_passes semantics for RL: each microbatch holds one
rollout and the step boundary is set by total rollouts rather than a
fixed microbatch count.

- ScheduleConfig.docs_per_step — when >0, Trainer._prefetch_to_doc_target
  fetches microbatches one at a time, all-reduces the per-microbatch doc
  count, and stops once the global total reaches the target. The final
  step total is broadcast to every microbatch so the loss normalization
  stays consistent.
- Trainer._get_or_build_schedule(N) builds and caches a per-N Schedule
  with _depth_first_override = N // breadth_first_micro_batches, reusing
  the schedule machinery without touching the runner.
- Schedule._eff_{depth_first,sequential_micro_batches,num_inputs} expose
  the effective values under an override.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title Deepspeed parity tweaks Add docs_per_step for dynamic microbatch accumulation May 27, 2026
@jlamypoirier jlamypoirier changed the base branch from main to jlp_fp32_lm_head May 27, 2026 19:37
jlamypoirier and others added 5 commits June 2, 2026 20:08
Surface the per-step document count produced by `_prefetch_to_doc_target`
(the loss-normalization denominator) and the cumulative document total as
training metrics. Lets the dynamic `docs_per_step` accumulation be verified
in production and gives documents-seen as a cross-run x-axis. Gated on
`docs_per_step > 0`; no effect on the static-schedule path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the base branch from jlp_fp32_lm_head to main June 25, 2026 16:53
@jlamypoirier jlamypoirier changed the title Add docs_per_step for dynamic microbatch accumulation Add fp32 LM head, docs_per_step accumulation, and document-count metrics Jun 25, 2026
Padding tokens fall past the last real document, so searchsorted assigned them
a phantom (num_documents+1)-th index, one past the per-segment buffer sized by
num_documents_in_sequence -> CUDA device-side assert in the GSPO index_add_.
Clamp the 1-based document index onto the last real document; padding targets
are masked so the contribution is zero.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant