Skip to content
3 changes: 1 addition & 2 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,14 @@ def _set_target_inputs(

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
def _get_label_counts(self, mask: torch.Tensor) -> torch.Tensor:
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
Expand Down
13 changes: 8 additions & 5 deletions fast_llm/data/document/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,19 @@ def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfi
global_cumulative_lengths = torch.from_numpy(padded_cumsum(self.lengths)).to(
dtype=torch.int32, device=self.device
)
# Exclude the trailing padding "length" from the count.
num_documents_in_sequence = len(self.lengths) - (1 if self.unpadded_length < len(self.tokens) else 0)
# Padding tokens fall past the last real document, so `searchsorted` would assign them the
# phantom (num_documents + 1)-th index — one past the per-segment buffer sized by
# `num_documents_in_sequence`. Clamp them onto the last real document; their target is masked
# so the contribution is zero, and the 1-based index stays within `num_documents_in_sequence`.
model_input.global_document_index_q = torch.searchsorted(
global_cumulative_lengths,
torch.arange(begin, end, device=self.device),
side="right",
out_int32=True,
)
# Exclude the padding "length" from the count.
model_input.num_documents_in_sequence = len(self.lengths) - (
1 if self.unpadded_length < len(self.tokens) else 0
)
).clamp_(max=num_documents_in_sequence)
model_input.num_documents_in_sequence = num_documents_in_sequence

LengthModelInputPreprocessor(
lengths=lengths,
Expand Down
10 changes: 10 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ class ScheduleConfig(Config):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
docs_per_step: int = Field(
default=0,
desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. "
"When >0, each training step dynamically accumulates microbatches until the globally all-reduced "
"document count reaches this value, then triggers the optimizer step. "
"depth_first_micro_batches is ignored when this is set. "
"0 = use depth_first_micro_batches as-is (fixed microbatch count per step).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
breadth_first_micro_batches: int = Field(
default=1,
desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.",
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def _preprocess_data(
if context.schedule.phase.is_training
else None
)
model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)]
n_micro_batches = context.schedule._eff_sequential_micro_batches
model_inputs = [next(data_iterator) for _ in range(n_micro_batches)]
model_inputs[0][0].share_batch_data(
[model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed
)
Expand All @@ -336,7 +337,7 @@ def _preprocess_data(
extra_kwargs={
"grad_output": grad_output,
"micro_batch": micro_batch,
"num_micro_batches": self._config.sequential_micro_batches,
"num_micro_batches": n_micro_batches,
"micro_batch_splits": self._config.micro_batch_splits,
},
)
Expand Down
38 changes: 25 additions & 13 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def __init__(
batch_meta: list[ModelInput],
distributed_config: DistributedConfig,
phase: PhaseType,
_depth_first_override: int | None = None,
):
super().__init__(config)
self._depth_first_override = _depth_first_override
self._multi_stage = multi_stage
self._distributed_config = distributed_config
self._num_stages = len(self._multi_stage.stages)
self._phase = phase
self._is_training = self._phase.is_training

if self._config.num_inputs < self._distributed_config.pipeline_parallel:
if self._eff_num_inputs < self._distributed_config.pipeline_parallel:
warnings.warn("Not enough input to achieve true pipeline parallelism.")

# Setup the activation metas.
Expand Down Expand Up @@ -155,9 +157,25 @@ def __init__(
def phase(self) -> PhaseType:
return self._phase

@property
def _eff_depth_first(self) -> int:
return (
self._depth_first_override
if self._depth_first_override is not None
else self._config.depth_first_micro_batches
)

@property
def _eff_sequential_micro_batches(self) -> int:
return self._eff_depth_first * self._config.breadth_first_micro_batches

@property
def _eff_num_inputs(self) -> int:
return self._eff_sequential_micro_batches * self._config.micro_batch_splits

@property
def samples_per_batch(self) -> int:
return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel
return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel

def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]:
return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank])
Expand Down Expand Up @@ -189,7 +207,7 @@ def _create_index(self) -> None:
Assert.in_range(
step.index,
0,
self._config.num_inputs,
self._eff_num_inputs,
)
Assert.incl(step.type_, (StepType.forward, StepType.backward))
step.global_index = i
Expand All @@ -205,7 +223,7 @@ def _create_index(self) -> None:
Assert.custom(all, self._device_steps)
# Consistency checks
step_map = self._step_map.copy()
for data_index in range(self._config.num_inputs):
for data_index in range(self._eff_num_inputs):
for type_ in (StepType.forward, StepType.backward):
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
assert (
Expand Down Expand Up @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]:
first_grad_stage += 1
else:
first_grad_stage = self._num_stages
for depth_first_micro_batch in range(self._config.depth_first_micro_batches):
for depth_first_micro_batch in range(self._eff_depth_first):
for stage in range(self._num_stages):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in range(self._config.micro_batch_splits):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand All @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]:
for stage in reversed(range(first_grad_stage, self._num_stages)):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in reversed(range(self._config.micro_batch_splits)):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand Down
71 changes: 64 additions & 7 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def setup(self, distributed: Distributed, run: Run) -> None:
preprocessing_config = self._multi_stage.get_preprocessing_config(
PhaseType.training, self._config.schedule.micro_batch_splits
)
self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size)
self._schedule_cache: dict[int, Schedule] = {}
self._schedule = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size),
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
)
Expand All @@ -140,6 +142,41 @@ def setup(self, distributed: Distributed, run: Run) -> None:

self._is_setup = True

def _get_or_build_schedule(self, n_microbatches: int) -> Schedule:
if n_microbatches not in self._schedule_cache:
bfmb = self._config.schedule.breadth_first_micro_batches
depth_first = n_microbatches // bfmb
self._schedule_cache[n_microbatches] = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
_depth_first_override=depth_first,
)
return self._schedule_cache[n_microbatches]

def _prefetch_to_doc_target(self, data_iterator) -> list:
target = self._config.schedule.docs_per_step
bfmb = self._config.schedule.breadth_first_micro_batches
buffer = []
total_docs = 0
while total_docs < target:
mb = next(data_iterator)
mb[0].share_batch_data(mb, self._distributed)
total_docs += mb[0].num_documents_in_batch
buffer.append(mb)
Assert.eq(
len(buffer) % bfmb,
0,
msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}",
)
# Reset num_documents_in_batch to the step total on all microbatches
for mb in buffer:
for mi in mb:
mi.num_documents_in_batch = total_docs
return buffer

@abc.abstractmethod
def _get_data(self) -> Data:
pass
Expand Down Expand Up @@ -184,6 +221,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
skipped_iters = 0
nan_iters = 0
total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions}
total_documents_seen = 0

# Profiling
profiler = self._config.profiling.get_profiler(
Expand Down Expand Up @@ -220,12 +258,26 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:

# TODO: Data loader hates getting all micro-batches at once.
# (Also preprocessing adds overhead)
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
if self._config.schedule.docs_per_step > 0:
buffer = self._prefetch_to_doc_target(train_iterator)
# `_prefetch_to_doc_target` broadcasts the step document total onto every microbatch.
step_num_documents = buffer[0][0].num_documents_in_batch
total_documents_seen += step_num_documents
step_schedule = self._get_or_build_schedule(len(buffer))
reduced_losses, update_successful, train_metrics = self._runner.run_step(
iter(buffer),
step_schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
else:
step_num_documents = None
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)

# Advanced, skipped, and Nan iterations.
if update_successful:
Expand Down Expand Up @@ -257,6 +309,11 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
metrics_key = PhaseType.training
metrics[metrics_key] = {
"batch_size": self._batch_size,
**(
{"num_documents": step_num_documents, "documents_seen": total_documents_seen}
if step_num_documents is not None
else {}
),
**{
name: (value / advanced_iters if advanced_iters > 0 else float("nan"))
for name, value in total_losses.items()
Expand Down
7 changes: 7 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig):
hint=FieldHint.architecture,
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)
fp32_lm_head: bool = Field(
default=False,
desc="Upcast input and weight to float32 before the lm_head linear. "
"Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs "
"are computed at the same numerical precision, keeping the IS ratio near 1 at init.",
hint=FieldHint.feature,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
Expand Down
34 changes: 28 additions & 6 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.tensor import TensorMeta
from fast_llm.tensor import TensorMeta, accumulate_gradient
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial(
split_index: int = 0,
return_logits: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if self._config.fp32_lm_head:
input_dtype = input_.dtype
input_ = input_.to(torch.float32)
# detach → requires_grad=False → output_parallel_linear_backward skips weight grad
weight = self.output_weights.detach().to(torch.float32)
else:
weight = self.output_weights

logits, context = output_parallel_linear_forward(
input_=input_,
weight=self.output_weights,
weight=weight,
bias=None,
group=self._parallel_dim.group if self._vocab_parallel else None,
sequence_parallel=self._sequence_parallel and self._vocab_parallel,
Expand Down Expand Up @@ -285,12 +293,26 @@ def _logits_loss_forward_backward_partial(
if loss_value is not None:
losses_.append(loss_value.detach())

if grad is not None and self._config.final_logit_softcap is not None:
if not self.training or grad is None:
return sum(losses_) if losses_ else None, None

if self._config.final_logit_softcap is not None:
grad = _softcap_backward(grad, logits, self._config.final_logit_softcap)

return sum(losses_) if losses_ else None, (
output_parallel_linear_backward(grad, context) if self.training else None
)
input_grad = output_parallel_linear_backward(grad, context)
if self._config.fp32_lm_head:
# Weight grad was skipped because weight.requires_grad=False; accumulate manually.
# context: (input_, weight, bias, group, sequence_parallel, ...)
saved_input = context[0]
if context[4]: # sequence_parallel
from fast_llm.core.ops import gather_op

saved_input = gather_op(saved_input, context[3], dim=0)
grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2))
accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype))
input_grad = input_grad.to(input_dtype)

return sum(losses_) if losses_ else None, input_grad

def get_loss_definitions(self) -> list[LossDef]:
return [
Expand Down
Loading
Loading