From 19c6c8aae71115b632c94c322099be16345a19e5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 17:33:55 -0400 Subject: [PATCH 01/33] fixes --- fast_llm/layers/ssm/gdn.py | 2 +- fast_llm_external_models/apriel2/modeling_apriel2.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index f694d80a6..cf5bc0bc4 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -227,7 +227,7 @@ def __init__( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - if _fast_gdn_available: + if _fast_gdn_available and distributed_config.use_cuda: self.chunk_gated_delta_rule = chunk_gated_delta_rule else: logger.warning( diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index ea0611953..9e82dfc4f 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -2839,7 +2839,7 @@ def forward( # Reshape back to [batch, num_patches, text_hidden] image_features = image_features.squeeze(0).view(batch_size, num_patches_per_image, -1) - return image_features, (*all_hidden_states, hidden_states, image_features) + return image_features, (*all_hidden_states, hidden_states, image_features) if output_hidden_states else None class SimpleMLP(nn.Module): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42802f1c7..3e6910b6f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -802,7 +802,7 @@ def update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), - requires_cuda=False, + requires_cuda=True, # GDN available on CPU, but not in the converted model (also runs very slow). ) _gdn_block = MODEL_CONFIGS["apriel2_gdn"].config_dict["model"]["base_model"]["decoder"]["block"] From 1b6fcd01864263bb408d9f467c1529ac6fb4d2ad Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 17:39:32 -0400 Subject: [PATCH 02/33] fix --- tests/utils/distributed_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index b085f0994..933ea8f8e 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -70,7 +70,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon if torch.cuda.is_available() else { (None, "norm"): get_config(ignore_tensors=True), - (None, "word_embeddings_weight"): get_config(8e-2, 1e-4), + (None, "embeddings_weight"): get_config(8e-2, 1e-4), } ), (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(2e-2, 2e-3), From 573c6d84e7be8bdcfa39d6c4aa2a4bccde807f83 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 19:59:21 -0400 Subject: [PATCH 03/33] fix --- fast_llm/data/dataset/streaming.py | 3 +++ tests/models/test_streaming.py | 3 ++- tests/utils/redis.py | 2 -- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 8835612ec..e3fce4eb3 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,5 +1,6 @@ import functools import json +import logging import time import typing @@ -14,6 +15,8 @@ from fast_llm.data.document.token_data import TokenDataDocument from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + @config_class() class RedisStreamingDocumentData(Config): diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 7b39a62f2..0c40f0a48 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -132,7 +132,7 @@ def _run_model_streaming_configs( model_testing_config, None, updates={ - ("data", "datasets"): {"training": {"port": port}}, + ("data", "datasets"): {"training": {"port": port, "timeout": 1.0}}, ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { @@ -143,6 +143,7 @@ def _run_model_streaming_configs( "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, + "timeout": 1.0, } }, # Disable tensor logging. diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 8160ef8c0..2dc09bee2 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -66,8 +66,6 @@ def producer_loop(): @contextlib.contextmanager def fake_redis_server(config: RedisConfig): - # We search for free port as port from previous test can still be not free even after server shutdown - # ----- Monkey-patch handler to suppress broken pipes ----- orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle From 3658c028017ad4791209e33fdd82c4ef42347bbe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Mar 2026 20:29:56 -0400 Subject: [PATCH 04/33] fix --- fast_llm_external_models/tests/test_apriel2/test_equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c5268f23c..8734aa02c 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -481,7 +481,7 @@ def test_batch_processing_behavior(self, model_pair): with torch.no_grad(): # Batch processing batch_src = get_pixtral_vision_features(source, pixel_values) - batch_tgt, _ = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) + batch_tgt = target.get_image_features(pixel_values)[0].view(-1, batch_src.shape[-1]) # Sequential processing singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] From ab39e26fae0156bd91d035492184ec33ef36066f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 24 Mar 2026 06:21:48 -0400 Subject: [PATCH 05/33] stuff --- fast_llm/data/document/abstract.py | 10 ++ fast_llm/data/document/config.py | 7 +- fast_llm/data/document/language_model.py | 145 ++++++++++-------- fast_llm/data/document/range.py | 6 +- fast_llm/data/document/token.py | 41 ++++- fast_llm/engine/base_model/base_model.py | 5 +- fast_llm/engine/inference/runner.py | 6 +- fast_llm/engine/schedule/runner.py | 22 +-- fast_llm/engine/schedule/schedule.py | 14 +- fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/loss/config.py | 1 + fast_llm/models/gpt/huggingface.py | 5 +- fast_llm/models/gpt/model.py | 79 +++++----- 13 files changed, 202 insertions(+), 140 deletions(-) diff --git a/fast_llm/data/document/abstract.py b/fast_llm/data/document/abstract.py index 85014452f..06ea0534b 100644 --- a/fast_llm/data/document/abstract.py +++ b/fast_llm/data/document/abstract.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: import torch + from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import TensorMeta @@ -59,6 +60,15 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.presents: self.presents, } + @classmethod + def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distributed"): + """ + Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents. + Should be called in the main process because distributed operations are not available during preprocessing. + Implemented as a class method so quantities shared by all models inputs are only computed once. + TODO: ====== Use as entry point for batch broadcasting? ====== + """ + @dataclasses.dataclass(kw_only=True) class Batch(Document): diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 8967227e8..352311b51 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -29,6 +29,11 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): return_position_index: bool = Field(default=False) +@config_class() +class TokenPreprocessingConfig(LengthPreprocessingConfig): + return_document_count: bool = Field(default=False) + + @config_class() class ImageNormalizationConfig(Config): scale: float = Field(default=255.0) @@ -62,7 +67,7 @@ def get_batch_meta(self, size: int = 1) -> "PatchBatch": @config_class() -class LanguageModelBatchPreprocessingConfig(LengthPreprocessingConfig): +class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): _abstract = False phase: PhaseType = Field(default=PhaseType.training) micro_batch_splits: int = Field(default=1) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 00040e576..076f3abb3 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -4,12 +4,14 @@ import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import ModelInput from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.patch import PatchBatch, PatchDocument, PatchModelInput from fast_llm.data.document.range import RangeBatch, RangeDocument from fast_llm.data.document.token import TokenBatch, TokenDocument, TokenModelInput from fast_llm.data.document.token_data import TokenDataBatch, TokenDataDocument +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.utils import div @@ -33,6 +35,20 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + num_labels: int | None = None + num_labels_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelTargetInput]", distributed: "Distributed"): + if model_inputs[0].num_labels is not None and model_inputs[0].num_labels_in_batch is None: + # We sum over sequences but not within a sequence. + num_labels_in_batch = allreduce_scalar( + sum(model_input.num_labels for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_labels_in_batch = num_labels_in_batch @dataclasses.dataclass(kw_only=True) @@ -40,6 +56,15 @@ class LanguageModelInput(TokenModelInput): targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) image_patches: PatchModelInput | None = None + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: "Distributed"): + super().share_batch_data(model_inputs, distributed) + for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True): + targets[0].share_batch_data(targets, distributed) + model_inputs[0].image_patches.share_batch_data( + [model_input.image_patches for model_input in model_inputs], distributed + ) + def set_children_attributes(self) -> None: if self.image_patches is not None: self.image_patches.set_parent_attributes(self) @@ -58,6 +83,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: out.update(self.image_patches.to_kwargs()) @@ -113,6 +139,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis ) ): model_input = self._get_model_input(sequence_k_past, sequence_k_past + local_input_length, config) + model_input.phase = config.phase + + if config.use_image_patches: + model_input.image_patches = self.image_patches.get_model_input( + sequence_k_past, sequence_k_past + local_input_length, config.vision_encoder + ) model_input.pasts = presents presents = None if micro_sequence_index == config.micro_batch_splits - 1 else [] @@ -121,73 +153,66 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis model_inputs.append(model_input) + self._set_target_inputs(model_inputs, config) + return model_inputs - def _get_model_input( - self, begin: int, end: int, config: LanguageModelBatchPreprocessingConfig - ) -> LanguageModelInput: - model_input = super()._get_model_input(begin, end, config) - model_input.phase = config.phase + def _set_target_inputs( + self, model_inputs: list[LanguageModelInput], config: LanguageModelBatchPreprocessingConfig + ): + labels = self.tokens.clone() - if config.use_image_patches: - model_input.image_patches = self.image_patches.get_model_input(begin, end, config.vision_encoder) + # Apply loss masking spans. + if config.use_loss_masking_spans and self.loss_masking_spans is not None: + for span_begin, span_end in self.loss_masking_spans.ranges: + labels[span_begin:span_end] = -100 for prediction_distance in range(1, config.num_labels + 1): - label_begin = begin + prediction_distance - label_end = end + prediction_distance - # Keep complete documents to simplify preprocessing. - _, first_document_begin, last_document_end = self._get_cropped_lengths(begin, label_end) - cropped_lengths, _, _ = self._get_cropped_lengths(first_document_begin, last_document_end) - labels = self.tokens[first_document_begin:last_document_end].clone() - labels_in_range = labels[label_begin - first_document_begin : label_end - first_document_begin] - - # Apply loss masking spans. - if config.use_loss_masking_spans and self.loss_masking_spans is not None: - for span_begin, span_end in self.loss_masking_spans.get_cropped_ranges( - first_document_begin, last_document_end - ): - labels[span_begin:span_end] = -100 - # Mask cross-document predictions. document_begin = 0 - for length in cropped_lengths: - labels[document_begin : document_begin + prediction_distance] = -100 + for length in self.lengths: + labels[document_begin + prediction_distance - 1] = -100 document_begin += length - if config.return_label_counts: - # Count the number of non-masked labels in each document through cumulative sums. - mask = labels >= 0 - mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) - length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) - label_count_cumsum = mask_cumsum[length_cumsum] - labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] - # Expand to one entry per token: find each token's document index via the sorted - # length cumsum, then look up that document's label count. - # TODO: Document index already computed in `LengthModelInputPreprocessor`. - document_index = torch.searchsorted( - length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" - ) - label_counts = labels_per_document[document_index][ - label_begin - first_document_begin : label_end - first_document_begin - ] - mask = ( - mask[label_begin - first_document_begin : label_end - first_document_begin] - if config.return_prediction_mask - else None - ) - else: - label_counts = None - mask = labels_in_range >= 0 if config.return_prediction_mask else None - - # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. - target_input = LanguageModelTargetInput(tokens=labels_in_range, mask=mask, label_counts=label_counts) - - if config.use_grpo_data and not model_input.is_meta: - target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end) - target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( - label_begin, label_end + prediction_labels = labels[ + prediction_distance : len(self.tokens) - config.num_labels + prediction_distance + ].clone() + mask = prediction_labels >= 0 + label_counts = self._get_label_counts(mask) if config.return_label_counts else None + + for input_index, model_input in enumerate(model_inputs): + begin = model_input.sequence_k_dim.size + end = begin + model_input.token_dim.size + + # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. + target_input = LanguageModelTargetInput( + tokens=labels[begin:end], + mask=mask[begin:end] if config.return_prediction_mask else None, + label_counts=label_counts[begin:end] if config.return_label_counts else None, + # Set value for the first input only so `share_batch_data` generated the correct sum. + # TODO: ====== Make optional? + num_labels=mask.sum(dtype=torch.int32).item() if input_index == 0 else 0, ) - - model_input.targets.append(target_input) - - return model_input + if config.use_grpo_data and not model_input.is_meta: + target_input.advantages = self.advantages.get_cropped_data( + begin + prediction_distance, end + prediction_distance + ) + target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( + begin + prediction_distance, end + prediction_distance + ) + + model_input.targets.append(target_input) + + def _get_label_counts(self, mask: torch.Tensor): + # Count the number of non-masked labels in each document through cumulative sums. + mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) + length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) + label_count_cumsum = mask_cumsum[length_cumsum] + labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] + # Expand to one entry per token: find each token's document index via the sorted + # length cumsum, then look up that document's label count. + # TODO: Document index already computed in `LengthModelInputPreprocessor`. + document_index = torch.searchsorted( + length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" + ) + return labels_per_document[document_index] diff --git a/fast_llm/data/document/range.py b/fast_llm/data/document/range.py index ea5d0e7fd..ed2503455 100644 --- a/fast_llm/data/document/range.py +++ b/fast_llm/data/document/range.py @@ -32,6 +32,6 @@ def from_documents( document_begin += size return cls(ranges=ranges) if ranges else None - def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) - return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] + # def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: + # cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) + # return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] diff --git a/fast_llm/data/document/token.py b/fast_llm/data/document/token.py index 1871b2c83..8aeabb694 100644 --- a/fast_llm/data/document/token.py +++ b/fast_llm/data/document/token.py @@ -1,12 +1,14 @@ import dataclasses -import functools import typing import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import Batch, Document from fast_llm.data.document.block import BlockModelInput, LengthModelInputPreprocessor -from fast_llm.data.document.config import LengthPreprocessingConfig +from fast_llm.data.document.config import TokenPreprocessingConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -22,12 +24,34 @@ def __len__(self) -> int: def device(self) -> torch.device: return self.tokens.device + @property + def is_meta(self) -> bool: + return self.device.type == "meta" + @dataclasses.dataclass(kw_only=True) class TokenModelInput(BlockModelInput, TokenDocument): - @functools.cached_property - def is_meta(self) -> bool: - return isinstance(self.tokens, TensorMeta) + num_documents: int | None = None + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[TokenModelInput]", distributed: "Distributed"): + if model_inputs[0].num_documents is not None and model_inputs[0].num_documents_in_batch is None: + # We sum over sequences but not within a sequence. + num_documents_in_batch = allreduce_scalar( + sum(model_input.num_documents for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_documents_in_batch = num_documents_in_batch + + def to_kwargs(self) -> dict[str, typing.Any]: + # TODO: Avoid conversion, use `LanguageModelMicroBatch` directly instead. + return { + **super().to_kwargs(), + LanguageModelKwargs.num_documents_in_batch: self.num_documents_in_batch, + } @dataclasses.dataclass(kw_only=True) @@ -74,10 +98,13 @@ def _get_cropped_lengths(self, begin: int, end: int) -> tuple[list[int], int, in return lengths, first_document_begin, document_end - def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConfig): + def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfig): model_input = self._model_input_class(tokens=self.tokens[begin:end]) lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end) + if config.return_document_count: + model_input.num_documents = len(self.lengths) if begin == 0 else 0 + LengthModelInputPreprocessor( lengths=lengths, sequence_k_past=begin, @@ -89,7 +116,7 @@ def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConf ).preprocess(model_input, config) Assert.eq(model_input.token_dim.size, end - begin) - if self.tokens.device.type == "meta": + if self.is_meta: model_input.tokens = TensorMeta.from_dims( (model_input.token_dim,), tensor_name=f"tokens_{begin}_to_{end}", dtype=torch.int64 ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index a12b68c17..a9f6887c7 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -178,14 +178,13 @@ def __init__( @abc.abstractmethod def preprocess_batch( self, - model_inputs: list[ModelInput], + model_input: ModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: + ) -> tuple[torch.Tensor, dict]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index f3b16c647..d9ed695ec 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -1,6 +1,7 @@ import abc import typing +from fast_llm.data.document.abstract import ModelInput from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import ScheduleConfig @@ -57,15 +58,14 @@ def setup(self): Assert.is_(self._runner._distributed, self._fast_llm_model.distributed) def forward( - self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False + self, model_input: ModelInput, *, iteration: int = 1, return_metrics: bool = False ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: # TODO: Return an actual model output. reduced_losses, update_successful, metrics = self._runner.run_step( - iter((((input_, kwargs),),)), + iter(((model_input,),)), self._schedule, iteration=iteration, return_metrics=return_metrics, - preprocessed=True, ) assert update_successful return reduced_losses, metrics diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 24b8b3d63..20a777a70 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -146,7 +146,6 @@ def run_step( *, iteration: int = 1, return_metrics: bool = False, - preprocessed: bool = False, ) -> tuple[dict[str, float | int], bool, dict[str, typing.Any] | None]: assert self._is_setup assert schedule._config is self._config # Noqa @@ -161,7 +160,7 @@ def run_step( losses={loss_def: [] for loss_def in self._loss_definitions}, metrics=metrics, ) - context.data_iterator = self._preprocess_data(context, data_iterator, preprocessed) + context.data_iterator = self._preprocess_data(context, data_iterator) if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( @@ -328,16 +327,20 @@ def _train_step(self, context: BatchContext, step: Step) -> None: self._reduce(context, step) def _preprocess_data( - self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool + self, context: BatchContext, data_iterator: typing.Iterator ) -> typing.Generator[None, None, None]: grad_output = ( self._optimizer.grad_scale / self._config.num_inputs if context.schedule.phase.is_training else None ) - for micro_batch in range(self._config.sequential_micro_batches): - micro_batch_data = next(data_iterator) - if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess_batch( - micro_batch_data, + model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + if not preprocessed: + model_inputs[0][0].share_batch_data(model_inputs, self._distributed) + + for micro_batch, model_inputs_ in enumerate(model_inputs): + Assert.eq(len(model_inputs_), self._config.micro_batch_splits) + for micro_batch_split, model_input in enumerate(model_inputs_): + input_, kwargs = self._multi_stage.base_model.preprocess_batch( + model_input, phase=context.phase, iteration=context.iteration, metrics=context.metrics, @@ -347,10 +350,7 @@ def _preprocess_data( "num_micro_batches": self._config.sequential_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, - device=self._distributed.device, ) - Assert.eq(len(micro_batch_data), self._config.micro_batch_splits) - for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): kwargs.update(micro_batch_split=micro_batch_split) data_index = micro_batch * self._config.micro_batch_splits + micro_batch_split if self._stages_owned[0]: diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index bc425520f..361772818 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -127,12 +127,14 @@ def __init__( warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. - self._preprocessed_meta = self._multi_stage.base_model.preprocess_batch( - batch_meta, - phase=self._phase, - iteration=0, - device=None, - ) + self._preprocessed_meta = [ + self._multi_stage.base_model.preprocess_batch( + model_input, + phase=self._phase, + iteration=0, + ) + for model_input in batch_meta + ] self._steps, self._first_grad_stage = self._create_steps() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a199ad154..4a8efdab6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -24,6 +24,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" + num_documents_in_batch = "num_documents_in_batch" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 99d4bce9a..5168aecfb 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs): advantages = "advantages" old_log_probabilities = "old_log_probabilities" label_counts = "num_labels_in_seq" + num_labels_in_batch = "num_labels_in_batch" @config_class(registry=True) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index def664d66..a53c234f0 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -114,11 +114,10 @@ def _inner_forward( use_cache, output_hidden_states, ) - ((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch( - [model_input], + input_, kwargs = self.fast_llm_base_model.preprocess_batch( + model_input, phase=PhaseType.inference, iteration=iteration, - device=self._fast_llm_model.distributed.device, ) self._inference_runner.forward(input_, kwargs, iteration=iteration) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index fc4537ee7..a21bdee7e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,3 +1,4 @@ +import dataclasses import functools import logging import re @@ -42,54 +43,46 @@ def __init__( def preprocess_batch( self, - model_inputs: list[LanguageModelInput], + model_input: LanguageModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: - reference_preprocessed_batches = {} - for name, reference_model in self._reference_models.items(): - reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( - model_inputs, - phase=PhaseType.inference, - iteration=iteration, - device=device, - ) - - preprocessed = [] - for input_index, model_input in enumerate(model_inputs): - if device is not None: - model_input.to_device_(device) - kwargs = model_input.to_kwargs() - kwargs[LanguageModelKwargs.iteration] = iteration - if extra_kwargs is not None: - Assert.empty(kwargs.keys() & extra_kwargs.keys()) - kwargs.update(extra_kwargs) - if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) - - if not model_input.is_meta: - for name, reference_model in self._reference_models.items(): - reference_tokens, reference_kwargs = reference_preprocessed_batches[name][input_index] - if name in self._decoder_reference_models: - # TODO: Get the actual names - reference_kwargs[BlockKwargs.output_hidden_states].add( - re.compile(r"decoder\.\d+\.mixer_output$") - ) - - reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - - kwargs[f"reference_{name}_hidden_states"] = { - layer_name: tensor - for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() - } - self.preprocess(kwargs) - preprocessed.append((model_input.tokens, kwargs)) - - return preprocessed + ) -> tuple[torch.Tensor, dict]: + if not model_input.is_meta: + model_input.to_device_(self._distributed.device) + kwargs = model_input.to_kwargs() + kwargs[LanguageModelKwargs.iteration] = iteration + if extra_kwargs is not None: + Assert.empty(kwargs.keys() & extra_kwargs.keys()) + kwargs.update(extra_kwargs) + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) + + if not model_input.is_meta: + for name, reference_model in self._reference_models.items(): + output_hidden_states = set() + if name in self._head_reference_models: + output_hidden_states.add(re.compile(r"head\..*logits.*$")) + if name in self._decoder_reference_models: + # TODO: Get the actual names + output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$")) + assert len(output_hidden_states) >= 1 + reference_model_input = dataclasses.replace( + model_input, + output_hidden_states=output_hidden_states, + hidden_states={}, + ) + reference_model_input.set_children_attributes() + reference_model.forward(model_input, iteration=iteration) + + kwargs[f"reference_{name}_hidden_states"] = { + layer_name: tensor for layer_name, (meta, tensor) in reference_model_input.hidden_states.items() + } + self.preprocess(kwargs) + + return model_input.tokens, kwargs def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`? From 2255845af53dcd399abf8c0bb213f9543022b089 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Mar 2026 09:13:39 -0400 Subject: [PATCH 06/33] stuff --- fast_llm/core/ops.py | 2 +- fast_llm/data/document/abstract.py | 1 + fast_llm/data/document/language_model.py | 32 ++++---- fast_llm/engine/base_model/base_model.py | 4 +- fast_llm/engine/base_model/config.py | 77 ++++++++++++++++++- fast_llm/engine/schedule/runner.py | 41 ++++------ fast_llm/functional/entropy_loss.py | 12 +-- fast_llm/functional/linear.py | 2 +- fast_llm/functional/triton/entropy_loss.py | 44 +++++------ fast_llm/functional/triton/mlp.py | 2 +- fast_llm/functional/triton/normalization.py | 2 +- fast_llm/functional/triton/rotary.py | 2 +- fast_llm/functional/triton/sparse_copy.py | 2 +- fast_llm/functional/triton/z_loss.py | 5 +- fast_llm/functional/{autograd.py => utils.py} | 9 +++ fast_llm/layers/attention/attention.py | 2 +- fast_llm/layers/block/sequence.py | 10 +-- fast_llm/layers/common/linear/linear.py | 2 +- fast_llm/layers/common/peft/lora.py | 2 +- fast_llm/layers/decoder/block.py | 18 +---- .../layers/decoder/mlp/mixture_of_experts.py | 20 +---- fast_llm/layers/decoder/stochastic_mixer.py | 6 +- fast_llm/layers/language_model/head.py | 12 +-- .../layers/language_model/language_model.py | 10 +-- .../language_model/loss/entropy_loss.py | 2 + fast_llm/layers/language_model/loss/grpo.py | 20 ++--- fast_llm/layers/language_model/loss/loss.py | 19 ++--- fast_llm/layers/language_model/loss/z_loss.py | 11 ++- .../language_model/multi_token_prediction.py | 6 +- fast_llm/layers/vision/vision_encoder.py | 12 +-- fast_llm/logging.py | 2 +- fast_llm/models/gpt/huggingface.py | 15 +--- fast_llm/models/gpt/model.py | 2 +- tests/layers/test_lm_head.py | 5 +- 34 files changed, 216 insertions(+), 197 deletions(-) rename fast_llm/functional/{autograd.py => utils.py} (91%) diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index 7d361a22e..46dea8fce 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -16,7 +16,7 @@ def reduce_op( - input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False + input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp.RedOpType = ReduceOp.SUM, async_op: bool = False ) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: if group: handle = all_reduce(input_, group=group, async_op=async_op, op=op) diff --git a/fast_llm/data/document/abstract.py b/fast_llm/data/document/abstract.py index 06ea0534b..6f546e9c3 100644 --- a/fast_llm/data/document/abstract.py +++ b/fast_llm/data/document/abstract.py @@ -66,6 +66,7 @@ def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distri Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents. Should be called in the main process because distributed operations are not available during preprocessing. Implemented as a class method so quantities shared by all models inputs are only computed once. + Note: this may be called more than once (ex. reference model preprocessing), so the method should be idempotent. TODO: ====== Use as entry point for batch broadcasting? ====== """ diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 076f3abb3..8f5c98801 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -61,9 +61,10 @@ def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: super().share_batch_data(model_inputs, distributed) for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True): targets[0].share_batch_data(targets, distributed) - model_inputs[0].image_patches.share_batch_data( - [model_input.image_patches for model_input in model_inputs], distributed - ) + if model_inputs[0].image_patches is not None: + model_inputs[0].image_patches.share_batch_data( + [model_input.image_patches for model_input in model_inputs], distributed + ) def set_children_attributes(self) -> None: if self.image_patches is not None: @@ -174,31 +175,28 @@ def _set_target_inputs( labels[document_begin + prediction_distance - 1] = -100 document_begin += length - prediction_labels = labels[ - prediction_distance : len(self.tokens) - config.num_labels + prediction_distance - ].clone() - mask = prediction_labels >= 0 + mask = labels >= 0 label_counts = self._get_label_counts(mask) if config.return_label_counts else None for input_index, model_input in enumerate(model_inputs): - begin = model_input.sequence_k_dim.size - end = begin + model_input.token_dim.size + label_end = model_input.sequence_k_dim.size + prediction_distance + label_begin = label_end - model_input.token_dim.size # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. target_input = LanguageModelTargetInput( - tokens=labels[begin:end], - mask=mask[begin:end] if config.return_prediction_mask else None, - label_counts=label_counts[begin:end] if config.return_label_counts else None, + tokens=labels[label_begin:label_end].clone(), + mask=mask[label_begin:label_end] if config.return_prediction_mask else None, + label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None, # Set value for the first input only so `share_batch_data` generated the correct sum. # TODO: ====== Make optional? - num_labels=mask.sum(dtype=torch.int32).item() if input_index == 0 else 0, + num_labels=( + len(mask) if self.is_meta else mask.sum(dtype=torch.int32).item() if input_index == 0 else 0 + ), ) if config.use_grpo_data and not model_input.is_meta: - target_input.advantages = self.advantages.get_cropped_data( - begin + prediction_distance, end + prediction_distance - ) + target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end) target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( - begin + prediction_distance, end + prediction_distance + label_begin, label_end ) model_input.targets.append(target_input) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index a9f6887c7..4cb529463 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -48,11 +48,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c out += layer.get_compute_usage(input_, kwargs, config) return out - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: losses = [] for layer in self.get_layers(): if layer is not self: - losses += layer.get_loss_definitions(count) + losses += layer.get_loss_definitions() return losses def get_preprocessing_config(self) -> dict[str, typing.Any]: diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 0526b9dc2..30d783199 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -1,5 +1,6 @@ import abc import dataclasses +import enum import typing from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class @@ -11,6 +12,7 @@ import torch from fast_llm.engine.base_model.base_model import BaseModel + from fast_llm.engine.distributed.distributed import Distributed @config_class() @@ -103,12 +105,79 @@ class ResourceUsageConfig: backward: int = 1 +class ReductionType(enum.StrEnum): + """ + An enum to represent data types independently of third party libraries, + so we can swap them more easily and allow for lazy imports. + """ + + sum = "float64" + average = "float32" + minimum = "float16" + maximum = "bfloat16" + + @property + def torch(self) -> "typing.Callable[[torch.Tensor], torch.Tensor]": + if not _TORCH_REDUCTION_MAP: + _set_torch_reduction_map() + return _TORCH_REDUCTION_MAP[self] + + @property + def distributed(self) -> "torch.distributed.ReduceOp.RedOpType": + if not _DISTRIBUTED_REDUCTION_MAP: + _set_distributed_reduction_map() + return _DISTRIBUTED_REDUCTION_MAP[self] + + +_TORCH_REDUCTION_MAP: dict[ReductionType, "typing.Callable[[torch.Tensor], torch.Tensor]"] = {} + + +def _set_torch_reduction_map() -> None: + import torch + + global _TORCH_REDUCTION_MAP + + _TORCH_REDUCTION_MAP = { + ReductionType.sum: torch.sum, + ReductionType.average: torch.mean, + ReductionType.minimum: torch.min, + ReductionType.maximum: torch.max, + } + + +_DISTRIBUTED_REDUCTION_MAP: dict[ReductionType, "torch.distributed.ReduceOp.RedOpType"] = {} + + +def _set_distributed_reduction_map() -> None: + import torch + + global _TORCH_REDUCTION_MAP + + _TORCH_REDUCTION_MAP = { + ReductionType.sum: torch.distributed.ReduceOp.SUM, + ReductionType.average: torch.distributed.ReduceOp.AVG, + ReductionType.minimum: torch.distributed.ReduceOp.MIN, + ReductionType.maximum: torch.distributed.ReduceOp.MAX, + } + + @dataclasses.dataclass() class LossDef: # A name for the loss name: str - formatted_name: str - # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. - count: int = 1 dtype: DataType = DataType.float32 + reduction: ReductionType = ReductionType.sum + + def reduce(self, losses: list[torch.Tensor], distributed: "Distributed") -> torch.Tensor | None: + from fast_llm.core.ops import reduce_op + + if losses or distributed.pipeline_group: + if losses: + reduced_loss = losses[0] if len(losses) == 1 else self.reduction.torch(torch.stack(losses)) + reduce_op(reduced_loss, group=distributed.data_group, op=self.reduction.distributed) + else: + reduced_loss = torch.zeros([1], dtype=self.dtype.torch, device=distributed.device) + reduce_op(reduced_loss, group=distributed.pipeline_group, op=self.reduction.distributed) + return reduced_loss + else: + return None diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 20a777a70..7ad03b24c 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -284,30 +284,12 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: - reduced_losses = {} - for name, losses in context.losses.items(): - if losses or self._distributed.pipeline_group: - if losses: - loss_count = ( - self._loss_definitions[name].count - * self._distributed_config.data_parallel - * context.schedule.config.num_inputs - ) - reduced_loss = torch.stack(losses).sum() / loss_count - if self._distributed.data_group: - all_reduce(reduced_loss, group=self._distributed.data_group) - else: - reduced_loss = torch.zeros( - [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device - ) - if self._distributed.pipeline_group: - all_reduce(reduced_loss, group=self._distributed.pipeline_group) - else: - reduced_loss = 0.0 - reduced_losses[name] = reduced_loss + reduced_losses = { + name: self._loss_definitions[name].reduce(losses, self._distributed) + for name, losses in context.losses.items() + } return { - name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss - for name, reduced_loss in reduced_losses.items() + name: 0.0 if reduced_loss is None else reduced_loss.item() for name, reduced_loss in reduced_losses.items() } def _train_step(self, context: BatchContext, step: Step) -> None: @@ -329,12 +311,17 @@ def _train_step(self, context: BatchContext, step: Step) -> None: def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator ) -> typing.Generator[None, None, None]: + # We multiply by the data-parallel size to improve numerical stability (reduce numerical underflow). + # This factor is canceled in the averaging during gradient reduction. grad_output = ( - self._optimizer.grad_scale / self._config.num_inputs if context.schedule.phase.is_training else None + self._optimizer.grad_scale * self._distributed_config.data_parallel + if context.schedule.phase.is_training + else None ) model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] - if not preprocessed: - model_inputs[0][0].share_batch_data(model_inputs, self._distributed) + model_inputs[0][0].share_batch_data( + [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed + ) for micro_batch, model_inputs_ in enumerate(model_inputs): Assert.eq(len(model_inputs_), self._config.micro_batch_splits) @@ -408,7 +395,7 @@ def _recv(self, context: BatchContext, step: Step) -> None: step.recv_event.wait() self._record_event(context, EventType.compute_wait_pipe, step) - def _forward(self, context: BatchContext, step: Step) -> None: + def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None: output, grad_context = self._stages[step.stage].forward( self._get_forward_input(context, step), context.batch[step.index], diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 65dcee32b..ea2ba9299 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -2,6 +2,7 @@ from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce from fast_llm.functional.config import EntropyLossType, TargetFormat +from fast_llm.functional.utils import reduce_losses from fast_llm.utils import Assert @@ -285,6 +286,7 @@ def fused_entropy_loss_forward_backward( temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -296,7 +298,7 @@ def fused_entropy_loss_forward_backward( assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) assert loss_mask is None loss_mask = target >= 0 - per_sample_loss, grad = _fused_cross_entropy_base_from_labels( + losses, grad = _fused_cross_entropy_base_from_labels( logits, target, loss_mask, @@ -305,7 +307,7 @@ def fused_entropy_loss_forward_backward( group, ) elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): - per_sample_loss, grad = _fused_cross_entropy_base_from_distribution( + losses, grad = _fused_cross_entropy_base_from_distribution( logits, target, grad_output, @@ -316,7 +318,7 @@ def fused_entropy_loss_forward_backward( return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - per_sample_loss, grad = _fused_reverse_kl_base_from_distribution( + losses, grad = _fused_reverse_kl_base_from_distribution( logits, target, grad_output, @@ -328,9 +330,7 @@ def fused_entropy_loss_forward_backward( else: raise NotImplementedError(entropy_loss_type) - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(losses, divisor, loss_mask) if grad is not None: if loss_mask is not None: diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index 38658ffc5..e1742a1bb 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -8,7 +8,6 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.functional.triton.sparse_linear import ( @@ -17,6 +16,7 @@ input_row_sparse_matmul, output_sparse_matmul, ) +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import accumulate_gradient, param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index 3d9937439..4ed661606 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -2,6 +2,7 @@ from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import reduce_losses @triton_jit() @@ -656,7 +657,7 @@ def _cross_entropy_loss_from_labels( sum_exp_logits: torch.Tensor, max_logits: torch.Tensor, ) -> torch.Tensor: - return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() + return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0) @torch.compile @@ -700,6 +701,7 @@ def triton_entropy_loss_forward_backward( entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, block_size: int | None = None, num_warps: int | None = None, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -745,23 +747,22 @@ def triton_entropy_loss_forward_backward( **kwargs, **backward_kwargs, ) - loss = losses.mean() else: - partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(partial_losses) - sum_exp_logits = torch.empty_like(partial_losses) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, max_logits_ptr=local_max_logits, sum_exp_logits_ptr=sum_exp_logits, - predicted_logits_ptr=partial_losses, + predicted_logits_ptr=losses, col_min=n_cols * group.rank(), **kwargs, ) max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) - torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, group=group) + losses = _cross_entropy_loss_from_labels(losses, target, sum_exp_logits, max_logits) if grad_output is not None: triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, @@ -798,14 +799,13 @@ def triton_entropy_loss_forward_backward( **kwargs, **backward_kwargs, ) - loss = losses.mean() else: - partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(partial_losses) - sum_exp_logits = torch.empty_like(partial_losses) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) if target_format == TargetFormat.logits: - local_target_max_logits = torch.empty_like(partial_losses) - target_sum_exp_logits = torch.empty_like(partial_losses) + local_target_max_logits = torch.empty_like(losses) + target_sum_exp_logits = torch.empty_like(losses) else: local_target_max_logits = target_sum_exp_logits = None @@ -823,7 +823,7 @@ def triton_entropy_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=local_target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - partial_losses_ptr=partial_losses, + partial_losses_ptr=losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, @@ -835,14 +835,12 @@ def triton_entropy_loss_forward_backward( target_sum_exp_logits, local_target_max_logits, group ) if entropy_loss_type != EntropyLossType.reverse_kl: - partial_losses = _rescale_predicted_logits( - partial_losses, local_target_max_logits, target_max_logits - ) + losses = _rescale_predicted_logits(losses, local_target_max_logits, target_max_logits) else: target_max_logits = None if entropy_loss_type == EntropyLossType.reverse_kl: - partial_losses = _rescale_predicted_logits(partial_losses, local_max_logits, max_logits) - torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + losses = _rescale_predicted_logits(losses, local_max_logits, max_logits) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, group=group) kernel[(n_rows,)]( logits, @@ -852,13 +850,13 @@ def triton_entropy_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - partial_losses_ptr=partial_losses, - losses_ptr=partial_losses, + partial_losses_ptr=losses, + losses_ptr=losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, **backward_kwargs, ) - loss = partial_losses.mean() + loss = reduce_losses(losses, divisor) return loss, grad_logits diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 7949faaf0..4a8c5f179 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -5,7 +5,6 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig from fast_llm.functional.linear import ( input_parallel_linear_forward, @@ -23,6 +22,7 @@ copy_sparse_to_dense_forward, ) from fast_llm.functional.triton.sparse_linear import output_sparse_matmul +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 9538a9275..7c25ce735 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -2,9 +2,9 @@ import torch -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 3d9c07145..f07046a52 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,8 +1,8 @@ import torch -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index e68692d9c..6af0c7828 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -3,9 +3,9 @@ import torch from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward @dataclasses.dataclass() diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py index cb3220131..4bce52119 100644 --- a/fast_llm/functional/triton/z_loss.py +++ b/fast_llm/functional/triton/z_loss.py @@ -6,6 +6,7 @@ triton_cross_entropy_forward_from_labels_parallel_kernel, triton_fused_softmax_base, ) +from fast_llm.functional.utils import reduce_losses @triton_jit() @@ -83,6 +84,7 @@ def triton_z_loss_forward_backward( logits_scale_factor: float = 1.0, block_size: int | None = None, num_warps: int | None = None, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert logits.is_contiguous() if loss_mask is not None: @@ -141,4 +143,5 @@ def triton_z_loss_forward_backward( **kwargs, **backward_kwargs, ) - return losses.mean(), grad_logits + loss = reduce_losses(losses, divisor) + return loss, grad_logits diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/utils.py similarity index 91% rename from fast_llm/functional/autograd.py rename to fast_llm/functional/utils.py index 586f833b3..b2fc4589d 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/utils.py @@ -69,3 +69,12 @@ def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float | Non @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa return grad_output, ctx.grad, None + + +@torch.compile +def reduce_losses( + losses: torch.Tensor, divisor: float | None = None, mask: torch.Tensor | None = None +) -> torch.Tensor: + if mask is not None: + losses = losses * mask + return losses.mean() if divisor is None else losses.sum() / divisor diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 16caf2d66..be40317f3 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index b085961bf..d2a8c7f3b 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -69,10 +69,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.num_blocks self._layers_with_namespace[0].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] - ) + def get_loss_definitions(self) -> list[LossDef]: + return self[0].get_loss_definitions() if self._config.num_blocks > 0 else [] class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): @@ -139,11 +137,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.expanded_pattern.count(name) self._layers_with_namespace[index].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # TODO: Prevent name conflicts. return sum( ( - self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) + self[self._config.preprocessing_layers[name]].get_loss_definitions() for name, count_ in collections.Counter(self._config.expanded_pattern).items() ), [], diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear/linear.py index d0ea7a681..f19e97a94 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -5,7 +5,6 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedDim -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, input_parallel_linear_backward, @@ -15,6 +14,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index fcff5d496..eaf9f67f0 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -3,7 +3,7 @@ import torch from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.layers.common.linear.linear import Linear, LinearBase from fast_llm.tensor import ParameterMeta diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index a9d213912..a2f2d3519 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.functional.autograd import AuxiliaryLoss +from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig @@ -216,18 +216,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # TODO: add layer_index _distillation_loss_name = "activation_distillation_loss" - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: loss_definitions = [] if self._config.distillation_model is not None: - loss_definitions.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=self._distillation_loss_name, - count=count, - ) - ) - return ( - loss_definitions - + self.mixer.get_loss_definitions(count=count) - + self.mlp.get_loss_definitions(count=count) - ) + loss_definitions.append(LossDef(name=self._distillation_loss_name)) + return loss_definitions + self.mixer.get_loss_definitions() + self.mlp.get_loss_definitions() diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 13ba79a7a..48bc5a5e1 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -9,9 +9,9 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType @@ -247,24 +247,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: loss_definitions = [] if self._config.routing == RoutingType.topk: - loss_definitions.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=1, - ) - ) + loss_definitions.append(LossDef(name=MLPLossNames.load_balancing_loss)) if self._config.z_loss_coefficient: - loss_definitions.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=1, - ) - ) + loss_definitions.append(LossDef(name=MLPLossNames.router_z_loss)) return loss_definitions diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 9def3895c..97bd1f477 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -231,7 +231,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return int(expected_usage) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: """ Merge loss definitions from all mixers with namespacing. @@ -241,13 +241,11 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: """ all_losses = [] for mixer_name, mixer in self.mixers.items(): - mixer_losses = mixer.get_loss_definitions(count=count) + mixer_losses = mixer.get_loss_definitions() # Namespace each loss with the mixer name to avoid conflicts for loss_def in mixer_losses: namespaced_loss = LossDef( name=f"{mixer_name}/{loss_def.name}", - formatted_name=f"{mixer_name}/{loss_def.formatted_name}", - count=loss_def.count, dtype=loss_def.dtype, ) all_losses.append(namespaced_loss) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b6b749095..35498ac0a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,8 +10,8 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.functional.utils import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -277,14 +277,10 @@ def _logits_loss_forward_backward_partial( output_parallel_linear_backward(grad, context) if self.training else None ) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: return [ - LossDef(name=self._total_loss_name, formatted_name=self._total_loss_name, count=count), - *( - loss_ - for loss in self.losses - for loss_ in loss.get_loss_definitions(count * self._config.cross_entropy_splits) - ), + LossDef(name=self._total_loss_name), + *(loss_ for loss in self.losses for loss_ in loss.get_loss_definitions()), ] def _get_full_loss_name(self, name) -> str: diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index c3dd625ec..1f12c5b52 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -82,14 +82,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.head.preprocess(kwargs) self.multi_token_prediction.preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? return sum( ( - self.embeddings.get_loss_definitions(count), - self.decoder.get_loss_definitions(count), - self.head.get_loss_definitions(count), - self.multi_token_prediction.get_loss_definitions(count), + self.embeddings.get_loss_definitions(), + self.decoder.get_loss_definitions(), + self.head.get_loss_definitions(), + self.multi_token_prediction.get_loss_definitions(), ), [], ) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index f16b6de44..48c1556f3 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -35,6 +35,7 @@ def _forward_backward( logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels, entropy_loss_type=self._config.loss_type, + divisor=self._get_label_count(kwargs), ) @@ -61,6 +62,7 @@ def _forward_backward( logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, + divisor=self._get_label_count(kwargs), ) def get_preprocessing_config(self) -> dict[str, typing.Any]: diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2136e7918..d055e1f1b 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -4,8 +4,8 @@ import torch from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base +from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -35,6 +35,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), + divisor=self._get_label_count(kwargs), ) self._register_loss( @@ -42,15 +43,8 @@ def _forward_backward( ) return loss, grad - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return super().get_loss_definitions(count) + [ - LossDef( - self._logprob_metric_name, - formatted_name=self._logprob_metric_name, - count=1, # This is an additive metric over the sequence. - dtype=DataType.float32, - ) - ] + def get_loss_definitions(self) -> list[LossDef]: + return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] def get_preprocessing_config( self, @@ -77,6 +71,7 @@ def fused_grpo_loss_forward_backward( num_labels_in_seq: ( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor loss_mask = target >= 0 @@ -88,12 +83,11 @@ def fused_grpo_loss_forward_backward( new_log_probs = predicted_logits - sum_exp_logits.log() probability_ratio = (new_log_probs - old_log_probabilities).exp() - per_sample_loss = -torch.min( + losses = -torch.min( probability_ratio * advantages, torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, ) - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(losses, divisor, loss_mask) # Sum of per-sequence mean log-probs, matching pipelinerl's new_logprobs metric: # sum_sum(new_logprobs / num_labels_in_seq, masks_shifted, segments) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 990d4c3a1..1bcc7d807 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -6,7 +6,6 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs @@ -68,19 +67,8 @@ def _forward_backward( ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - [ - LossDef( - name=self.name, - formatted_name=self.name, - count=count, - dtype=DataType.float32, - ) - ] - if self._do_register_loss - else [] - ) + def get_loss_definitions(self) -> list[LossDef]: + return [LossDef(name=self.name)] if self._do_register_loss else [] def get_preprocessing_config( self, @@ -139,6 +127,9 @@ def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): return self._prepare_target(kwargs[LanguageModelLossKwargs.labels], split_index) + def _get_label_count(self, kwargs: dict[str, typing.Any]): + return kwargs[LanguageModelKwargs.label_counts][self._prediction_distance - 1] + def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) return None if loss_mask is None else self._prepare_target(loss_mask, split_index) diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 5565294d5..d6a086ebb 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -5,6 +5,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_softmax_base from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward +from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -29,8 +30,12 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, logits_scale_factor=self._logits_scale_factor, grad_logits=grad_logits, + divisor=self._get_label_count(kwargs), ) + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return {"return_prediction_mask": True} + @torch.compile def z_loss( @@ -54,6 +59,7 @@ def fused_z_loss_forward_backward( grad_output: float | None = None, group: torch.distributed.ProcessGroup | None = None, logits_scale_factor: float = 1.0, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Z-loss = mean(logsumexp(logits, dim=-1) ** 2) @@ -63,10 +69,7 @@ def fused_z_loss_forward_backward( logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group) log_sum_exp_logits = sum_exp_logits.log() + logits_max - per_sample_loss = log_sum_exp_logits**2 - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(log_sum_exp_logits**2, divisor, loss_mask) if grad_output is not None: grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index c7be11b70..9766182b8 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -99,10 +99,10 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if self._enabled: self._layers_with_namespace[0].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: return ( - self.blocks[0].get_loss_definitions(count=count * (self._config.prediction_heads - 1)) - + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count)] + self.blocks[0].get_loss_definitions() + + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions()] if self._enabled else [] ) diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py index 3116702e6..0b94beec9 100644 --- a/fast_llm/layers/vision/vision_encoder.py +++ b/fast_llm/layers/vision/vision_encoder.py @@ -69,12 +69,12 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.encoder.preprocess(kwargs) self.adapter.preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? return ( - self.embeddings.get_loss_definitions(count) - + self.encoder.get_loss_definitions(count) - + self.adapter.get_loss_definitions(count) + self.embeddings.get_loss_definitions() + + self.encoder.get_loss_definitions() + + self.adapter.get_loss_definitions() ) @@ -123,8 +123,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._vision_encoder_with_namespace.preprocess(kwargs) super().preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.vision_encoder.get_loss_definitions(count) + super().get_loss_definitions(count) + def get_loss_definitions(self) -> list[LossDef]: + return self.vision_encoder.get_loss_definitions() + super().get_loss_definitions() @functools.cached_property def _vision_encoder_namespace(self) -> str: diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 3f45c8184..2619883d6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -112,7 +112,7 @@ def format_metrics( **{key: metrics.pop(key, _NAN) for key in _METRIC_FORMATS_KEYS[phase]}, ) ] - outputs.extend([f"{loss_def.formatted_name}: {metrics.pop(loss_def.name, _NAN):.5f}" for loss_def in loss_defs]) + outputs.extend([f"{loss_def.name}: {metrics.pop(loss_def.name, _NAN):.5f}" for loss_def in loss_defs]) if metrics: outputs.extend([f"{key}: {value}" for key, value in metrics.items()]) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index a53c234f0..55c30c7ee 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -11,7 +11,6 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -81,7 +80,7 @@ def _get_batch( def _inner_forward( self, - batch: LanguageModelInput, + batch: LanguageModelBatch, input_shape: tuple[int], past_key_values=None, inputs_embeds: torch.FloatTensor | None = None, @@ -114,18 +113,12 @@ def _inner_forward( use_cache, output_hidden_states, ) - input_, kwargs = self.fast_llm_base_model.preprocess_batch( - model_input, - phase=PhaseType.inference, - iteration=iteration, - ) - - self._inference_runner.forward(input_, kwargs, iteration=iteration) + self._inference_runner.forward(model_input, iteration=iteration) # TODO: Make a proper way of returning the model output. hidden_states = { name: meta.local_to_global(tensor)[0].unflatten(0, input_shape) - for name, (meta, tensor) in kwargs[AttentionKwargs.hidden_states].items() + for name, (meta, tensor) in model_input.hidden_states.items() } # TODO: Handle MTP. @@ -134,7 +127,7 @@ def _inner_forward( output = transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states or None, - past_key_values=kwargs[AttentionKwargs.presents], + past_key_values=model_input.presents, ) return ( output diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a21bdee7e..83abaca21 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -75,7 +75,7 @@ def preprocess_batch( hidden_states={}, ) reference_model_input.set_children_attributes() - reference_model.forward(model_input, iteration=iteration) + reference_model.forward(reference_model_input, iteration=iteration) kwargs[f"reference_{name}_hidden_states"] = { layer_name: tensor for layer_name, (meta, tensor) in reference_model_input.hidden_states.items() diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c2bde6a8b..b0333e433 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -329,9 +329,10 @@ def test_lm_head(test_config: LMHeadTestConfig): Assert.eq(len(loss_definitions), len(loss_definitions_)) Assert.eq(losses.keys(), ref_losses.keys(), loss_definitions.keys()) + losses = {name: loss[0] if len(loss) == 1 else torch.stack(loss).sum() for name, loss in losses.items()} losses = { - name: loss[0] if len(loss) == 1 else torch.stack(loss).sum() / loss_definitions[name].count - for name, loss in losses.items() + name: loss_definition.reduce(losses[name], distributed) + for name, loss_definition in loss_definitions.items() } for name, loss in losses.items(): From 68b68b2cd7ebc98d0c1e5ab04181d559cd7e5aff Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 25 Mar 2026 16:23:13 +0000 Subject: [PATCH 07/33] Fix intermittent test_data_streaming failure with fakeredis 2.34+ fakeredis 2.34 introduced Resp3Writer hardcoded for all TCP connections regardless of protocol negotiation. When XREADGROUP BLOCK times out on an empty stream, Resp3Writer.dump(None) sends RESP3 null (b'_\r\n'). The redis-py RESP2 parser (used by default) raises Protocol Error: b'_'. Fix: monkey-patch TCPFakeRequestHandler.setup in fake_redis_server() to replace Resp3Writer with Resp2Writer, restoring correct RESP2 null encoding (b'*-1\r\n') for blocking timeouts. The patch is guarded on the presence of Resp3Writer (2.34+ only) and raises explicitly if Resp2Writer is missing so future breakage is immediately diagnosable. --- tests/utils/redis.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 2dc09bee2..6004425dc 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -81,6 +81,34 @@ def safe_handle(self): fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + # ----- Monkey-patch setup to use Resp2Writer instead of Resp3Writer ----- + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + # The Resp2Writer class was introduced alongside the bug in 2.34, so use its + # presence as the version guard. + orig_setup = fakeredis._tcp_server.TCPFakeRequestHandler.setup + if hasattr(fakeredis._tcp_server, "Resp3Writer"): + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + if not hasattr(fakeredis._tcp_server, "Resp2Writer"): + raise RuntimeError( + f"fakeredis {fakeredis.__version__} has Resp3Writer but not Resp2Writer — " + "the workaround for the RESP2/RESP3 null encoding bug no longer applies. " + "See tests/utils/redis.py for details." + ) + + def resp2_setup(self): + orig_setup(self) + if not isinstance(self.writer, fakeredis._tcp_server.Resp2Writer): + self.writer = fakeredis._tcp_server.Resp2Writer(self.client_address, self.wfile, self) + self.current_client.writer = self.writer + + fakeredis._tcp_server.TCPFakeRequestHandler.setup = resp2_setup + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") print(f"Starting fake redis server at {config.host}:{config.port}") thread = threading.Thread(target=server.serve_forever, daemon=True) @@ -94,3 +122,5 @@ def safe_handle(self): server.shutdown() server.server_close() thread.join() + fakeredis._tcp_server.TCPFakeRequestHandler.setup = orig_setup + fakeredis._tcp_server.TCPFakeRequestHandler.handle = orig_handle From ce3e85ae8281f4c331f4f204df87e882e3758748 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Mar 2026 15:11:48 -0400 Subject: [PATCH 08/33] Reduce losses with token counts instead of sequence length - Add `divisor` parameter to fused loss functions (entropy, z-loss, grpo) to allow normalizing by actual token count rather than total sequence positions - Fix `_get_grad_output` to not pre-divide by parallel/split factors (handled by divisor) - Fix loss accumulation across cross-entropy splits in LM head - Fix variable naming bug in `_set_distributed_reduction_map` - Update tests to pass explicit divisor and match new normalization behavior Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/base_model/config.py | 8 +++-- fast_llm/functional/entropy_loss.py | 4 ++- fast_llm/functional/triton/entropy_loss.py | 4 ++- fast_llm/functional/triton/z_loss.py | 4 ++- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/language_model/loss/config.py | 2 +- fast_llm/layers/language_model/loss/grpo.py | 4 ++- fast_llm/layers/language_model/loss/loss.py | 11 ++----- fast_llm/layers/language_model/loss/z_loss.py | 4 ++- tests/layers/test_lm_head.py | 31 +++++++++++-------- tests/layers/test_lm_losses.py | 3 +- 11 files changed, 44 insertions(+), 33 deletions(-) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 30d783199..2770e67a2 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -151,9 +151,9 @@ def _set_torch_reduction_map() -> None: def _set_distributed_reduction_map() -> None: import torch - global _TORCH_REDUCTION_MAP + global _DISTRIBUTED_REDUCTION_MAP - _TORCH_REDUCTION_MAP = { + _DISTRIBUTED_REDUCTION_MAP = { ReductionType.sum: torch.distributed.ReduceOp.SUM, ReductionType.average: torch.distributed.ReduceOp.AVG, ReductionType.minimum: torch.distributed.ReduceOp.MIN, @@ -168,7 +168,9 @@ class LossDef: dtype: DataType = DataType.float32 reduction: ReductionType = ReductionType.sum - def reduce(self, losses: list[torch.Tensor], distributed: "Distributed") -> torch.Tensor | None: + def reduce(self, losses: "list[torch.Tensor]", distributed: "Distributed") -> "torch.Tensor | None": + import torch + from fast_llm.core.ops import reduce_op if losses or distributed.pipeline_group: diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index ea2ba9299..05eaae520 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -293,7 +293,9 @@ def fused_entropy_loss_forward_backward( It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, but still suboptimal because it needs multiple kernels. """ - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor if target_format == TargetFormat.labels: assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) assert loss_mask is None diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index 4ed661606..9ec13a7d4 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -714,6 +714,8 @@ def triton_entropy_loss_forward_backward( assert target.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) + if divisor is None: + divisor = n_rows if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -732,7 +734,7 @@ def triton_entropy_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / n_rows, + "grad_losses": grad_output / divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py index 4bce52119..d9592a4f4 100644 --- a/fast_llm/functional/triton/z_loss.py +++ b/fast_llm/functional/triton/z_loss.py @@ -91,6 +91,8 @@ def triton_z_loss_forward_backward( assert loss_mask.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) + if divisor is None: + divisor = logits.shape[:-1].numel() if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -110,7 +112,7 @@ def triton_z_loss_forward_backward( backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / n_rows, + "grad_losses": grad_output / divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 35498ac0a..aee71ef0b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -221,7 +221,7 @@ def _logits_loss_forward_backward( total_losses.append(total_loss_) # TODO: Avoid copy with explicit out argument. input_grad_.copy_(grad_) - total_loss = sum(total_losses) / self._config.cross_entropy_splits if total_losses else None + total_loss = torch.stack(total_losses).sum() if total_losses else None # TODO: ====== Drop return value, treat as normal loss ====== # Return value only needed because stage expects a return tensor diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 5168aecfb..a2c067a95 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -26,7 +26,7 @@ class LanguageModelLossKwargs(BlockKwargs): rejected_spans = "rejected_spans" advantages = "advantages" old_log_probabilities = "old_log_probabilities" - label_counts = "num_labels_in_seq" + label_counts = "label_counts" num_labels_in_batch = "num_labels_in_batch" diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index d055e1f1b..62f591d9f 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -73,7 +73,9 @@ def fused_grpo_loss_forward_backward( ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 1bcc7d807..985933ecc 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -115,20 +115,13 @@ def _prepare_target( def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: grad_output = kwargs.get(LanguageModelKwargs.grad_output) - if grad_output is not None: - grad_output = ( - grad_output - * self._weight - / (self._parallel_dim.size if self._sequence_parallel else 1) - / self._num_splits - ) - return grad_output + return None if grad_output is None else grad_output * self._weight def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): return self._prepare_target(kwargs[LanguageModelLossKwargs.labels], split_index) def _get_label_count(self, kwargs: dict[str, typing.Any]): - return kwargs[LanguageModelKwargs.label_counts][self._prediction_distance - 1] + return kwargs[LanguageModelKwargs.num_labels_in_batch][self._prediction_distance - 1] def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index d6a086ebb..2e5f90b1d 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -65,7 +65,9 @@ def fused_z_loss_forward_backward( Z-loss = mean(logsumexp(logits, dim=-1) ** 2) Grad = 2 * log_sum_exp_logits * softmax(logits) """ - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group) log_sum_exp_logits = sum_exp_logits.log() + logits_max diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b0333e433..73e9f4807 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -102,6 +102,11 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: torch.randint(0, 2, (NUM_TOKENS,), dtype=torch.bool, device=device) for _ in range(self.prediction_heads) ] + kwargs[LanguageModelKwargs.num_labels_in_batch] = [ + loss_mask.sum().item() for loss_mask in kwargs[LanguageModelKwargs.loss_mask] + ] + else: + kwargs[LanguageModelKwargs.num_labels_in_batch] = [NUM_TOKENS for _ in range(self.prediction_heads)] if self.actual_label_loss is not False or self.grpo_loss is not False: labels = [ torch.randint( @@ -166,32 +171,34 @@ def get_reference_outputs( names_losses_weights = [] + loss_mask = ( + kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] + if LanguageModelKwargs.loss_mask in kwargs + else None + ) + if self.actual_label_loss is not False or self.grpo_loss is not False: labels = kwargs[LanguageModelKwargs.labels][head._prediction_distance - 1] if self.actual_label_loss is not False: - label_loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none").mean() + label_loss = torch.nn.functional.cross_entropy(logits, labels) names_losses_weights.append(("label", label_loss, float(self.actual_label_loss))) - # total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: distillation_loss = torch.nn.functional.cross_entropy( logits, torch.softmax(kwargs[f"reference_distillation_hidden_states"]["head.logits"], -1), - reduction="none", + reduction="mean" if loss_mask is None else "none", ) - if LanguageModelKwargs.loss_mask in kwargs: - distillation_loss = ( - distillation_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] - ) - distillation_loss = distillation_loss.mean() + if loss_mask is not None: + distillation_loss = (distillation_loss * loss_mask).sum() / loss_mask.sum() names_losses_weights.append(("distillation", distillation_loss, float(self.distillation_loss))) if self.z_loss is not False: z_loss = torch.logsumexp(logits, dim=-1) ** 2 - if LanguageModelKwargs.loss_mask in kwargs: - z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] - z_loss = z_loss.mean() + if loss_mask is not None: + z_loss = z_loss * loss_mask + z_loss = z_loss.mean() if loss_mask is None else (z_loss * loss_mask).sum() / loss_mask.sum() names_losses_weights.append(("z_loss", z_loss, float(self.z_loss))) if self.grpo_loss is not False: @@ -317,7 +324,6 @@ def test_lm_head(test_config: LMHeadTestConfig): losses = collections.defaultdict(list) output, context = stage.forward(head_input, kwargs, losses) - print(losses) stage.backward(output_grad, context) threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( @@ -329,7 +335,6 @@ def test_lm_head(test_config: LMHeadTestConfig): Assert.eq(len(loss_definitions), len(loss_definitions_)) Assert.eq(losses.keys(), ref_losses.keys(), loss_definitions.keys()) - losses = {name: loss[0] if len(loss) == 1 else torch.stack(loss).sum() for name, loss in losses.items()} losses = { name: loss_definition.reduce(losses[name], distributed) for name, loss_definition in loss_definitions.items() diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index a719e44e8..3a68a999f 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -147,7 +147,7 @@ def reference_grpo_loss( # new_logprobs: sum of per-sequence mean log-probs log_probs = torch.nn.functional.log_softmax(logits_, -1).gather(-1, labels.unsqueeze(-1)).squeeze(-1) new_logprobs = (log_probs * loss_mask).sum() / max(float(loss_mask.sum()), 1.0) - return (loss * loss_mask).mean(), new_logprobs + return (loss * loss_mask).sum() / loss_mask.sum(), new_logprobs _BATCH_SHAPES = ((64,), (16, 8)) @@ -272,6 +272,7 @@ def _test_grpo_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, + divisor=(target >= 0).sum().item(), ) _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) From b429c5ee87065929c7eade1a3e0baf0d80f93f10 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Mar 2026 12:07:40 -0400 Subject: [PATCH 09/33] Fix tflops units, loss reduce ops, and CPU test support - Fix schedule tflops divide by 1e12 (was reporting raw flops) - Change loss reductions from AVG to SUM (needed with token-count weighting) - Add CPU/gloo fallback support in distributed test configs - Fix pp tied weight bias ignore_duplicates - Adjust micro_batch_size and compare targets for distributed configs Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/distributed/config.py | 4 +++- fast_llm/engine/schedule/schedule.py | 4 ++-- fast_llm/layers/language_model/head.py | 4 ++-- fast_llm/layers/language_model/loss/loss.py | 2 +- tests/models/test_checkpoint.py | 16 ++++++++-------- tests/models/test_model.py | 9 +++------ tests/utils/distributed_configs.py | 15 ++++++++++----- tests/utils/model_configs.py | 6 ++++-- tests/utils/save_load_configs.py | 5 +++++ 9 files changed, 38 insertions(+), 27 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index c3950cedf..f5dc00a09 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -117,7 +117,9 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start: global_ranks = range(start, start + size * stride, global_ranks.step) else: - global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks] + global_ranks = tuple( + rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks + ) Assert.eq(len(global_ranks), world_size) return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 361772818..e2a9c75b5 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -538,6 +538,6 @@ def compute_usage(self) -> tuple[int | None, int | None]: def get_compute_metrics(self, time_per_iteration: float) -> dict[str, float]: model_compute, hardware_compute = self.compute_usage return { - "model_tflops": math.nan if model_compute is None else model_compute / time_per_iteration, - "hardware_tflops": math.nan if hardware_compute is None else hardware_compute / time_per_iteration, + "model_tflops": math.nan if model_compute is None else model_compute / time_per_iteration / 1e12, + "hardware_tflops": math.nan if hardware_compute is None else hardware_compute / time_per_iteration / 1e12, } diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index aee71ef0b..d57b465bf 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -136,7 +136,7 @@ def forward( (scalar_dim,), tensor_name=f"{self.module_name} output", reductions=( - (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.SUM), ), ) else: @@ -227,7 +227,7 @@ def _logits_loss_forward_backward( # Return value only needed because stage expects a return tensor if self._sequence_parallel_logits: # TODO: Async - all_reduce(total_loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + all_reduce(total_loss, op=ReduceOp.SUM, group=self._parallel_dim.group) if losses is not None: losses[self._total_loss_name].append(total_loss) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 985933ecc..9a92661c9 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -76,7 +76,7 @@ def get_preprocessing_config( return {} def _register_loss( - self, name: str, value: torch.Tensor, losses: dict | None, reduce_op=torch.distributed.ReduceOp.AVG + self, name: str, value: torch.Tensor, losses: dict | None, reduce_op=torch.distributed.ReduceOp.SUM ): if losses is None: return diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 74c51719d..5f0f5a80f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -472,13 +472,12 @@ def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_pat # Save and load checkpoints to and from various distributed configurations. # Combined in a single test to mitigate process creation overhead. # TODO: Test beyond 2 gpu configs? - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs2") run_parallel_script( _save_and_load_in_parallel, (run_test_script_base_path, model_testing_config), world_size=2, backend=model_testing_config.distributed_backend, + use_cuda=torch.cuda.is_available(), ) @@ -503,6 +502,7 @@ def test_load_parallel_checkpoint_in_single_gpu( load_and_compare_checkpoints, reference_distributed_shard, report_subtest, + testing_device, ): if ( model_testing_config.checkpoint_format is None @@ -514,16 +514,16 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) - if torch.cuda.device_count() < distributed_save_load_config.num_gpus: - pytest.skip( - f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" - ) - report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) + report_subtest( + distributed_save_load_config.save_path, + distributed_save_load_config.num_gpus, + use_cuda=torch.cuda.is_available(), + ) load_and_compare_checkpoints( DistributedCheckpointFormat, distributed_save_load_config.save_path / DistributedCheckpointFormat.name, None, - reference_distributed_shard.to(device="cuda"), + reference_distributed_shard.to(device=testing_device), ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 0c58afade..f3a9a1d7c 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -68,13 +68,12 @@ def _run_model_distributed( ModelTestingGroup.distributed, ) def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path): - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs") run_parallel_script( _run_model_distributed, (run_test_script_base_path, model_testing_config), - world_size=torch.cuda.device_count(), + world_size=torch.cuda.device_count() if torch.cuda.is_available() else 8, backend=model_testing_config.distributed_backend, + use_cuda=torch.cuda.is_available(), ) @@ -94,9 +93,7 @@ def test_model_distributed( config = DISTRIBUTED_TESTING_CONFIGS[config_name] if model_testing_config.should_skip(config): pytest.skip(f"Configuration not supported.") - if torch.cuda.device_count() < config.num_gpus: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < {config.num_gpus}") - report_subtest(run_test_script_base_path / config.name, config.num_gpus) + report_subtest(run_test_script_base_path / config.name, config.num_gpus, use_cuda=torch.cuda.is_available()) if config.compare is not None: if not check_subtest_success(run_test_script_base_path / config.compare): pytest.fail(f"Test {config.compare} failed", pytrace=False) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 933ea8f8e..405aa1bcd 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -18,6 +18,7 @@ class DistributedTestingConfig: compare_config: CompareConfig | None = None # Scale the comparison thresholds for specific distributed configs. compare_factor: float = 1.0 + requires_cuda: bool = False def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareConfig: @@ -53,6 +54,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon _compare_layer_mismatch_duplicate_gradients = copy.deepcopy(_compare_layer_mismatch) _compare_layer_mismatch_duplicate_gradients.sub_configs[(None, "bias")].ignore_duplicates = True _compare_layer_mismatch_duplicate_gradients.sub_configs[(None, "gradient")].ignore_duplicates = True +_pp_tied_weight_compare.sub_configs[(None, "bias")].ignore_duplicates = True _pp_tied_weight_compare.sub_configs[(None, "gradient")].ignore_duplicates = True _pp_tied_weight_compare.sub_configs[("init", None)].ignore_duplicates = True for tensor in ("fw", "bw"): @@ -113,6 +115,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ) _SINGLE_GPU_TESTING_CONFIGS = [ + # TODO: 16-bit matmuls extremely slow on cpu DistributedTestingConfig( name="bf16", compare="simple", @@ -124,6 +127,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ], num_gpus=1, compare_config=_bf16_compare, + requires_cuda=True, ), DistributedTestingConfig( name="fp16", @@ -131,6 +135,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=["model.distributed.compute_dtype=fp16", "data.micro_batch_size=4096"], num_gpus=1, compare_config=_fp16_compare, + requires_cuda=True, ), # Cross-entropy splits. DistributedTestingConfig( @@ -397,17 +402,17 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple DistributedTestingConfig( name="dp2_stp2_pp2s2_bf4", - compare="dp2_z2_df4", + compare="df8", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "schedule.breadth_first_micro_batches=4", - "data.micro_batch_size=412", + "data.micro_batch_size=512", ], num_gpus=8, - compare_config=_compare_layer_match, + compare_config=_compare_layer_mismatch, ), # Tied weights on different ranks DistributedTestingConfig( @@ -427,7 +432,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Micro-sequence DistributedTestingConfig( name="sdp2_stp2_pp2s2_ms4", - compare="df2", + compare="simple", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -435,7 +440,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "schedule.micro_batch_splits=4", - "data.micro_batch_size=2048", + "data.micro_batch_size=4096", ], num_gpus=8, compare_config=_compare_layer_mismatch, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3e6910b6f..6268ac194 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -154,7 +154,9 @@ def distributed_backend(self): return DistributedBackend(self.config_dict["model"]["distributed"]["backend"]) def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: - return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) + return (distributed_config.requires_cuda and not torch.cuda.is_available()) or any( + re.search(pattern, distributed_config.name) for pattern in self.skip_tests + ) def update_and_add_testing_config( @@ -264,7 +266,7 @@ def update_and_add_testing_config( "distributed": { "reproducible_init": True, "timeout": 20, - "backend": "nccl", + "backend": DistributedBackend.nccl if torch.cuda.device_count() >= 2 else DistributedBackend.gloo, "use_cuda": torch.cuda.is_available(), }, }, diff --git a/tests/utils/save_load_configs.py b/tests/utils/save_load_configs.py index 3e7cbf10f..6bc619825 100644 --- a/tests/utils/save_load_configs.py +++ b/tests/utils/save_load_configs.py @@ -4,6 +4,7 @@ import typing import pytest +import torch from fast_llm.engine.checkpoint.config import CheckpointFormat, DistributedCheckpointFormat, FastLLMCheckpointFormat from tests.utils.model_configs import ModelTestingConfig @@ -17,6 +18,10 @@ class DistributedSaveLoadConfig: distributed: dict[str, typing.Any] num_gpus: int = 2 + def __post_init__(self): + self.distributed["use_cuda"] = torch.cuda.is_available() + self.distributed["backend"] = "nccl" if torch.cuda.device_count() >= self.num_gpus else "gloo" + def resolve(self, base_path: pathlib.Path, model_testing_config: ModelTestingConfig) -> typing.Self: if model_testing_config.checkpoint_format is None: format = { From 647fbb7dc087ec1d2df1d5d1366cc8c4bcdca4b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Mar 2026 12:33:09 -0400 Subject: [PATCH 10/33] Fix MTP Llama converter to map head.final_norm to model.mtp_norms.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MTPLlamaModel uses mtp_norms[0] for the first prediction head instead of model.norm (as in standard Llama). The converter was inheriting the Llama mapping (head.final_norm → model.norm), so the native HuggingFace model loaded converted checkpoints with mtp_norms[0] uninitialized. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/models/gpt/conversion/mtp_llama.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 5ce91fbac..cb9c5c1f2 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -13,6 +13,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + get_parameter_converter, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -38,7 +39,21 @@ def get_converters( config: LanguageModelConfig, exported_config: dict, ) -> list[WeightConverter]: - converters = super().get_converters(config, exported_config) + # Override: map head.final_norm to model.mtp_norms.0 (not model.norm as in standard Llama), + # since MTPLlamaModel uses mtp_norms[0] for the first prediction head. + converters = [ + *cls.normalization_converter_class.get_converters( + config.head.normalization, + "head.final_norm", + "model.mtp_norms.0", + ), + get_parameter_converter( + "head.output_weights", + "lm_head.weight", + drop_on_import=exported_config["tie_word_embeddings"], + drop_on_export=exported_config["tie_word_embeddings"], + ), + ] for prediction_distance in range(2, config.head.prediction_heads + 1): converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, From d26d4ec34736768f950d19634207951c0fe20173 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Mar 2026 13:19:50 -0400 Subject: [PATCH 11/33] Fix DistributedDim pickling to allow DataLoader workers with streaming datasets Add __getstate__/__setstate__ to DistributedDim to drop the process group when pickling, so DataLoader worker processes can be spawned even when the dataset or collate_fn captures a DistributedConfig with active process groups. Also expand test_data_streaming to cover num_workers=1 and increase _NUM_BATCHES from 2 to 10 for better coverage. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/distributed/config.py | 10 ++++++++++ tests/data/test_streaming.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index f5dc00a09..a214e8e50 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -80,6 +80,16 @@ class DistributedDim: def __post_init__(self): self._is_setup = False + def __getstate__(self): + # Prevent process groups from being pickled, ex. in the data loader. + state = self.__dict__.copy() + if "_group" in state: + del state["_group"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + @property def group(self) -> "ProcessGroup|None": assert hasattr(self, "_group") diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index c7088eae3..83f7657a0 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -144,7 +144,7 @@ def test_streaming_sampled_dataset( assert batch.old_log_probabilities is None -_NUM_BATCHES = 2 +_NUM_BATCHES = 10 _SEQUENCE_LENGTH = 10 @@ -160,7 +160,9 @@ def _get_distributed_config(distributed_config_dict: dict[str, typing.Any], worl ) -def _run_test_data_streaming(path: pathlib.Path, distributed_config: DistributedConfig, port: int): +def _run_test_data_streaming( + path: pathlib.Path, distributed_config: DistributedConfig, port: int, num_workers: int = 1 +): redis_config = RedisConfig(port=port + 100, timeout=1) data = GPTData( @@ -186,7 +188,7 @@ def _run_test_data_streaming(path: pathlib.Path, distributed_config: Distributed distributed_config.batch_data_parallel * _NUM_BATCHES, ) data_iter = data.get_iterator( - "train", consumed_samples=0, num_workers=0, prefetch_factor=None, timeout=5, preprocess=False + "train", consumed_samples=0, num_workers=num_workers, prefetch_factor=None, timeout=5, preprocess=False ) batches = [next(data_iter) for _ in range(_NUM_BATCHES)] path.mkdir(parents=True, exist_ok=True) @@ -228,10 +230,11 @@ def _run_test_data_streaming_distributed( _run_test_data_streaming(base_path / name, distributed_config, port) -def test_data_streaming(result_path, worker_resources): +@pytest.mark.parametrize("num_workers", (0, 1)) +def test_data_streaming(result_path, worker_resources, num_workers): distributed_config = _get_distributed_config({}) path = result_path / "data_streaming/single_gpu" - _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port) + _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers) check_data_streaming_results(path, distributed_config) From fa607fc7bc977defcc6ad77001c266908bd961a8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Mar 2026 13:46:55 -0400 Subject: [PATCH 12/33] Expand test_preprocessing.py with comprehensive coverage Add tests for padding, multi-token prediction, micro-batch splits, prediction mask, label counts, GRPO data, position index, inference phase, document count, and cumulative sequence lengths. Co-Authored-By: Claude Sonnet 4.6 --- tests/data/test_preprocessing.py | 197 ++++++++++++++++++++++++++++++- 1 file changed, 195 insertions(+), 2 deletions(-) diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index d0e56e3f0..0990377da 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -4,11 +4,11 @@ from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert -# TODO: Test padding, more scenarios -# TODO: Check rest of preprocessing output @pytest.mark.parametrize( ("tokens", "loss_masking_spans"), ( @@ -53,3 +53,196 @@ def test_preprocessing(tokens, loss_masking_spans): Assert.eq(len(model_input.targets), 1) Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:]) + + +def test_preprocessing_padding(): + # 5 real tokens padded to 8: padding tokens (-100) should be masked out of labels. + tokens = [100, 101, 102, 103, 104] + document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) + + batch = LanguageModelBatch.from_documents([document], pad_to_size=8) + (model_input,) = batch.get_model_inputs(LanguageModelBatchPreprocessingConfig()) + + # total_input_length = 8 - 1 = 7; tokens[0:7] = [100, 101, 102, 103, 104, -100, -100] + Assert.all_equal( + model_input.tokens, + torch.tensor([100, 101, 102, 103, 104, -100, -100], dtype=torch.int64), + ) + + # labels = [100, 101, 102, 103, 104, -100, -100, -100], then labels[0]=-100 (cross-doc) + # target = labels[1:8] = [101, 102, 103, 104, -100, -100, -100] + Assert.all_equal( + model_input.targets[0].tokens, + torch.tensor([101, 102, 103, 104, -100, -100, -100], dtype=torch.int64), + ) + + +@pytest.mark.parametrize("predicted_tokens", [1, 2, 3]) +def test_preprocessing_multi_token_prediction(predicted_tokens): + # With predicted_tokens=d, there are d target sets. + # Target for distance d is tokens[d : d + total_input_length]. + # Cross-doc masking for distance d falls at index d-1, just outside each target window. + tokens = list(range(100, 111)) # 11 tokens + document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) + + config = LanguageModelBatchPreprocessingConfig(predicted_tokens=predicted_tokens) + (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) + + total_input = len(tokens) - predicted_tokens + Assert.all_equal(model_input.tokens, torch.tensor(tokens[:total_input], dtype=torch.int64)) + Assert.eq(len(model_input.targets), predicted_tokens) + + for i, target in enumerate(model_input.targets): + d = i + 1 + # Cross-doc masking for all distances <=d falls at indices 0..d-1, outside window [d:d+total_input]. + Assert.all_equal(target.tokens, torch.tensor(tokens[d : d + total_input], dtype=torch.int64)) + + +def test_preprocessing_micro_batch_splits(): + # micro_batch_splits=2 produces two model inputs each covering half the sequence. + tokens = list(range(100, 113)) # 13 tokens → total_input_length=12, each split=6 + document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) + + config = LanguageModelBatchPreprocessingConfig(micro_batch_splits=2) + model_inputs = LanguageModelBatch.from_documents([document]).get_model_inputs(config) + + Assert.eq(len(model_inputs), 2) + Assert.all_equal(model_inputs[0].tokens, torch.tensor(tokens[:6], dtype=torch.int64)) + Assert.all_equal(model_inputs[1].tokens, torch.tensor(tokens[6:12], dtype=torch.int64)) + + # labels[0]=-100 (cross-doc); targets are labels[1:7] and labels[7:13] + Assert.all_equal(model_inputs[0].targets[0].tokens, torch.tensor(tokens[1:7], dtype=torch.int64)) + Assert.all_equal(model_inputs[1].targets[0].tokens, torch.tensor(tokens[7:13], dtype=torch.int64)) + + +def test_preprocessing_prediction_mask(): + # return_prediction_mask exposes the boolean mask of non-masked label positions. + tokens = [100, 101, 102, 103, 104, 105] + document = LanguageModelDocument( + tokens=torch.tensor(tokens, dtype=torch.int64), + loss_masking_spans=RangeDocument(ranges=[(2, 4)]), # mask positions 2 and 3 + ) + + config = LanguageModelBatchPreprocessingConfig(return_prediction_mask=True) + (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) + + # labels = [100, 101, 102, 103, 104, 105] + # after span masking: labels[2:4] = -100 → [100, 101, -100, -100, 104, 105] + # after cross-doc: labels[0] = -100 → [-100, 101, -100, -100, 104, 105] + # target = labels[1:6] = [101, -100, -100, 104, 105] + # mask[1:6] = [True, False, False, True, True] + assert model_input.targets[0].mask is not None + Assert.all_equal( + model_input.targets[0].mask, + torch.tensor([True, False, False, True, True]), + ) + + +def test_preprocessing_label_counts(): + # return_label_counts gives each token the total count of valid labels in its document. + # Two documents each of length 4; cross-doc masking removes the first token of each, + # leaving 3 valid labels per document. + docs = [ + LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), + LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), + ] + + config = LanguageModelBatchPreprocessingConfig(return_label_counts=True) + (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) + + # labels after cross-doc masking: [-100, 101, 102, 103, -100, 201, 202, 203] + # doc1: 3 valid labels (indices 1,2,3); doc2: 3 valid labels (indices 5,6,7) + # target window: labels[1:8] → label_counts[1:8] = [3, 3, 3, 3, 3, 3, 3] + assert model_input.targets[0].label_counts is not None + Assert.all_equal( + model_input.targets[0].label_counts, + torch.full((7,), 3, dtype=model_input.targets[0].label_counts.dtype), + ) + + +def test_preprocessing_grpo_data(): + # use_grpo_data attaches per-token advantages and log-probabilities to the target, + # cropped to the label window [label_begin:label_end]. + tokens = [100, 101, 102, 103, 104, 105] + advantages_data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + log_probs_data = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6] + + document = LanguageModelDocument( + tokens=torch.tensor(tokens, dtype=torch.int64), + advantages=TokenDataDocument(data=torch.tensor(advantages_data)), + old_log_probabilities=TokenDataDocument(data=torch.tensor(log_probs_data)), + ) + + config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True) + (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) + + # total_input_length=5; label_begin=1, label_end=6 + target = model_input.targets[0] + assert target.advantages is not None + assert target.old_log_probabilities is not None + Assert.rms_close(target.advantages, torch.tensor(advantages_data[1:]), 1e-6) + Assert.rms_close(target.old_log_probabilities, torch.tensor(log_probs_data[1:]), 1e-6) + + +def test_preprocessing_position_index(): + # return_position_index gives the within-document position of each input token, + # resetting to 0 at every document boundary. + docs = [ + LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), # len=4 + LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), # len=4 + ] + + config = LanguageModelBatchPreprocessingConfig(return_position_index=True) + (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) + + # total_input_length=7; input tokens: [100,101,102,103,200,201,202] + # positions: doc1 → [0,1,2,3], doc2 (first 3 tokens) → [0,1,2] + assert model_input.position_index is not None + Assert.all_equal( + model_input.position_index, + torch.tensor([0, 1, 2, 3, 0, 1, 2], dtype=torch.int32), + ) + + +def test_preprocessing_inference(): + # In inference phase num_labels=0, so the full token sequence is the input and there are no targets. + tokens = [100, 101, 102, 103, 104] + document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) + + config = LanguageModelBatchPreprocessingConfig(phase=PhaseType.inference) + (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) + + Assert.all_equal(model_input.tokens, torch.tensor(tokens, dtype=torch.int64)) + Assert.eq(len(model_input.targets), 0) + + +def test_preprocessing_document_count(): + # return_document_count records how many documents are in the batch (first split only). + docs = [ + LanguageModelDocument(tokens=torch.tensor([100, 101, 102], dtype=torch.int64)), + LanguageModelDocument(tokens=torch.tensor([200, 201, 202], dtype=torch.int64)), + ] + + config = LanguageModelBatchPreprocessingConfig(return_document_count=True) + (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) + + Assert.eq(model_input.num_documents, 2) + + +def test_preprocessing_cumulative_sequence_lengths(): + # return_cumulative_sequence_lengths produces cu_seqlens tensors for flash-attention style kernels. + docs = [ + LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), # len=4 + LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), # len=4 + ] + + config = LanguageModelBatchPreprocessingConfig(return_cumulative_sequence_lengths=True) + (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) + + # total_input_length=7; lengths in this input: [4, 3] (doc2 is cut to 3 by the -1 label offset) + # cumulative_lengths_q = padded_cumsum([4, 3]) = [0, 4, 7] + # cumulative_lengths_k = [0, 4, 7] (sequence_k_past=0, first_document_begin=0) + assert model_input.cumulative_lengths_q is not None + assert model_input.cumulative_lengths_k is not None + Assert.all_equal(model_input.cumulative_lengths_q, torch.tensor([0, 4, 7], dtype=torch.int32)) + Assert.all_equal(model_input.cumulative_lengths_k, torch.tensor([0, 4, 7], dtype=torch.int32)) From 9af9675cb6b935891d03378e8895798387691cf6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Mar 2026 18:00:36 -0400 Subject: [PATCH 13/33] Fix cross-document masking bounds and padding document count bugs, expand preprocessing tests - Guard cross-document label masking against documents shorter than prediction distance - Fix num_documents to exclude the padding pseudo-document from the count - Add comprehensive test coverage: all split/target indices, predicted_tokens in (1,3), padding variants, and complex multi-document cases with loss masking spans and GRPO data - Refactor test helpers into cached properties indexed by [split_index][target_index] Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/data/document/language_model.py | 3 +- fast_llm/data/document/token.py | 5 +- tests/data/test_preprocessing.py | 630 ++++++++++++++--------- 3 files changed, 406 insertions(+), 232 deletions(-) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 8f5c98801..7821b81c5 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -172,7 +172,8 @@ def _set_target_inputs( # Mask cross-document predictions. document_begin = 0 for length in self.lengths: - labels[document_begin + prediction_distance - 1] = -100 + if prediction_distance <= length: + labels[document_begin + prediction_distance - 1] = -100 document_begin += length mask = labels >= 0 diff --git a/fast_llm/data/document/token.py b/fast_llm/data/document/token.py index 8aeabb694..70261a152 100644 --- a/fast_llm/data/document/token.py +++ b/fast_llm/data/document/token.py @@ -103,7 +103,10 @@ def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfi lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end) if config.return_document_count: - model_input.num_documents = len(self.lengths) if begin == 0 else 0 + # Exclude the padding "length" from the document count. + model_input.num_documents = ( + len(self.lengths) - (1 if self.unpadded_length < len(self.tokens) else 0) if begin == 0 else 0 + ) LengthModelInputPreprocessor( lengths=lengths, diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index 0990377da..ae58121ae 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -1,3 +1,6 @@ +import dataclasses +import functools + import pytest import torch @@ -6,243 +9,410 @@ from fast_llm.data.document.range import RangeDocument from fast_llm.data.document.token_data import TokenDataDocument from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div + + +def _get_cropped_lengths(batch_lengths: list[int], begin: int, end: int) -> tuple[list[int], int]: + """Return (cropped_lengths, first_document_begin) for the token window [begin, end).""" + doc_begin = 0 + cropped = [] + first_doc_begin = 0 + for length in batch_lengths: + doc_end = doc_begin + length + crop = min(doc_end, end) - max(doc_begin, begin) + if crop > 0: + if not cropped: + first_doc_begin = doc_begin + cropped.append(crop) + if doc_end > end: + break + doc_begin = doc_end + return cropped, first_doc_begin + + +def _compute_label_counts(batch_lengths: list[int], labels: list[int]) -> torch.Tensor: + """For each token, compute the count of valid (non-negative) labels in its document.""" + result = [] + offset = 0 + for length in batch_lengths: + count = sum(1 for label in labels[offset : offset + length] if label >= 0) + result.extend([count] * length) + offset += length + return torch.tensor(result, dtype=torch.int64) + + +def _assert_tensor_equal_or_none(actual: torch.Tensor | None, expected: torch.Tensor | None) -> None: + if expected is None: + assert actual is None + else: + Assert.all_equal(actual, expected) + + +@dataclasses.dataclass +class PreprocessingTestConfig: + name: str + tokens: list[list[int]] + loss_masking_spans: list[list[tuple[int, int]] | None] | None = None + padding: int | None = None + advantages: list[list[float]] | None = None + log_probabilities: list[list[float]] | None = None + phase: PhaseType = PhaseType.training + predicted_tokens: int = 1 + micro_batch_splits: int = 1 + use_grpo_data: bool = False + return_prediction_mask: bool = False + return_label_counts: bool = False + return_position_index: bool = False + return_document_count: bool = False + return_cumulative_sequence_lengths: bool = False + + @functools.cached_property + def config_kwargs(self) -> dict: + return { + "phase": self.phase, + "predicted_tokens": self.predicted_tokens, + "micro_batch_splits": self.micro_batch_splits, + "use_grpo_data": self.use_grpo_data, + "return_prediction_mask": self.return_prediction_mask, + "return_label_counts": self.return_label_counts, + "return_position_index": self.return_position_index, + "return_document_count": self.return_document_count, + "return_cumulative_sequence_lengths": self.return_cumulative_sequence_lengths, + } + + @functools.cached_property + def padding_size(self) -> int: + return 0 if self.padding is None else self.padding + + @functools.cached_property + def unpadded_size(self) -> int: + return sum(self.unpadded_lengths) + + @functools.cached_property + def padded_size(self) -> int: + return self.unpadded_size + self.padding_size + + @functools.cached_property + def unpadded_lengths(self) -> list[int]: + return [len(tokens) for tokens in self.tokens] + + @functools.cached_property + def padded_lengths(self) -> list[int]: + return self.unpadded_lengths + ([self.padding_size] if self.padding_size > 0 else []) + + @functools.cached_property + def num_labels(self) -> int: + return self.padded_size - self.predicted_tokens + + @functools.cached_property + def split_size(self) -> int: + return div(self.num_labels, self.micro_batch_splits) + + @functools.cached_property + def all_flat_tokens(self) -> list[int]: + return sum(self.tokens, []) + [-100] * self.padding_size + + @functools.cached_property + def base_labels(self) -> list[int]: + """Tokens with loss masking spans applied, but no cross-document masking.""" + labels = list(self.all_flat_tokens) + if self.loss_masking_spans is not None: + offset = 0 + for doc_tokens, spans in zip(self.tokens, self.loss_masking_spans): + if spans is not None: + for begin, end in spans: + labels[offset + begin : offset + end] = [-100] * (end - begin) + offset += len(doc_tokens) + return labels + + @functools.cached_property + def labels_per_distance(self) -> list[list[int]]: + """For each prediction distance d, labels with cumulative cross-document masking.""" + result = [] + labels = list(self.base_labels) + for d in range(1, self.predicted_tokens + 1): + offset = 0 + for doc_tokens in self.tokens: + if d <= len(doc_tokens): + labels[offset + d - 1] = -100 + offset += len(doc_tokens) + result.append(list(labels)) + return result + + @functools.cached_property + def _split_ranges(self) -> list[tuple[int, int]]: + return [(i * self.split_size, (i + 1) * self.split_size) for i in range(self.micro_batch_splits)] + + @functools.cached_property + def _cropped_lengths_per_split(self) -> list[tuple[list[int], int]]: + return [_get_cropped_lengths(self.padded_lengths, begin, end) for begin, end in self._split_ranges] + + @functools.cached_property + def expected_input_tokens(self) -> list[torch.Tensor]: + all_tokens = torch.tensor(self.all_flat_tokens, dtype=torch.int64) + return [all_tokens[begin:end] for begin, end in self._split_ranges] + + @functools.cached_property + def expected_target_tokens(self) -> list[list[torch.Tensor]]: + labels_tensors = [torch.tensor(labels, dtype=torch.int64) for labels in self.labels_per_distance] + return [ + [ + labels_tensors[target_index][begin + d : end + d] + for target_index, d in enumerate(range(1, self.predicted_tokens + 1)) + ] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_target_mask(self) -> list[list[torch.Tensor | None]]: + if not self.return_prediction_mask: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + return [[tokens >= 0 for tokens in split_targets] for split_targets in self.expected_target_tokens] + + @functools.cached_property + def expected_target_label_counts(self) -> list[list[torch.Tensor | None]]: + if not self.return_label_counts: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + return [ + [ + _compute_label_counts(self.padded_lengths, self.labels_per_distance[target_index])[begin + d : end + d] + for target_index, d in enumerate(range(1, self.predicted_tokens + 1)) + ] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_advantages(self) -> list[list[torch.Tensor | None]]: + if self.advantages is None: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + flat = torch.tensor(sum(self.advantages, []) + [0.0] * self.padding_size, dtype=torch.float32) + return [ + [flat[begin + d : end + d] for d in range(1, self.predicted_tokens + 1)] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_log_probabilities(self) -> list[list[torch.Tensor | None]]: + if self.log_probabilities is None: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + flat = torch.tensor(sum(self.log_probabilities, []) + [0.0] * self.padding_size, dtype=torch.float32) + return [ + [flat[begin + d : end + d] for d in range(1, self.predicted_tokens + 1)] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_position_index(self) -> list[torch.Tensor | None]: + if not self.return_position_index: + return [None] * self.micro_batch_splits + result = [] + for split_index, (begin, _end) in enumerate(self._split_ranges): + cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index] + pos_in_doc = begin - first_doc_begin + positions = [] + remaining = cropped_lengths[0] if cropped_lengths else 0 + doc_index = 0 + for _ in range(self.split_size): + positions.append(pos_in_doc) + pos_in_doc += 1 + remaining -= 1 + if remaining == 0 and doc_index + 1 < len(cropped_lengths): + doc_index += 1 + remaining = cropped_lengths[doc_index] + pos_in_doc = 0 + result.append(torch.tensor(positions, dtype=torch.int32)) + return result + + @functools.cached_property + def expected_cumulative_lengths(self) -> list[tuple[torch.Tensor | None, torch.Tensor | None]]: + if not self.return_cumulative_sequence_lengths: + return [(None, None)] * self.micro_batch_splits + result = [] + for split_index, (begin, _end) in enumerate(self._split_ranges): + cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index] + cu_q = torch.tensor([0] + cropped_lengths, dtype=torch.int32).cumsum(dim=0) + cu_k = (cu_q + begin).clone() + cu_k[0] = first_doc_begin + result.append((cu_q, cu_k)) + return result + + @functools.cached_property + def expected_num_documents(self) -> list[int | None]: + if self.return_document_count: + return [len(self.tokens) if split_index == 0 else 0 for split_index in range(self.micro_batch_splits)] + else: + return [None] * self.micro_batch_splits + + +_BASE_TEST_CASES = [ + PreprocessingTestConfig( + name="simple", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="negative_tokens", + tokens=[[100, 101, -100, -100, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="loss_masking_span", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(3, 5)]], + ), + PreprocessingTestConfig( + name="negative_tokens_and_loss_masking", + tokens=[[100, 101, 102, 103, -100, -100, 106, 107, 108]], + loss_masking_spans=[[(2, 3)]], + ), + PreprocessingTestConfig( + name="two_documents", + tokens=[[100, 101, -100, 103, -100, -100, 106, 107], [100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(2, 3)], None], + ), + PreprocessingTestConfig( + name="three_documents", + tokens=[[100, 101, 102], [103, 104, 105], [106, 107, 108]], + loss_masking_spans=[[(1, 2)], None, [(0, 2)]], + ), + PreprocessingTestConfig( + # Document of length 1 is shorter than predicted_tokens=3; cross-document masking must not go out of bounds. + name="short_document", + tokens=[[100], [101, 102, 103, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="multiple_loss_masking_spans", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(1, 3), (5, 7)]], + ), + PreprocessingTestConfig( + # use_grpo_data attaches per-token advantages and log-probabilities to the target. + name="grpo_data", + tokens=[[100, 101, 102, 103, 104, 105, 106]], + advantages=[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]], + log_probabilities=[[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7]], + use_grpo_data=True, + ), + PreprocessingTestConfig( + name="two_documents_grpo_data", + tokens=[[100, 101, 102, 103], [104, 105, 106, 107, 108]], + advantages=[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8, 0.9]], + log_probabilities=[[-0.1, -0.2, -0.3, -0.4], [-0.5, -0.6, -0.7, -0.8, -0.9]], + use_grpo_data=True, + ), + PreprocessingTestConfig( + # In inference phase num_labels=0, so the full token sequence is the input and there are no targets. + name="inference", + tokens=[[100, 101, 102, 103, 104]], + phase=PhaseType.inference, + ), +] + +# Each base case is run with each return configuration and both values of micro_batch_splits, +# except inference which has no labels to return or split. +_RETURN_CONFIG_VARIANTS: dict[str, dict] = { + "": {}, + "prediction_mask": {"return_prediction_mask": True}, + "label_counts": {"return_label_counts": True}, + "position_index": {"return_position_index": True}, + "document_count": {"return_document_count": True}, + "cumulative_sequence_lengths": {"return_cumulative_sequence_lengths": True}, + "all": { + "return_prediction_mask": True, + "return_label_counts": True, + "return_position_index": True, + "return_document_count": True, + "return_cumulative_sequence_lengths": True, + }, +} + + +def _make_name( + base_name: str, return_name: str, predicted_tokens: int, micro_batch_splits: int, padding: int | None +) -> str: + parts = [base_name] + if return_name: + parts.append(f"return_{return_name}") + if predicted_tokens > 1: + parts.append(f"predicted_tokens_{predicted_tokens}") + if micro_batch_splits > 1: + parts.append(f"splits_{micro_batch_splits}") + if padding is not None: + parts.append(f"padding_{padding}") + return "_".join(parts) + + +_PREPROCESSING_TEST_CASES = [ + dataclasses.replace( + base_case, + name=_make_name(base_case.name, return_name, predicted_tokens, micro_batch_splits, padding), + predicted_tokens=predicted_tokens, + micro_batch_splits=micro_batch_splits, + padding=padding, + **return_config, + ) + for base_case in _BASE_TEST_CASES + for return_name, return_config in _RETURN_CONFIG_VARIANTS.items() + for predicted_tokens in (1, 3) + for micro_batch_splits in (1, 2) + for padding in (None, 0, 2) + if base_case.phase != PhaseType.inference + or (not return_config and predicted_tokens == 1 and micro_batch_splits == 1) +] @pytest.mark.parametrize( - ("tokens", "loss_masking_spans"), - ( - ([[100, 101, 102, 103, 104, 105, 106, 107]], [None]), # Simple case - ([[100, 101, -100, -100, 104, 105, 106, 107]], [None]), # Negative tokens - ([[100, 101, 102, 103, 104, 105, 106, 107]], [[(3, 5)]]), # Loss masking span - ([[100, 101, 102, 103, -100, -100, 106, 107]], [[(2, 3)]]), # Both - ( - [ - [100, 101, -100, 103, -100, -100, 106, 107], - [100, 101, 102, 103, 104, 105, 106, 107], - ], - [[(2, 3)], None], - ), # Two samples - ), + "test_config", [pytest.param(test_config, id=test_config.name) for test_config in _PREPROCESSING_TEST_CASES] ) -def test_preprocessing(tokens, loss_masking_spans): +def test_preprocessing(test_config: PreprocessingTestConfig): + config = LanguageModelBatchPreprocessingConfig(**test_config.config_kwargs) + documents = [ LanguageModelDocument( - tokens=torch.tensor(tokens_, dtype=torch.int64), - loss_masking_spans=None if loss_masking_spans_ is None else RangeDocument(ranges=loss_masking_spans_), + tokens=torch.tensor(tokens, dtype=torch.int64), + loss_masking_spans=None if spans is None else RangeDocument(ranges=spans), + advantages=None if doc_advantages is None else TokenDataDocument(data=torch.tensor(doc_advantages)), + old_log_probabilities=( + None if doc_log_probs is None else TokenDataDocument(data=torch.tensor(doc_log_probs)) + ), + ) + for tokens, spans, doc_advantages, doc_log_probs in zip( + test_config.tokens, + test_config.loss_masking_spans or [None] * len(test_config.tokens), + test_config.advantages or [None] * len(test_config.tokens), + test_config.log_probabilities or [None] * len(test_config.tokens), + strict=True, ) - for tokens_, loss_masking_spans_ in zip(tokens, loss_masking_spans, strict=True) - ] - - (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs( - LanguageModelBatchPreprocessingConfig() - ) - - Assert.all_equal(model_input.tokens, torch.cat([document.tokens for document in documents])[:-1]) - - label_tokens = [] - for document in documents: - label_tokens_ = document.tokens.clone() - # Mask cross-document attention - label_tokens_[0] = -100 - # Loss masking spans - if document.loss_masking_spans is not None: - for begin, end in document.loss_masking_spans.ranges: - label_tokens_[begin:end] = -100 - label_tokens.append(label_tokens_) - - Assert.eq(len(model_input.targets), 1) - Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:]) - - -def test_preprocessing_padding(): - # 5 real tokens padded to 8: padding tokens (-100) should be masked out of labels. - tokens = [100, 101, 102, 103, 104] - document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) - - batch = LanguageModelBatch.from_documents([document], pad_to_size=8) - (model_input,) = batch.get_model_inputs(LanguageModelBatchPreprocessingConfig()) - - # total_input_length = 8 - 1 = 7; tokens[0:7] = [100, 101, 102, 103, 104, -100, -100] - Assert.all_equal( - model_input.tokens, - torch.tensor([100, 101, 102, 103, 104, -100, -100], dtype=torch.int64), - ) - - # labels = [100, 101, 102, 103, 104, -100, -100, -100], then labels[0]=-100 (cross-doc) - # target = labels[1:8] = [101, 102, 103, 104, -100, -100, -100] - Assert.all_equal( - model_input.targets[0].tokens, - torch.tensor([101, 102, 103, 104, -100, -100, -100], dtype=torch.int64), - ) - - -@pytest.mark.parametrize("predicted_tokens", [1, 2, 3]) -def test_preprocessing_multi_token_prediction(predicted_tokens): - # With predicted_tokens=d, there are d target sets. - # Target for distance d is tokens[d : d + total_input_length]. - # Cross-doc masking for distance d falls at index d-1, just outside each target window. - tokens = list(range(100, 111)) # 11 tokens - document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) - - config = LanguageModelBatchPreprocessingConfig(predicted_tokens=predicted_tokens) - (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) - - total_input = len(tokens) - predicted_tokens - Assert.all_equal(model_input.tokens, torch.tensor(tokens[:total_input], dtype=torch.int64)) - Assert.eq(len(model_input.targets), predicted_tokens) - - for i, target in enumerate(model_input.targets): - d = i + 1 - # Cross-doc masking for all distances <=d falls at indices 0..d-1, outside window [d:d+total_input]. - Assert.all_equal(target.tokens, torch.tensor(tokens[d : d + total_input], dtype=torch.int64)) - - -def test_preprocessing_micro_batch_splits(): - # micro_batch_splits=2 produces two model inputs each covering half the sequence. - tokens = list(range(100, 113)) # 13 tokens → total_input_length=12, each split=6 - document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) - - config = LanguageModelBatchPreprocessingConfig(micro_batch_splits=2) - model_inputs = LanguageModelBatch.from_documents([document]).get_model_inputs(config) - - Assert.eq(len(model_inputs), 2) - Assert.all_equal(model_inputs[0].tokens, torch.tensor(tokens[:6], dtype=torch.int64)) - Assert.all_equal(model_inputs[1].tokens, torch.tensor(tokens[6:12], dtype=torch.int64)) - - # labels[0]=-100 (cross-doc); targets are labels[1:7] and labels[7:13] - Assert.all_equal(model_inputs[0].targets[0].tokens, torch.tensor(tokens[1:7], dtype=torch.int64)) - Assert.all_equal(model_inputs[1].targets[0].tokens, torch.tensor(tokens[7:13], dtype=torch.int64)) - - -def test_preprocessing_prediction_mask(): - # return_prediction_mask exposes the boolean mask of non-masked label positions. - tokens = [100, 101, 102, 103, 104, 105] - document = LanguageModelDocument( - tokens=torch.tensor(tokens, dtype=torch.int64), - loss_masking_spans=RangeDocument(ranges=[(2, 4)]), # mask positions 2 and 3 - ) - - config = LanguageModelBatchPreprocessingConfig(return_prediction_mask=True) - (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) - - # labels = [100, 101, 102, 103, 104, 105] - # after span masking: labels[2:4] = -100 → [100, 101, -100, -100, 104, 105] - # after cross-doc: labels[0] = -100 → [-100, 101, -100, -100, 104, 105] - # target = labels[1:6] = [101, -100, -100, 104, 105] - # mask[1:6] = [True, False, False, True, True] - assert model_input.targets[0].mask is not None - Assert.all_equal( - model_input.targets[0].mask, - torch.tensor([True, False, False, True, True]), - ) - - -def test_preprocessing_label_counts(): - # return_label_counts gives each token the total count of valid labels in its document. - # Two documents each of length 4; cross-doc masking removes the first token of each, - # leaving 3 valid labels per document. - docs = [ - LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), - LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), - ] - - config = LanguageModelBatchPreprocessingConfig(return_label_counts=True) - (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) - - # labels after cross-doc masking: [-100, 101, 102, 103, -100, 201, 202, 203] - # doc1: 3 valid labels (indices 1,2,3); doc2: 3 valid labels (indices 5,6,7) - # target window: labels[1:8] → label_counts[1:8] = [3, 3, 3, 3, 3, 3, 3] - assert model_input.targets[0].label_counts is not None - Assert.all_equal( - model_input.targets[0].label_counts, - torch.full((7,), 3, dtype=model_input.targets[0].label_counts.dtype), - ) - - -def test_preprocessing_grpo_data(): - # use_grpo_data attaches per-token advantages and log-probabilities to the target, - # cropped to the label window [label_begin:label_end]. - tokens = [100, 101, 102, 103, 104, 105] - advantages_data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] - log_probs_data = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6] - - document = LanguageModelDocument( - tokens=torch.tensor(tokens, dtype=torch.int64), - advantages=TokenDataDocument(data=torch.tensor(advantages_data)), - old_log_probabilities=TokenDataDocument(data=torch.tensor(log_probs_data)), - ) - - config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True) - (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) - - # total_input_length=5; label_begin=1, label_end=6 - target = model_input.targets[0] - assert target.advantages is not None - assert target.old_log_probabilities is not None - Assert.rms_close(target.advantages, torch.tensor(advantages_data[1:]), 1e-6) - Assert.rms_close(target.old_log_probabilities, torch.tensor(log_probs_data[1:]), 1e-6) - - -def test_preprocessing_position_index(): - # return_position_index gives the within-document position of each input token, - # resetting to 0 at every document boundary. - docs = [ - LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), # len=4 - LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), # len=4 ] - config = LanguageModelBatchPreprocessingConfig(return_position_index=True) - (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) - - # total_input_length=7; input tokens: [100,101,102,103,200,201,202] - # positions: doc1 → [0,1,2,3], doc2 (first 3 tokens) → [0,1,2] - assert model_input.position_index is not None - Assert.all_equal( - model_input.position_index, - torch.tensor([0, 1, 2, 3, 0, 1, 2], dtype=torch.int32), + batch = LanguageModelBatch.from_documents( + documents, pad_to_size=test_config.padded_size if test_config.padding is not None else None ) - - -def test_preprocessing_inference(): - # In inference phase num_labels=0, so the full token sequence is the input and there are no targets. - tokens = [100, 101, 102, 103, 104] - document = LanguageModelDocument(tokens=torch.tensor(tokens, dtype=torch.int64)) - - config = LanguageModelBatchPreprocessingConfig(phase=PhaseType.inference) - (model_input,) = LanguageModelBatch.from_documents([document]).get_model_inputs(config) - - Assert.all_equal(model_input.tokens, torch.tensor(tokens, dtype=torch.int64)) - Assert.eq(len(model_input.targets), 0) - - -def test_preprocessing_document_count(): - # return_document_count records how many documents are in the batch (first split only). - docs = [ - LanguageModelDocument(tokens=torch.tensor([100, 101, 102], dtype=torch.int64)), - LanguageModelDocument(tokens=torch.tensor([200, 201, 202], dtype=torch.int64)), - ] - - config = LanguageModelBatchPreprocessingConfig(return_document_count=True) - (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) - - Assert.eq(model_input.num_documents, 2) - - -def test_preprocessing_cumulative_sequence_lengths(): - # return_cumulative_sequence_lengths produces cu_seqlens tensors for flash-attention style kernels. - docs = [ - LanguageModelDocument(tokens=torch.tensor([100, 101, 102, 103], dtype=torch.int64)), # len=4 - LanguageModelDocument(tokens=torch.tensor([200, 201, 202, 203], dtype=torch.int64)), # len=4 - ] - - config = LanguageModelBatchPreprocessingConfig(return_cumulative_sequence_lengths=True) - (model_input,) = LanguageModelBatch.from_documents(docs).get_model_inputs(config) - - # total_input_length=7; lengths in this input: [4, 3] (doc2 is cut to 3 by the -1 label offset) - # cumulative_lengths_q = padded_cumsum([4, 3]) = [0, 4, 7] - # cumulative_lengths_k = [0, 4, 7] (sequence_k_past=0, first_document_begin=0) - assert model_input.cumulative_lengths_q is not None - assert model_input.cumulative_lengths_k is not None - Assert.all_equal(model_input.cumulative_lengths_q, torch.tensor([0, 4, 7], dtype=torch.int32)) - Assert.all_equal(model_input.cumulative_lengths_k, torch.tensor([0, 4, 7], dtype=torch.int32)) + model_inputs = batch.get_model_inputs(config) + + # Inference: full token sequence as input, no targets. + if config.phase == PhaseType.inference: + Assert.eq(len(model_inputs), 1) + Assert.all_equal(model_inputs[0].tokens, batch.tokens) + Assert.eq(len(model_inputs[0].targets), 0) + return + + Assert.eq(len(model_inputs), test_config.micro_batch_splits) + for split_index, model_input in enumerate(model_inputs): + Assert.all_equal(model_input.tokens, test_config.expected_input_tokens[split_index]) + Assert.eq(len(model_input.targets), test_config.predicted_tokens) + + for target_index, target in enumerate(model_input.targets): + Assert.all_equal(target.tokens, test_config.expected_target_tokens[split_index][target_index]) + _assert_tensor_equal_or_none(target.mask, test_config.expected_target_mask[split_index][target_index]) + _assert_tensor_equal_or_none( + target.label_counts, test_config.expected_target_label_counts[split_index][target_index] + ) + _assert_tensor_equal_or_none(target.advantages, test_config.expected_advantages[split_index][target_index]) + _assert_tensor_equal_or_none( + target.old_log_probabilities, test_config.expected_log_probabilities[split_index][target_index] + ) + + _assert_tensor_equal_or_none(model_input.position_index, test_config.expected_position_index[split_index]) + cu_q, cu_k = test_config.expected_cumulative_lengths[split_index] + _assert_tensor_equal_or_none(model_input.cumulative_lengths_q, cu_q) + _assert_tensor_equal_or_none(model_input.cumulative_lengths_k, cu_k) + Assert.eq(model_input.num_documents, test_config.expected_num_documents[split_index]) From f3975a5f131d9ab9735a6f7a74ebc16eb79076e2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 12:47:23 -0400 Subject: [PATCH 14/33] Rename FieldUpdate to FieldOverride, convert derived fields to cached_property, overhaul config/data tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename `FieldUpdate` → `FieldOverride` throughout: clearer that it overrides inherited field metadata at class-definition time, distinct from the runtime `UpdateType.update` config-merge mechanism - Convert `Field(init=False, hint=FieldHint.derived)` fields in `DistributedConfig` and `WandbAlertConfig` to `functools.cached_property`, removing computed state from `_validate()` in favour of lazy evaluation - Overhaul `tests/config/`: consolidate fixture configs into `common.py`, replace weak repr/to_logs tests with parametrized checks against explicit expected dicts, restructure `test_field.py` with `FieldTestCase`/`ValidCase` dataclasses, add comprehensive `UpdateType` test cases in `test_update.py` - Replace `result_path` with domain-scoped `data_result_path` fixture in all data tests to avoid `-n 12` worker collisions Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/config.py | 33 +- fast_llm/engine/checkpoint/config.py | 10 +- fast_llm/engine/distributed/config.py | 224 +++++------ fast_llm/engine/training/config.py | 37 +- fast_llm/models/gpt/config.py | 10 +- fast_llm/models/multimodal/config.py | 8 +- tests/config/common.py | 6 + tests/config/test_config.py | 279 +++++++++++++- tests/config/test_field.py | 510 ++++++++++++++++---------- tests/config/test_update.py | 39 +- tests/conftest.py | 2 +- tests/data/common.py | 2 +- tests/data/conftest.py | 8 + tests/data/test_blending.py | 6 +- tests/data/test_concatenate.py | 3 +- tests/data/test_dataset_discovery.py | 6 +- tests/data/test_fim.py | 3 +- tests/data/test_random.py | 3 +- tests/data/test_sampling.py | 185 +++++++++- tests/data/test_slice.py | 3 +- tests/data/test_streaming.py | 12 +- tests/test_config.py | 2 +- 22 files changed, 993 insertions(+), 398 deletions(-) create mode 100644 tests/data/conftest.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 6b947bce5..eeaa6c7d3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -47,10 +47,9 @@ class UpdateType(str, enum.Enum): update = "update" -class FieldHint: +class FieldHint(enum.StrEnum): """ A label defined for each config field, to let the user and some methods know how important each field is. - * core: """ core = "core" @@ -127,7 +126,7 @@ def __init__( *, desc: str | None = None, doc: str | None = None, - hint: str = FieldHint.unknown, + hint: FieldHint = FieldHint.unknown, # Validation function on the field to satisfy. # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). @@ -164,9 +163,9 @@ def __init__( # self.auto_instantiate = auto_instantiate -class FieldUpdate(dict): +class FieldOverride(dict): """ - Specify some entries in the field that should be updated from the base class. + Override some entries in the field inherited from the base class. Useful for changing the default or description in a derived class. Processed in `__init_subclass__`. """ @@ -185,20 +184,6 @@ def valid(x): return valid -def test_field(fn, *args, **kwargs): - """ - Helper function to define a condition that a config field should satisfy, - in the form of a function that returns a boolean. - """ - - def valid(x): - if not fn(x, *args, **kwargs): - raise ValueError(fn, x, args, kwargs) - return x - - return valid - - def process_field(fn, *args, **kwargs): """ Helper function to apply non-standard processing during validation, @@ -536,7 +521,7 @@ def _validate_array(cls, value, type_, name: str): ) else: if not issubclass(origin, tuple) and len(args) != 1: - FieldTypeError(f"Invalid array specification") + raise FieldTypeError(f"Invalid array specification") new_value = origin( cls._validate_nested(value_, args[0], f"{name}[{i}]", None, errors, True) for i, value_ in enumerate(value) @@ -649,8 +634,8 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: - # Exclude class variables and derived fields unless requested explicitly. + if field is not None and (field._field_type != dataclasses._FIELD or (not field.init and not all_fields)): + # Always exclude class variables; exclude derived (init=False) fields unless all_fields=True. return explicit_field = ( field is None @@ -865,7 +850,7 @@ def _from_dict_array(cls, value, type_, strict: bool): new_value += value[len(value) - len(new_value) :] else: if not issubclass(origin, tuple) and len(args) != 1: - FieldTypeError(f"Invalid array specification") + raise FieldTypeError(f"Invalid array specification") new_value = origin(cls._from_dict_nested(value_, args[0], strict) for i, value_ in enumerate(value)) return new_value @@ -973,7 +958,7 @@ def __init_subclass__(cls): for name in list(cls.__dict__): value = getattr(cls, name) - if isinstance(value, FieldUpdate): + if isinstance(value, FieldOverride): # In case of multiple inheritance, the base class field may not appear in `cls.__dataclass_fields__`. # so we iterate over superclasses following mro and use the first match. base_class_field = None diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 98303539e..04b1dff46 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -7,7 +7,7 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldOverride, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -142,8 +142,8 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): _abstract = False - model_weights: bool = FieldUpdate(desc="Save the model weights.") - optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.") + model_weights: bool = FieldOverride(desc="Save the model weights.") + optimizer_state: bool = FieldOverride(desc="Save the optimizer state. Default: save if supported by the `format`.") def _validate(self) -> None: if self.optimizer_state is None and hasattr(self.format, "support_optimizer"): @@ -196,8 +196,8 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): _abstract = False - model_weights: bool = FieldUpdate(desc="Load the model weights.") - optimizer_state: bool = FieldUpdate(default=False, desc="Load the optimizer state.") + model_weights: bool = FieldOverride(desc="Load the model weights.") + optimizer_state: bool = FieldOverride(default=False, desc="Load the optimizer state.") def _validate(self) -> None: super()._validate() diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index a214e8e50..b5ae7b4f5 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,5 +1,6 @@ import dataclasses import enum +import functools import logging import os import typing @@ -172,11 +173,6 @@ class DistributedConfig(Config): pipeline_parallel: int = Field( default=1, desc="Pipeline parallelism group size.", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) ) - data_parallel: int = Field(init=False, desc="Data parallelism group size.", hint=FieldHint.derived) - model_parallel: int = Field( - init=False, desc="Model parallelism group size (tensor * pipeline).", hint=FieldHint.derived - ) - num_nodes: int = Field(init=False, desc="Number of GPU nodes.", hint=FieldHint.derived) sequence_tensor_parallel: bool = Field( default=False, desc="Enable sequence tensor parallelism.", hint=FieldHint.performance ) @@ -186,7 +182,6 @@ class DistributedConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - batch_data_parallel: int = Field(init=False, desc="Batch data parallelism group size.", hint=FieldHint.performance) world_size: int = Field( default=None, desc="Size of the world group, e.e., total number of GPUs. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", @@ -199,23 +194,6 @@ class DistributedConfig(Config): hint=FieldHint.expert, valid=check_field(Assert.geq, 0), ) - data_rank: int = Field(init=False, desc="Data-parallel rank of the local process.", hint=FieldHint.derived) - pipeline_rank: int = Field(init=False, desc="Pipeline-parallel rank of the local process.", hint=FieldHint.derived) - tensor_rank: int = Field(init=False, desc="Tensor-parallel rank of the local process.", hint=FieldHint.derived) - sequence_data_rank: int = Field( - init=False, desc="Sequence-data-parallel rank of the local process.", hint=FieldHint.derived - ) - batch_data_rank: int = Field( - init=False, desc="Batch-data-parallel rank of the local process.", hint=FieldHint.derived - ) - distributed_dims: dict[str, DistributedDim] = Field( - init=False, desc="The `DistributedDim` objects for the distributed dimensions.", hint=FieldHint.derived - ) - local_rank: int = Field( - init=False, - desc="The rank of the process on the current node.", - hint=FieldHint.derived, - ) local_world_size: int = Field( default=None, desc="Number of GPUs in each node. Typically provided by torchrun or equivalent through the `LOCAL_WORLD_SIZE` environment variable.", @@ -310,6 +288,112 @@ class DistributedConfig(Config): hint=FieldHint.derived, ) + @functools.cached_property + def model_parallel(self) -> int: + return self.tensor_parallel * self.pipeline_parallel + + @functools.cached_property + def data_parallel(self) -> int: + return div(self.world_size, self.model_parallel) + + @functools.cached_property + def num_nodes(self) -> int: + return div(self.world_size, self.local_world_size) + + @functools.cached_property + def local_rank(self) -> int: + return self.rank % self.local_world_size + + @functools.cached_property + def tensor_rank(self) -> int: + return self.rank % self.tensor_parallel + + @functools.cached_property + def data_rank(self) -> int: + if self.pipeline_first: + # Smaller models can be more demanding on pipeline parallel. + return (self.rank // self.tensor_parallel) // self.pipeline_parallel + else: + # Larger models are more demanding on data parallel. + return (self.rank // self.tensor_parallel) % self.data_parallel + + @functools.cached_property + def pipeline_rank(self) -> int: + if self.pipeline_first: + return (self.rank // self.tensor_parallel) % self.pipeline_parallel + else: + return (self.rank // self.tensor_parallel) // self.data_parallel + + @functools.cached_property + def batch_data_parallel(self) -> int: + return div(self.data_parallel, self.sequence_data_parallel) + + @functools.cached_property + def sequence_data_rank(self) -> int: + return self.data_rank % self.sequence_data_parallel + + @functools.cached_property + def batch_data_rank(self) -> int: + return self.data_rank // self.sequence_data_parallel + + @functools.cached_property + def distributed_dims(self) -> dict[str, "DistributedDim"]: + if self.reference_config is not None: + return self.reference_config.distributed_dims + dims: dict[str, DistributedDim] = {} + tensor_stride = 1 + sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + batch_data_stride = sequence_data_stride * self.sequence_data_parallel + pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) + self._add_distributed_dim_from_sizes_and_strides(dims, DistributedDimNames.world, (self.world_size, 1)) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.data, + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.sequence_data, (self.sequence_data_parallel, sequence_data_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.tensor_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.tensor_and_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.model_and_sequence_data, + (self.tensor_parallel, tensor_stride), + ( + (self.pipeline_parallel, pipeline_stride) + if self.pipeline_first + else (self.sequence_data_parallel, sequence_data_stride) + ), + ( + (self.sequence_data_parallel, sequence_data_stride) + if self.pipeline_first + else (self.pipeline_parallel, pipeline_stride) + ), + ) + return dims + def _validate(self) -> None: if self.world_size is None: self.world_size = self.default_world_size @@ -317,112 +401,36 @@ def _validate(self) -> None: self.rank = self.default_rank if self.local_world_size is None: self.local_world_size = self.default_local_world_size - self.model_parallel = self.tensor_parallel * self.pipeline_parallel - self.data_parallel = div(self.world_size, self.model_parallel) - self.num_nodes = div(self.world_size, self.local_world_size) - self.local_rank = self.rank % self.local_world_size - Assert.multiple(self.local_world_size, self.tensor_parallel) - - if self.pipeline_first: - # Smaller models can be more demanding on pipeline parallel. - self.data_rank = (self.rank // self.tensor_parallel) // self.pipeline_parallel - self.pipeline_rank = (self.rank // self.tensor_parallel) % self.pipeline_parallel - else: - # Larger models are more demanding on data parallel. - self.data_rank = (self.rank // self.tensor_parallel) % self.data_parallel - self.pipeline_rank = (self.rank // self.tensor_parallel) // self.data_parallel - self.sequence_data_rank = self.data_rank % self.sequence_data_parallel - self.batch_data_parallel = div(self.data_parallel, self.sequence_data_parallel) - self.batch_data_rank = self.data_rank // self.sequence_data_parallel - - self.tensor_rank = self.rank % self.tensor_parallel if self.tensor_parallel == 1 and self.sequence_tensor_parallel: self.sequence_tensor_parallel = False - if self.reference_config is not None: self.reference_config.validate() if self.reference_config.reference_config is not None: self.reference_config = self.reference_config.reference_config assert self.reference_config.reference_config is None - self.distributed_dims = self.reference_config.distributed_dims - else: - self.distributed_dims = {} - - tensor_stride = 1 - sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) - batch_data_stride = sequence_data_stride * self.sequence_data_parallel - pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.world, - (self.world_size, 1), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.data, - (self.sequence_data_parallel, sequence_data_stride), - (self.batch_data_parallel, batch_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.sequence_data, - (self.sequence_data_parallel, sequence_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor_and_sequence_data, - (self.tensor_parallel, tensor_stride), - (self.sequence_data_parallel, sequence_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor_and_data, - (self.tensor_parallel, tensor_stride), - (self.sequence_data_parallel, sequence_data_stride), - (self.batch_data_parallel, batch_data_stride), - ) - - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.model_and_sequence_data, - (self.tensor_parallel, tensor_stride), - ( - (self.pipeline_parallel, pipeline_stride) - if self.pipeline_first - else (self.sequence_data_parallel, sequence_data_stride) - ), - ( - (self.sequence_data_parallel, sequence_data_stride) - if self.pipeline_first - else (self.pipeline_parallel, pipeline_stride) - ), - ) - super()._validate() + Assert.multiple(self.local_world_size, self.tensor_parallel) if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None: - self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) + def _add_distributed_dim_from_sizes_and_strides( + self, dims: dict[str, DistributedDim], name: str, *sizes_and_strides: tuple[int, int] + ) -> None: + self._add_distributed_dim(dims, DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) - def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + def _add_distributed_dim(self, dims: dict[str, DistributedDim], distributed_dim: DistributedDim) -> None: Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) - try: distributed_dim.check_ranks_in_range(0, self.world_size) except: logger.info(str(self)) raise - if distributed_dim.name in self.distributed_dims: - Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) + if distributed_dim.name in dims: + Assert.eq(distributed_dim, dims[distributed_dim.name]) else: - self.distributed_dims[distributed_dim.name] = distributed_dim + dims[distributed_dim.name] = distributed_dim def get_distributed_dim(self, name: str) -> DistributedDim: return self.distributed_dims[name] diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index d5c6fbc7c..fecee4615 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -1,4 +1,5 @@ import abc +import functools import os import pathlib import shlex @@ -10,7 +11,7 @@ Configurable, Field, FieldHint, - FieldUpdate, + FieldOverride, NoAutoValidate, check_field, config_class, @@ -71,12 +72,12 @@ def run(self) -> None: @config_class() class WandbAlertConfig(IntervalConfig): - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each Wandb status post (alert)." " Setting to None will disable iteration-based wandb alerts." " Must be a sub-interval of the logging interval." ) - offset = FieldUpdate( + offset = FieldOverride( desc="Offset for the first Wandb status post (alert)." " Must be compatible with the logging offset.", ) status_updates: bool | None = Field( @@ -85,22 +86,20 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False) - def _validate(self) -> None: - if self.status_updates is None: - self.post_alerts = self.enabled() - super()._validate() + @functools.cached_property + def post_alerts(self) -> bool: + return self.status_updates if self.status_updates is not None else self.enabled() @config_class() class MetricsLogsConfig(IntervalConfig): - interval = FieldUpdate( + interval = FieldOverride( default=100, desc="The number of training iterations between each metric logs." " Setting to None will disable metric logging.", ) - offset = FieldUpdate(desc="Offset for the first metric logs.") + offset = FieldOverride(desc="Offset for the first metric logs.") @config_class() @@ -159,12 +158,12 @@ def to_delete(self, iterations: list[int]) -> list[int]: class TrainingCheckpointConfig(TrainingCheckpointBaseConfig): _abstract = False save_name: typing.ClassVar[str] = "checkpoint" - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each checkpoint. Setting to None will disable checkpoints." ) - offset = FieldUpdate(desc="Offset for the first checkpoint.") - callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after checkpoint.") - keep: int | None = FieldUpdate(default=5) + offset = FieldOverride(desc="Offset for the first checkpoint.") + callback: CallbackConfig = FieldOverride(desc="Callback (shell script) to run after checkpoint.") + keep: int | None = FieldOverride(default=5) def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "checkpoint" @@ -192,11 +191,11 @@ def get_load_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConfigBase): _abstract = False save_name: typing.ClassVar[str] = "export" - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each export." " Setting to None will disable exports." ) - offset = FieldUpdate(desc="Offset for the first export.") - callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") + offset = FieldOverride(desc="Offset for the first export.") + callback: CallbackConfig = FieldOverride(desc="Callback (shell script) to run after export.") def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name @@ -207,12 +206,12 @@ def get_save_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi @config_class() class ShutdownConfig(IntervalConfig): - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each automated shutdown." " Setting to None will disable automated shutdowns." " Must be a sub-interval of the checkpoint interval." ) - offset = FieldUpdate( + offset = FieldOverride( desc="Offset for the first automated shutdown." " Must be compatible with the checkpoint offset." ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 16222b3c5..72cace032 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,7 +1,7 @@ import logging import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Field, FieldHint, FieldOverride, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -58,7 +58,7 @@ def base_model_class(self) -> type["GPTBaseModel"]: class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate() + base_model: GPTBaseModelConfig = FieldOverride() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, LlamaCheckpointFormat, @@ -94,14 +94,14 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate() + model: GPTModelConfig = FieldOverride() @config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() + data: GPTDataConfig = FieldOverride() # TODO: Use dynamic model type? - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() + reference_models: dict[str, PretrainedGPTModelConfig] = FieldOverride() def _validate(self) -> None: if self.model.base_model.use_megatron_initialization: diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index becdcacbb..15d62ad9e 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -1,7 +1,7 @@ import logging import typing -from fast_llm.config import FieldUpdate, config_class +from fast_llm.config import FieldOverride, config_class from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig @@ -40,7 +40,7 @@ def base_model_class(self) -> type["MultiModalBaseModel"]: class MultiModalModelConfig(GPTModelConfig): _abstract = False model_name: typing.ClassVar[str] = "multimodal" - base_model: MultiModalBaseModelConfig = FieldUpdate() + base_model: MultiModalBaseModelConfig = FieldOverride() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat, @@ -69,13 +69,13 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModa @config_class() class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): _abstract = False - model: MultiModalModelConfig = FieldUpdate() + model: MultiModalModelConfig = FieldOverride() @config_class(dynamic_type={RunnableConfig: "train_multimodal", TrainerConfig: "multimodal"}) class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): # TODO: Use dynamic model type? - reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() + reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldOverride() @classmethod def get_trainer_class(cls) -> type["MultiModalTrainer"]: diff --git a/tests/config/common.py b/tests/config/common.py index b341bd0cb..4b54fbc5e 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -33,6 +33,12 @@ class ExampleConfig(Config): core_field: int = Field(default=4, hint=FieldHint.core) complex_field: dict[str, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + @classmethod + def _from_dict(cls, default: dict, strict: bool = True): + cls._handle_renamed_field(default, "old_int_field", "int_field") + cls._handle_renamed_field(default, "original_float_field", "float_field", fn=lambda value: value * 2) + return super()._from_dict(default, strict) + def _validate(self) -> None: with self._set_implicit_default(): if self.implicit_field is None: diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 4c473fa6d..4f7722c00 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -1,7 +1,54 @@ +import typing + import pytest +import yaml + +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldVerboseLevel, + NoAutoValidate, + UpdateType, + ValidationError, + config_class, +) +from fast_llm.utils import Assert, check_equal_nested, header +from tests.config.common import ExampleConfig, ExampleNestedConfig + +# --- Dynamic dispatch fixtures --- + + +@config_class(registry=True) +class AnimalConfig(Config): + name: str = Field(default="", hint=FieldHint.optional) + + +@config_class(dynamic_type={AnimalConfig: "dog"}) +class DogConfig(AnimalConfig): + breed: str = Field(default="mutt", hint=FieldHint.optional) + + +@config_class(dynamic_type={AnimalConfig: "cat"}) +class CatConfig(AnimalConfig): + indoor: bool = Field(default=True, hint=FieldHint.optional) + -from fast_llm.config import NoAutoValidate -from tests.config.common import ExampleConfig +# --- Verbose level fixtures --- + + +@config_class() +class ExampleHintConfig(Config): + """One field at each FieldHint importance level for testing verbose output filtering.""" + + core_field: int = Field(default=1, hint=FieldHint.core) + architecture_field: int = Field(default=2, hint=FieldHint.architecture) + optional_field: int = Field(default=3, hint=FieldHint.optional) + performance_field: int = Field(default=4, hint=FieldHint.performance) + expert_field: int = Field(default=5, hint=FieldHint.expert) + + +# --- Lifecycle --- def test_auto_validate(): @@ -28,3 +75,231 @@ def test_auto_validate(): assert not (config := ExampleConfig.from_dict({}))._validated config.validate() assert config._validated + + +def test_multiple_validation_errors_all_reported(): + with pytest.raises(ValidationError) as exc_info: + ExampleConfig.from_dict({"int_field": "not_an_int", "float_field": "not_a_float"}) + error_message = str(exc_info.value) + assert "int_field" in error_message + assert "float_field" in error_message + + +# --- compare() --- + + +def test_compare_equal_returns_none(): + config_a = ExampleConfig.from_dict({"int_field": 5}) + config_b = ExampleConfig.from_dict({"int_field": 5}) + assert config_a.compare(config_b) is None + + +def test_compare_different(): + config_a = ExampleConfig.from_dict({"int_field": 5}) + config_b = ExampleConfig.from_dict({"int_field": 7}) + with pytest.raises(ValueError): + config_a.compare(config_b) + # Custom log_fn receives the difference instead of raising. + messages = [] + config_a.compare(config_b, log_fn=messages.append) + assert len(messages) == 1 + + +# --- strict mode --- + + +@pytest.mark.parametrize( + ("config_dict", "cls"), + [ + ({"int_field": 3, "unknown_field": 5}, ExampleConfig), + ({"nested_field": {"int_field": 3, "unknown_sub_field": 5}}, ExampleNestedConfig), + ], + ids=["top_level", "nested"], +) +def test_strict_unknown_field_raises(config_dict, cls): + with pytest.raises(ValidationError): + cls.from_dict(config_dict) + + +def test_strict_false_unknown_field_ignored(): + config = ExampleConfig.from_dict({"int_field": 3, "unknown_field": 5}, strict=False) + assert config.int_field == 3 + assert not hasattr(config, "unknown_field") + + +def test_strict_false_unknown_nested_field_ignored(): + config = ExampleNestedConfig.from_dict({"nested_field": {"int_field": 3, "unknown_sub_field": 5}}, strict=False) + assert config.nested_field.int_field == 3 + + +# --- Dynamic dispatch --- + + +@pytest.mark.parametrize( + ("input_dict", "expected_cls", "expected_field", "expected_value"), + [ + ({"type": "dog", "breed": "labrador"}, DogConfig, "breed", "labrador"), + ({"type": "cat", "indoor": False}, CatConfig, "indoor", False), + ], + ids=["dog", "cat"], +) +def test_dynamic_dispatch_selects_subclass(input_dict, expected_cls, expected_field, expected_value): + config = AnimalConfig.from_dict(input_dict) + assert isinstance(config, expected_cls) + Assert.eq(getattr(config, expected_field), expected_value) + + +def test_dynamic_dispatch_type_serialized(): + config = DogConfig.from_dict({"breed": "poodle"}) + result = config.to_dict() + Assert.eq(result["type"], "dog") + Assert.eq(result["breed"], "poodle") + + +def test_dynamic_dispatch_unknown_type_raises(): + with pytest.raises(ValidationError): + AnimalConfig.from_dict({"type": "fish"}) + + +def test_dynamic_dispatch_roundtrip(): + original = DogConfig.from_dict({"breed": "husky"}) + roundtrip = AnimalConfig.from_dict(original.to_dict()) + assert isinstance(roundtrip, DogConfig) + Assert.eq(roundtrip.breed, "husky") + + +# --- Renamed fields --- + + +def test_renamed_field(): + with pytest.warns(DeprecationWarning, match="old_int_field"): + config = ExampleConfig.from_dict({"old_int_field": 5}) + Assert.eq(config.int_field, 5) + # New name works without a deprecation warning. + Assert.eq(ExampleConfig.from_dict({"int_field": 7}).int_field, 7) + + +def test_renamed_field_with_transform(): + with pytest.warns(DeprecationWarning, match="original_float_field"): + config = ExampleConfig.from_dict({"original_float_field": 4.0}) + Assert.eq(config.float_field, 8.0) + + +# --- Verbose levels --- + +# At verbose >= optional (10), the base Config.type field (hint=feature, importance=10) also appears. +_VERBOSE_LEVEL_CASES = [ + (FieldVerboseLevel.explicit, {}), + (FieldVerboseLevel.core, {"core_field": 1, "architecture_field": 2}), + (FieldVerboseLevel.optional, {"core_field": 1, "architecture_field": 2, "optional_field": 3, "type": None}), + ( + FieldVerboseLevel.performance, + {"core_field": 1, "architecture_field": 2, "optional_field": 3, "performance_field": 4, "type": None}, + ), + ( + FieldVerboseLevel.debug, + { + "core_field": 1, + "architecture_field": 2, + "optional_field": 3, + "performance_field": 4, + "expert_field": 5, + "type": None, + }, + ), +] + + +@pytest.mark.parametrize(("verbose", "expected"), _VERBOSE_LEVEL_CASES) +def test_verbose_level(verbose, expected): + check_equal_nested(ExampleHintConfig.from_dict({}).to_dict(verbose=verbose), expected) + + +# --- Field definition error fixtures --- + + +with pytest.raises(ValueError, match="default_factory"): + # Defining this at module level triggers Field.__init__ validation immediately. + @config_class() + class _BothDefaultAndFactoryConfig(Config): + x: list = Field(default=[], default_factory=list, hint=FieldHint.optional) + + +with pytest.raises(ValueError, match="default_factory"): + + @config_class() + class _ConfigAsDefaultFactoryConfig(Config): + nested: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.optional) + + +with pytest.raises(TypeError, match="__post_init__"): + + @config_class() + class _PostInitConfig(Config): + def __post_init__(self): + pass + + +@config_class() +class _AbstractConfig(Config): + _abstract: typing.ClassVar[bool] = True + + +# --- Abstract config --- + + +def test_abstract_config_raises(): + with pytest.raises(ValidationError, match="abstract"): + _AbstractConfig() + + +# --- Delete on validated config --- + + +def test_delattr_after_validation_raises(): + config = ExampleConfig.from_dict({}) + with pytest.raises(RuntimeError, match="delete"): + del config.int_field + + +# --- to_logs / __repr__ --- + + +@pytest.mark.parametrize( + ("cls", "config_dict", "expected_core_dict"), + [ + (ExampleConfig, {}, {"core_field": 4}), + (ExampleConfig, {"int_field": 3}, {"int_field": 3, "core_field": 4}), + ( + ExampleConfig, + {"int_field": 3, "str_field": "hello"}, + {"int_field": 3, "str_field": "hello", "core_field": 4}, + ), + ( + ExampleNestedConfig, + {"nested_field": {"int_field": 5}}, + {"core_field": 4, "nested_field": {"int_field": 5, "core_field": 4}}, + ), + ], +) +def test_repr_and_to_logs(cls, config_dict, expected_core_dict): + config = cls.from_dict(config_dict) + expected = ( + f"\n{header(config._get_class_name(), 80, '-')}" + f"\n{yaml.safe_dump(expected_core_dict, sort_keys=False)}" + f"{header('end', 80, '-')}" + ) + Assert.eq(repr(config), expected) + messages = [] + config.to_logs(log_fn=messages.append) + Assert.eq(len(messages), 1) + Assert.eq(messages[0], expected) + + +# --- Validated config as update --- + + +def test_validated_config_as_update_raises(): + validated = ExampleConfig.from_dict({"int_field": 1}) + with pytest.raises(ValueError, match="Validated"): + ExampleConfig.from_dict({}, validated, update_type=UpdateType.update) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index bc5881167..2a49c8c60 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -1,203 +1,331 @@ +import dataclasses +import functools import math import pathlib +from typing import Any import numpy import pytest -from fast_llm.config import FieldVerboseLevel +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldOverride, + FieldVerboseLevel, + check_field, + config_class, + process_field, + skip_valid_if_none, +) from fast_llm.utils import Assert, check_equal_nested -from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig, check_config, check_invalid_config - - -def test_create_and_serialize_config(): - Assert.eq(ExampleConfig.from_dict({}).to_dict(), {}) - - -@pytest.mark.parametrize("value", (0, -6, 3)) -def test_int_field(value): - check_config({"int_field": value}) - - -@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4], True)) -def test_int_field_invalid(value): - check_invalid_config({"int_field": value}) - - -@pytest.mark.parametrize("value", (True, False)) -def test_bool_field(value): - check_config({"bool_field": value}) - - -@pytest.mark.parametrize("value", (1, "True", None, [True])) -def test_bool_field_invalid(value): - check_invalid_config({"bool_field": value}) - - -@pytest.mark.parametrize("value", ("", "text", "1")) -def test_str_field(value): - check_config({"str_field": str(value)}, {"str_field": value}) - - -@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a)) -def test_str_field_invalid(value): - check_invalid_config({"str_field": value}) - - -@pytest.mark.parametrize("value", (".", "text", "/a/b/c.d")) -def test_path_field(value): - check_config({"path_field": pathlib.Path(value)}, {"path_field": value}) - - -@pytest.mark.parametrize("value", (1, True, None, [pathlib.Path("a")])) -def test_path_field_invalid(value): - check_invalid_config({"path_field": value}) - - -@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, math.nan)) -def test_float_field(value): - check_config( - {"float_field": float(value)}, {"float_field": value}, serialized_config={"float_field": float(value)} - ) - - -@pytest.mark.parametrize("value", (None, [4.7], "0.0", True, numpy.float64(3))) -def test_float_field_invalid(value): - check_invalid_config({"float_field": value}) - - -@pytest.mark.parametrize("value", ("", None, "text")) -def test_optional_field(value): - check_config({"optional_field": value}) - - -@pytest.mark.parametrize("value", (True, 6, [None])) -def test_optional_field_invalid(value): - check_invalid_config({"optional": value}) - - -@pytest.mark.parametrize("value", ("", 0, "text", 7)) -def test_union_field(value): - check_config({"union_field": value}) - - -@pytest.mark.parametrize("value", (6.0, [""], True)) -def test_union_field_invalid(value): - check_invalid_config({"optional": value}) - - -def test_implicit_field_value(): - Assert.eq(ExampleConfig.from_dict({}).implicit_field, "implicit") - - -@pytest.mark.parametrize("value", ("implicit", "", "text")) -def test_implicit_field(value): - check_config({"implicit_field": value}) - - -ARRAY_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) -ARRAY_VALUES_INVALID = (6.0, {}, True, "text") - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_list_field(value): - check_config( - {"list_field": list(value)}, - {"list_field": value}, - serialized_config={"list_field": list(value)}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_list_field_invalid(value): - check_invalid_config({"list_field": value}) - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_tuple_field(value): - check_config( - {"tuple_field": list(value)}, - {"tuple_field": value}, - serialized_config={"tuple_field": list(value)}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_tuple_field_invalid(value): - check_invalid_config({"tuple_field": value}) - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_set_field(value): - check_config( - {"set_field": list(set(value))}, - {"set_field": set(value)}, - {"set_field": list(value)}, - {"set_field": tuple(value)}, - serialized_config={"set_field": list(set(value))}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_tuple_field_invalid(value): - check_invalid_config({"set_field": value}) - - -@pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) -def test_dict_field(value): - check_config({"dict_field": value}) - - -@pytest.mark.parametrize("value", ({True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text")) -def test_dict_field_invalid(value): - check_invalid_config({"dict_field": value}) +from tests.config.common import ( + ExampleConfig, + ExampleEnum, + ExampleNestedConfig, + ExampleVerboseConfig, + check_config, + check_invalid_config, +) -class IntClass(int): +class IntSubclass(int): pass -@pytest.mark.parametrize("value", (int, bool, IntClass)) -def test_type_field(value): - check_config({"type_field": value}, serialized_config={"type_field": str(value)}) - - -@pytest.mark.parametrize("value", (5, None, [], "text")) -def test_type_field_invalid(value): - check_invalid_config({"type_field": value}) +# --- Validator configs (referenced in _FIELD_TEST_CASES) --- -@pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) -def test_enum_field(value): - check_config({"enum_field": value}, {"enum_field": str(value)}) +@config_class() +class ExampleCheckFieldConfig(Config): + positive_field: int = Field(default=0, hint=FieldHint.optional, valid=check_field(Assert.geq, 0)) -@pytest.mark.parametrize("value", (5, None, [], "text")) -def test_enum_field_invalid(value): - check_invalid_config({"type_field": value}) +@config_class() +class ExampleSkipIfNoneConfig(Config): + optional_positive_field: int | None = Field( + default=None, + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) -def test_core_field(): - Assert.eq(ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core), {"core_field": 4}) +@config_class() +class ExampleProcessFieldConfig(Config): + doubled_field: int = Field(default=0, hint=FieldHint.optional, valid=process_field(lambda value: value * 2)) + + +# --- FieldOverride configs --- + + +@config_class() +class ExampleUpdatedDefaultConfig(ExampleConfig): + int_field = FieldOverride(default=42) + + +@config_class() +class ExampleUpdatedHintConfig(ExampleConfig): + # Promote str_field from optional to core so it appears at verbose=core. + str_field = FieldOverride(hint=FieldHint.core) + + +@dataclasses.dataclass +class ValidCase: + # Canonical Python-side value. Used as input to from_dict() and as expected to_dict(serialized=False) result. + internal: Any + # Expected to_dict() result. Defaults to internal. + serialized: Any = None + # Other input values that should produce the same internal/serialized result. + alternates: tuple = () + + def __post_init__(self): + if self.serialized is None: + self.serialized = self.internal + + +@dataclasses.dataclass +class FieldTestCase: + field_name: str + valid: list[ValidCase] + invalid: list[Any] + cls: type = ExampleConfig + # When the config class has other fields with non-empty defaults, check only this field. + fields: list[str] | None = None + + @functools.cached_property + def params(self) -> list: + return [ + *( + pytest.param( + self.field_name, + self.cls, + valid_case.internal, + valid_case, + self.fields, + id=f"{self.field_name}-{valid_case.internal!r}", + ) + for valid_case in self.valid + ), + *( + pytest.param( + self.field_name, + self.cls, + invalid_value, + None, + self.fields, + id=f"{self.field_name}-invalid-{invalid_value!r}", + ) + for invalid_value in self.invalid + ), + ] + + +_FIELD_TEST_CASES: list[FieldTestCase] = [ + FieldTestCase( + field_name="int_field", + valid=[ValidCase(0), ValidCase(-6), ValidCase(3)], + # Rejects float (even if whole number), bool, string, None, list. + invalid=[4.0, math.inf, "1", None, [4], True], + ), + FieldTestCase( + field_name="bool_field", + valid=[ValidCase(True), ValidCase(False)], + # Rejects int (bool is a subclass of int but the reverse is not accepted), string, None, list. + invalid=[1, "True", None, [True]], + ), + FieldTestCase( + field_name="str_field", + valid=[ValidCase(""), ValidCase("text"), ValidCase("1")], + # Rejects int, bool, None, list, Path, Enum. + invalid=[1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a], + ), + FieldTestCase( + field_name="path_field", + valid=[ + # Stores as pathlib.Path; serializes to string; accepts string input. + ValidCase(pathlib.Path("."), serialized=".", alternates=(".",)), + ValidCase(pathlib.Path("text"), serialized="text", alternates=("text",)), + ValidCase(pathlib.Path("/a/b/c.d"), serialized="/a/b/c.d", alternates=("/a/b/c.d",)), + ], + # Rejects int, bool, None, list. + invalid=[1, True, None, [pathlib.Path("a")]], + ), + FieldTestCase( + field_name="float_field", + valid=[ + # Accepts int and float; stores and serializes as float; inf and nan are valid. + ValidCase(4.0), + ValidCase(math.pi), + ValidCase(math.inf), + ValidCase(math.nan), + ValidCase(3.0, alternates=(3,)), # int input coerced to float + ], + # Rejects None, list, string, bool, numpy scalar. + invalid=[None, [4.7], "0.0", True, numpy.float64(3)], + ), + FieldTestCase( + field_name="optional_field", + valid=[ValidCase(None), ValidCase(""), ValidCase("text")], + # Rejects bool, int, list. + invalid=[True, 6, [None]], + ), + FieldTestCase( + field_name="union_field", + valid=[ValidCase(""), ValidCase(0), ValidCase("text"), ValidCase(7)], + # Rejects float, list, bool. + invalid=[6.0, [""], True], + ), + FieldTestCase( + field_name="implicit_field", + valid=[ + # _validate() sets "implicit" when not provided; explicit value overrides. + ValidCase("implicit"), + ValidCase(""), + ValidCase("text"), + ], + invalid=[], # Any string is valid; invalids are covered by str_field tests. + ), + FieldTestCase( + field_name="list_field", + valid=[ + # Stores as list; accepts tuple input; duplicates preserved. + ValidCase([]), + ValidCase([1], alternates=((1,),)), + ValidCase([3, 4, 6], alternates=((3, 4, 6),)), + ValidCase([4, 5, 4], alternates=((4, 5, 4),)), + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="tuple_field", + valid=[ + # Stores as tuple; serializes as list; accepts list or tuple input. + ValidCase([], serialized=[], alternates=((),)), + ValidCase([1], serialized=[1], alternates=((1,),)), + ValidCase([3, 4, 6], serialized=[3, 4, 6], alternates=((3, 4, 6),)), + ValidCase([4, 5, 4], serialized=[4, 5, 4], alternates=((4, 5, 4),)), + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="set_field", + valid=[ + # Deduplicates; serializes as list; accepts list/tuple/set input. + # Note: CPython iterates small-int sets in insertion/hash order, matching sorted order here. + ValidCase([], serialized=[], alternates=(set(), ())), + ValidCase([1], serialized=[1], alternates=({1}, (1,))), + ValidCase([3, 4, 6], serialized=[3, 4, 6], alternates=({3, 4, 6}, (3, 4, 6))), + ValidCase([4, 5], serialized=[4, 5], alternates=({4, 5}, [4, 5, 4], (4, 5, 4))), # deduplication + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="dict_field", + valid=[ValidCase({}), ValidCase({1: 2, 3: 4})], + # Rejects bool keys, wrong value types, nested dict values, None, int, set, list, string. + invalid=[{True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text"], + ), + FieldTestCase( + field_name="type_field", + valid=[ + # Accepts type objects that are subclasses of int; serializes as repr string. + ValidCase(int, serialized=str(int)), + ValidCase(bool, serialized=str(bool)), + ValidCase(IntSubclass, serialized=str(IntSubclass)), + ], + # Rejects non-type values. + invalid=[5, None, [], "text"], + ), + FieldTestCase( + field_name="enum_field", + valid=[ + # Accepts enum values and their string equivalents; serializes as string. + ValidCase(ExampleEnum.a, serialized="a", alternates=("a",)), + ValidCase(ExampleEnum.b, serialized="b", alternates=("b",)), + ValidCase(ExampleEnum.c, serialized="c", alternates=("c",)), + ], + # Rejects non-string, None, list, and strings not in the enum. + invalid=[5, None, [], "d"], + ), + FieldTestCase( + field_name="complex_field", + valid=[ + ValidCase({}), + ValidCase({"3": None, "text": [], "0": [["", 3], ["a", -7]]}), + ValidCase({"0": [[".", 8]]}), + ], + # Rejects non-string dict keys. + invalid=[{False: [["", 3]]}], + ), + FieldTestCase( + field_name="tuple_fixed_length_field", + valid=[ + # Fixed-length (int, str) tuple; stores and serializes as list; accepts list or tuple input. + ValidCase([0, ""], alternates=((0, ""),)), + ValidCase([5, "text"], alternates=((5, "text"),)), + ValidCase([7, "True"], alternates=((7, "True"),)), + ], + # Rejects wrong length (too short/long) and wrong element types. + invalid=[(), (5,), ("", 0), ("0", "True"), (0, "", "text")], + cls=ExampleVerboseConfig, + fields=["tuple_fixed_length_field"], + ), + FieldTestCase( + field_name="nested_field", + valid=[ + # Non-empty sub-configs only: empty nested config serializes back to {} (no nested_field key). + ValidCase({"int_field": 3}), + ValidCase({"int_field": 3, "str_field": "text"}), + ValidCase({"list_field": [1, 2], "dict_field": {1: 2}}), + ], + # Rejects None, non-dict, and dicts with invalid sub-field values. + invalid=[None, 5, {"int_field": "not_an_int"}, {"int_field": True}], + cls=ExampleNestedConfig, + ), + FieldTestCase( + field_name="positive_field", + valid=[ValidCase(0), ValidCase(5)], + # Rejects values failing check_field(>=0); type invalids already covered by int_field tests. + invalid=[-1], + cls=ExampleCheckFieldConfig, + ), + FieldTestCase( + field_name="optional_positive_field", + valid=[ValidCase(None), ValidCase(0), ValidCase(5)], + # Rejects negative values; None bypasses the validator (skip_valid_if_none). + invalid=[-1], + cls=ExampleSkipIfNoneConfig, + ), +] @pytest.mark.parametrize( - "value", - ( - {}, - {"3": None, "text": [], "0": [["", 3], ["a", -7]]}, - {"0": [[".", 8]]}, - ), + ("field_name", "cls", "value", "expected", "fields"), + [case for field_test_case in _FIELD_TEST_CASES for case in field_test_case.params], ) -def test_complex_field(value): - check_config({"complex_field": value}) +def test_field(field_name: str, cls: type, value: Any, expected: ValidCase | None, fields: list[str] | None): + if expected is None: + check_invalid_config({field_name: value}, cls=cls) + else: + check_config( + {field_name: value}, + *({field_name: alternate} for alternate in expected.alternates), + serialized_config={field_name: expected.serialized}, + cls=cls, + fields=fields, + ) -@pytest.mark.parametrize( - "value", - ({"3": None, "text": [], False: [["", 3], ["a", -7]]},), -) -def test_complex_field_invalid(value): - check_invalid_config({"complex_field": value}) +def test_implicit_field_value(): + # When implicit_field is not provided, _validate() fills it in as "implicit". + config = ExampleConfig.from_dict({}) + Assert.eq(config.implicit_field, "implicit") + # Implicitly-set fields are not included in the serialized dict; all other fields are default, + # so the empty config serializes back to {}. + Assert.eq(config.to_dict(), {}) def test_verbose_config_default(): @@ -214,21 +342,23 @@ def test_verbose_config_default(): check_equal_nested(config.to_dict(serialized=False), default_values) -@pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) -def test_tuple_fixed_length_field(value): - check_config( - {"tuple_fixed_length_field": list(value)}, - {"tuple_fixed_length_field": value}, - serialized_config={"tuple_fixed_length_field": list(value)}, - cls=ExampleVerboseConfig, - fields=["tuple_fixed_length_field"], - ) +def test_nested_field_empty(): + # An empty sub-config is accepted; sub-fields take their defaults. + config = ExampleNestedConfig.from_dict({"nested_field": {}}) + Assert.eq(config.nested_field.int_field, 0) + Assert.eq(config.nested_field.str_field, "") + + +def test_process_field_transforms_value(): + # process_field transforms the value during validation; input 5 is stored as 10. + Assert.eq(ExampleProcessFieldConfig.from_dict({"doubled_field": 5}).doubled_field, 10) -@pytest.mark.parametrize("value", ((), (5,), ("", 0), ("0", "True"), (0, "", "text"))) -def test_tuple_fixed_length_field_invalid(value): - check_invalid_config({"tuple_fixed_length_field": value}, cls=ExampleVerboseConfig) +def test_field_update_default(): + Assert.eq(ExampleUpdatedDefaultConfig.from_dict({}).int_field, 42) + Assert.eq(ExampleConfig.from_dict({}).int_field, 0) # parent default unchanged -# TODO: Test other fields with defaults. -# TODO: Test nested fields. +def test_field_update_hint(): + assert "str_field" in ExampleUpdatedHintConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core) + assert "str_field" not in ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core) # parent unchanged diff --git a/tests/config/test_update.py b/tests/config/test_update.py index 525c47694..eea1a5252 100644 --- a/tests/config/test_update.py +++ b/tests/config/test_update.py @@ -6,27 +6,58 @@ TEST_CONFIGS = ( ( - # Empty config + # Empty config: updating nothing changes nothing. {}, {}, {}, None, ), ( - # Update unset field; don't update set field; update + # Flat fields: update adds new fields and overwrites shared fields; unrelated base fields survive. {"int_field": 4, "str_field": "text"}, {"float_field": 3.0, "str_field": ""}, {"int_field": 4, "float_field": 3.0, "str_field": ""}, None, ), ( - # Update/override nested field. + # Nested field: update merges sub-fields; override replaces the whole nested config. {"nested_field": {"int_field": 4, "str_field": "text"}}, {"nested_field": {"float_field": 3.0, "str_field": ""}}, {"nested_field": {"int_field": 4, "float_field": 3.0, "str_field": ""}}, {"nested_field": {"float_field": 3.0, "str_field": ""}}, ), - # TODO: Add more complex cases + ( + # Top-level and nested fields together: top-level fields and nested sub-fields both update correctly. + {"int_field": 1, "nested_field": {"int_field": 4, "str_field": "text"}}, + {"str_field": "new", "nested_field": {"float_field": 3.0}}, + { + "int_field": 1, + "str_field": "new", + "nested_field": {"int_field": 4, "float_field": 3.0, "str_field": "text"}, + }, + {"int_field": 1, "str_field": "new", "nested_field": {"float_field": 3.0}}, + ), + ( + # Update from empty: base has no fields set; all update fields appear in result. + {}, + {"int_field": 7, "str_field": "hello"}, + {"int_field": 7, "str_field": "hello"}, + None, + ), + ( + # Update to empty: update has no fields set; base is preserved unchanged. + {"int_field": 7, "str_field": "hello"}, + {}, + {"int_field": 7, "str_field": "hello"}, + None, + ), + ( + # Collection fields: list and dict fields in update replace their counterparts in base. + {"int_field": 1, "list_field": [1, 2, 3]}, + {"list_field": [4, 5]}, + {"int_field": 1, "list_field": [4, 5]}, + None, + ), ) diff --git a/tests/conftest.py b/tests/conftest.py index 1d3264103..43f1fc65f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -277,7 +277,7 @@ def worker_resources(request) -> WorkerResources: return request.config.worker_resources -@pytest.mark.trylast +@pytest.hookimpl(trylast=True) def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" diff --git a/tests/data/common.py b/tests/data/common.py index 5771f9b11..1a17695bc 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -167,7 +167,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s token_ids = torch.stack([LanguageModelBatch.from_documents(sampled[i]).tokens for i in range(len(sampled))]).to( torch.int64 ) - Assert.all_equal(token_ids, validate_samples) + Assert.all_equal(token_ids, np.array(validate_samples)) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids diff --git a/tests/data/conftest.py b/tests/data/conftest.py new file mode 100644 index 000000000..306351471 --- /dev/null +++ b/tests/data/conftest.py @@ -0,0 +1,8 @@ +import pathlib + +import pytest + + +@pytest.fixture(scope="session") +def data_result_path(result_path: pathlib.Path) -> pathlib.Path: + return result_path / "data" diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 5d72c7152..edbe479cc 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -106,7 +106,7 @@ def test_blending(probs): Assert.all_equal(samples, samples_alt) -def test_gpt_blended(): +def test_gpt_blended(data_result_path): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() _, alt_config, _, _ = get_alt_test_dataset() @@ -127,10 +127,11 @@ def test_gpt_blended(): sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "blended", ) -def test_gpt_blended_mixed(): +def test_gpt_blended_mixed(data_result_path): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # Random dataset needs an explicit vocab size. @@ -155,4 +156,5 @@ def test_gpt_blended_mixed(): sequence_length=5, expected_samples=GPT_BLENDED_MIXED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "blended_mixed", ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 200a771f7..6774374bb 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -23,7 +23,7 @@ ] -def test_gpt_concatenate(): +def test_gpt_concatenate(data_result_path): # Make sure the dataset concatenation works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() @@ -47,4 +47,5 @@ def test_gpt_concatenate(): sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "concatenate", ) diff --git a/tests/data/test_dataset_discovery.py b/tests/data/test_dataset_discovery.py index e94da8499..0dd9c31a4 100644 --- a/tests/data/test_dataset_discovery.py +++ b/tests/data/test_dataset_discovery.py @@ -141,11 +141,11 @@ ), ) def test_dataset_discovery( - result_path: pathlib.Path, name: str, paths: tuple[pathlib.Path], ignore_paths, expected_config: dict + data_result_path: pathlib.Path, name: str, paths: tuple[pathlib.Path], ignore_paths, expected_config: dict ): """Test end-to-end discovery with multiple datasets in various structure.""" test_dataset_path = [get_common_test_dataset()[0], get_alt_test_dataset()[0]] - (dataset_path := result_path / f"dataset_discovery/{name}").mkdir(parents=True) + (dataset_path := data_result_path / f"dataset_discovery/{name}").mkdir(parents=True) for index, path in enumerate(paths): (path_ := dataset_path / path).mkdir(parents=True, exist_ok=True) shutil.copy( @@ -157,7 +157,7 @@ def test_dataset_discovery( # Run dataset discovery config = DatasetDiscoveryConfig( directory=dataset_path, - output=result_path / f"dataset_discovery/configs/{name}.yaml", + output=data_result_path / f"dataset_discovery/configs/{name}.yaml", ignore_paths=ignore_paths, ) config.run() diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index ec6ac3011..25e42fb97 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -20,7 +20,7 @@ ] -def test_gpt_fim(): +def test_gpt_fim(data_result_path): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. @@ -45,4 +45,5 @@ def test_gpt_fim(): sequence_length=5, expected_samples=GPT_FIM_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "fim", ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 9f0d4d9c6..9b7941600 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -15,7 +15,7 @@ ] -def test_gpt_random_dataset(): +def test_gpt_random_dataset(data_result_path): # Make sure the random dataset works and check for unintended changes in behavior. preprocessing = LanguageModelBatchPreprocessingConfig(vocab_size=8192) sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( @@ -30,4 +30,5 @@ def test_gpt_random_dataset(): sequence_length=7, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "random", ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2c753a98f..c45160ac2 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,9 +1,13 @@ +import dataclasses +import functools +import pathlib + import numpy as np import pytest import torch from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingConfig from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.utils import Assert @@ -35,24 +39,6 @@ ] -def test_gpt_sampled(): - # Make sure the memmap dataset works and check for unintended changes in behavior. - _, config, _, preprocessing = get_common_test_dataset() - sampled = get_dataset_config( - dataset_config := config, GPTDatasetFromFileConfig[LanguageModelDocument] - ).build_and_sample(*get_sampling_config(8, sequence_length=5, preprocessing=preprocessing)) - validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) - - # Test in data. - get_test_data_and_compare_samples( - {"datasets": {"training": dataset_config}}, - 8, - sequence_length=5, - expected_samples=GPT_MEMMAP_SAMPLES, - preprocessing=preprocessing, - ) - - class SimpleGPTIndexedDataset[DocumentType: LanguageModelDocument](IndexedDataset[DocumentType]): # TODO: worth adding to the main codebase? def __init__(self, samples): @@ -72,6 +58,7 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return len(self._samples[index]) + @property def name(self) -> str: return "dataset" @@ -87,6 +74,130 @@ def name(self) -> str: ] ) +# Document sizes: 3, 5, 2, 4, 6. +# With maximum_document_length=4, truncate_documents=False: docs of size 5 and 6 are dropped. +# With maximum_document_length=4, truncate_documents=True: docs of size 5 and 6 are split into chunks of ≤4. +TRUNCATE_DATASET = SimpleGPTIndexedDataset( + [ + [0, 1, 2], # length 3 — fits + [3, 4, 5, 6, 7], # length 5 — exceeds maximum_document_length=4 + [8, 9], # length 2 — fits + [10, 11, 12, 13], # length 4 — exactly at limit + [14, 15, 16, 17, 18, 19], # length 6 — exceeds + ] +) + + +@dataclasses.dataclass +class SamplingTestConfig: + name: str + num_samples: int + sequence_length: int = 5 + seed: int = 54983 + shuffle: ShufflingType = ShufflingType.epoch + truncate_documents: bool = True + maximum_document_length: int | None = None + expected_samples: list[list[int]] | None = None + # Tokens that must not appear in any sample (validated for drop/filter cases). + # Defaults to empty — the check is always run but trivially passes. + forbidden_tokens: frozenset[int] = frozenset() + # Tokens that must collectively appear across all samples (validated for truncate cases). + # Defaults to empty — the check is always run but trivially passes. + required_tokens: frozenset[int] = frozenset() + requires_extension: bool = False + dataset: SimpleGPTIndexedDataset | None = dataclasses.field(default=None, compare=False, repr=False) + + @functools.cached_property + def sampling_config_overrides(self) -> dict: + if self.maximum_document_length is not None: + return {"maximum_document_length": self.maximum_document_length} + return {} + + +_SAMPLING_TEST_CASES = [ + SamplingTestConfig( + name="simple", + num_samples=20, + ), + SamplingTestConfig( + # With truncate_documents=False, documents exceeding maximum_document_length are dropped entirely. + # Only the 3 docs with length ≤ 4 contribute tokens: [0,1,2], [8,9], [10,11,12,13] = 9 tokens. + name="maximum_document_length_drop", + num_samples=2, + sequence_length=4, + shuffle=ShufflingType.disabled, + truncate_documents=False, + maximum_document_length=4, + forbidden_tokens=frozenset(range(3, 8)) | frozenset(range(14, 20)), + dataset=TRUNCATE_DATASET, + requires_extension=True, + ), + SamplingTestConfig( + # With truncate_documents=True, documents exceeding maximum_document_length are split into chunks. + # All tokens should appear in the output; none should be dropped. + name="maximum_document_length_truncate", + num_samples=10, + sequence_length=4, + shuffle=ShufflingType.disabled, + truncate_documents=True, + maximum_document_length=4, + required_tokens=frozenset(range(20)), + dataset=TRUNCATE_DATASET, + ), +] + + +@pytest.mark.parametrize("test_config", [pytest.param(c, id=c.name) for c in _SAMPLING_TEST_CASES]) +def test_sampling(test_config: SamplingTestConfig): + if test_config.requires_extension and not _extension_available: + pytest.skip("CPP Extension not available") + + dataset = test_config.dataset if test_config.dataset is not None else TEST_DATASET + base_config, num_samples, seed = get_sampling_config( + test_config.num_samples, + sequence_length=test_config.sequence_length, + seed=test_config.seed, + shuffle=test_config.shuffle, + truncate_documents=test_config.truncate_documents, + ) + sampling_config = GPTSamplingConfig.from_dict(base_config.to_dict(), test_config.sampling_config_overrides) + sampled = dataset.sample(sampling_config, num_samples, seed) + + # validate_indexed_dataset_sampling's reference implementation concatenates tokens without padding, + # so it only applies when truncate_documents=True (no padding between documents). + if test_config.truncate_documents: + tokens = validate_indexed_dataset_sampling(sampled, test_config.expected_samples) + else: + tokens = torch.stack( + [ + LanguageModelBatch.from_documents(sampled[i], test_config.sequence_length + 1).tokens + for i in range(len(sampled)) + ] + ) + + valid_tokens = set(tokens[tokens >= 0].tolist()) + assert test_config.forbidden_tokens.isdisjoint(valid_tokens) + assert test_config.required_tokens.issubset(valid_tokens) + + +def test_gpt_sampled(data_result_path: pathlib.Path): + # Make sure the memmap dataset works and check for unintended changes in behavior. + _, config, _, preprocessing = get_common_test_dataset() + sampled = get_dataset_config( + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelDocument] + ).build_and_sample(*get_sampling_config(8, sequence_length=5, preprocessing=preprocessing)) + validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) + + # Test in data. + get_test_data_and_compare_samples( + {"datasets": {"training": dataset_config}}, + 8, + sequence_length=5, + expected_samples=GPT_MEMMAP_SAMPLES, + preprocessing=preprocessing, + cache_directory=data_result_path / "sampling/gpt_sampled", + ) + @pytest.mark.parametrize("seed", (0, 32, 88)) @pytest.mark.parametrize( @@ -111,6 +222,42 @@ def test_gpt_sample(seed, shuffle): previous_samples = samples +@pytest.mark.parametrize("token_cumsum_rate", (1, 3, 7, 20)) +def test_token_cumsum_rate(token_cumsum_rate): + # Different token_cumsum_rate values are a performance/memory tradeoff only — + # sampling output must be identical regardless of the rate chosen. + config, num_samples, seed = get_sampling_config(20, sequence_length=5) + reference = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + + config_with_rate = GPTSamplingConfig.from_dict(config.to_dict(), {"token_cumsum_rate": token_cumsum_rate}) + result = validate_indexed_dataset_sampling(TEST_DATASET.sample(config_with_rate, num_samples, seed)) + Assert.all_equal(result, reference) + + +def test_cache_directory(data_result_path: pathlib.Path): + # Verify that the cache is written on first run and reused on subsequent runs. + cache_dir = data_result_path / "sampling/cache_directory" + config, num_samples, seed = get_sampling_config(20, sequence_length=5, cache_directory=cache_dir) + + first = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + assert cache_dir.exists() and any(cache_dir.iterdir()) + + # Second run with the same config must produce identical output (reads from cache). + second = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + Assert.all_equal(first, second) + + +def test_cache_invalidated_on_config_change(data_result_path: pathlib.Path): + # Changing a sampling parameter should raise rather than silently return stale data. + cache_dir = data_result_path / "sampling/cache_invalidation" + config, num_samples, seed = get_sampling_config(20, sequence_length=5, cache_directory=cache_dir) + TEST_DATASET.sample(config, num_samples, seed) + + config_changed = GPTSamplingConfig.from_dict(config.to_dict(), {"token_cumsum_rate": 3}) + with pytest.raises(RuntimeError, match="Invalid dataset cache"): + TEST_DATASET.sample(config_changed, num_samples, seed) + + @pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") def test_build_padded_token_cumsum(): sizes = np.array([100, 256, 580, 600, 550, 89, 339, 430, 400, 795, 680, 50], dtype=np.int32) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 838562f64..d5d09f58c 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -29,7 +29,7 @@ ] -def test_gpt_slice(): +def test_gpt_slice(data_result_path): # Make sure dataset splitting works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() @@ -73,4 +73,5 @@ def test_gpt_slice(): "validation": GPT_SLICE_VALIDATION_SAMPLES, }, preprocessing=preprocessing, + cache_directory=data_result_path / "slice", ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 83f7657a0..9ad96b961 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -231,9 +231,9 @@ def _run_test_data_streaming_distributed( @pytest.mark.parametrize("num_workers", (0, 1)) -def test_data_streaming(result_path, worker_resources, num_workers): +def test_data_streaming(data_result_path, worker_resources, num_workers): distributed_config = _get_distributed_config({}) - path = result_path / "data_streaming/single_gpu" + path = data_result_path / f"data_streaming/single_gpu_workers_{num_workers}" _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers) check_data_streaming_results(path, distributed_config) @@ -254,10 +254,10 @@ def test_data_streaming(result_path, worker_resources, num_workers): @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) -def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): +def test_run_data_streaming_distributed(run_parallel_script, data_result_path, worker_resources): run_parallel_script( _run_test_data_streaming_distributed, - (result_path / "data_streaming", worker_resources.torchrun_port), + (data_result_path / "data_streaming", worker_resources.torchrun_port), world_size=4, backend=DistributedBackend.gloo, use_cuda=False, # Disable device count check. @@ -267,7 +267,7 @@ def test_run_data_streaming_distributed(run_parallel_script, result_path, worker @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) @pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) -def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): - report_subtest(path := result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) +def test_data_streaming_distributed(data_result_path, name, num_gpus, distributed_config_dict, report_subtest): + report_subtest(path := data_result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) distributed_config = _get_distributed_config(distributed_config_dict, num_gpus) check_data_streaming_results(path, distributed_config) diff --git a/tests/test_config.py b/tests/test_config.py index bf76595f9..492a57b02 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -70,7 +70,7 @@ def test_serialize_default_config_updates(cls): @pytest.mark.parametrize("load_config", tuple(ModelConfigType)) def test_pretrained_config(load_config: ModelConfigType, result_path): - config_path = result_path / "pretrained_config" + config_path = result_path / "pretrained_config" / load_config.value pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { From fe630662dbfcda77447ad5d4761cd452aa269869 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 12:49:04 -0400 Subject: [PATCH 15/33] Various fixes and improvements across data, schedule, docs, and build MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Make TOKEN_CUMSUM_RATE configurable via SamplingConfig.token_cumsum_rate - Suppress expected non-writable buffer UserWarning in RedisStreamingDocumentData - Fix cast warning in build_padded_token_cumsum (size_t → int64_t) - Use copy-on-write memmap mode ("c") for MemmapDataset to allow worker writes - Only pin_memory when CUDA is available in GPTData DataLoader - Remove deprecated estimate_critical_batch field from ScheduleConfig - Add pytest filterwarnings for PYTHONHASHSEED and itertools deprecation noise - Update docs: mkdocs inline citations, exclude docs/README.md, add release guide - Update README and recipe docs with revised benchmark numbers and config names - Add setup.cfg/pyproject.toml tweaks Co-Authored-By: Claude Sonnet 4.6 --- README.md | 10 +-- docs/help.md | 4 +- docs/index.md | 2 +- docs/quick-start.md | 99 +++++++++++----------- docs/recipes/continue-training.md | 32 ++++---- docs/recipes/data-preparation.md | 4 +- docs/recipes/generate.md | 4 +- docs/recipes/instruction-finetuning.md | 24 +++--- docs/recipes/train.md | 109 ++++++++++++++----------- examples/fast-llm.pytorchjob.yaml | 4 +- examples/fast-llm.sbat | 2 +- fast_llm/csrc/data.cpp | 2 +- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/config.py | 7 ++ fast_llm/data/dataset/memmap/memmap.py | 2 +- fast_llm/data/dataset/sampled.py | 26 +++--- fast_llm/data/dataset/streaming.py | 22 +++-- fast_llm/engine/schedule/config.py | 6 +- mkdocs.yaml | 5 ++ pyproject.toml | 6 ++ setup.cfg | 1 + 21 files changed, 199 insertions(+), 174 deletions(-) diff --git a/README.md b/README.md index d02e7f95e..2a1b9f4e2 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,12 @@ As a truly open-source project, Fast-LLM allows full customization and extension We'll walk you through how to use Fast-LLM to train a large language model on a cluster with multiple nodes and GPUs. We'll show an example setup using a Slurm cluster and a Kubernetes cluster. -For this demo, we will train a Mistral-7B model from scratch for 100 steps on random data. The config file `examples/mistral-4-node-benchmark.yaml` is pre-configured for a multi-node setup with 4 DGX nodes, each with 8 A100-80GB or H100-80GB GPUs. +For this demo, we will train a Mistral-7B model from scratch for 100 steps on random data. The config file `examples/mistral.yaml` defines the model architecture and training settings, while the example launch scripts are pre-configured for a 4-node setup with 8 GPUs per node. > [!NOTE] > Fast-LLM scales from a single GPU to large clusters. You can start small and expand based on your resources. -Expect to see a significant speedup in training time compared to other libraries! For training Mistral-7B, Fast-LLM is expected to achieve a throughput of **9,800 tokens/s/H100** (batch size 32, sequence length 8k) on a 4-node cluster with 32 H100s. +Expect to see a significant speedup in training time compared to other libraries! For training Mistral-7B, Fast-LLM is expected to achieve a throughput of **9,800 tokens/s/H100** (micro-batch size 8k tokens, total batch size 256k tokens) on a 4-node cluster with 32 H100s. ### Running Fast-LLM on a Slurm Cluster @@ -77,7 +77,7 @@ Expect to see a significant speedup in training time compared to other libraries #### Steps -1. Deploy the [nvcr.io/nvidia/pytorch:24.07-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) Docker image to all nodes (recommended), because it contains all the necessary dependencies. +1. Deploy the [nvcr.io/nvidia/pytorch:25.11-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) Docker image to all nodes (recommended), because it contains all the necessary dependencies. 2. Install Fast-LLM on all nodes: ```bash @@ -88,7 +88,7 @@ Expect to see a significant speedup in training time compared to other libraries #SBATCH --ntasks=$(scontrol show node | grep -c NodeName) #SBATCH --exclusive - srun bash -c 'pip install --no-cache-dir -e "git+https://github.com/ServiceNow/Fast-LLM.git#egg=llm[CORE,OPTIONAL,DEV]"' + srun bash -c 'pip install --no-cache-dir "fast-llm[CORE,OPTIONAL] @ git+https://github.com/ServiceNow/Fast-LLM.git"' EOF ``` @@ -115,7 +115,7 @@ Now, you can sit back and relax while Fast-LLM trains your model at full speed! #### Steps -1. Create a Kubernetes [PersistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) (PVC) named `fast-llm-home` that will be mounted to `/home/fast-llm` in the container using [examples/fast-llm-pvc.yaml](examples/fast-llm-pvc.yaml): +1. Create a Kubernetes [PersistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) (PVC) named `pvc-fast-llm-home` that will be mounted to `/home/fast-llm` in the container using [examples/fast-llm-pvc.yaml](examples/fast-llm-pvc.yaml): ```bash kubectl apply -f examples/fast-llm-pvc.yaml diff --git a/docs/help.md b/docs/help.md index ed59dffa7..e368e349e 100644 --- a/docs/help.md +++ b/docs/help.md @@ -10,7 +10,7 @@ Welcome to the Fast-LLM Help Center! Here, you'll find fixes for common hiccups, Let's stay one step ahead of those pesky gotchas. Here's a list of common issues and quick fixes: -- **CUDA Out of Memory**: When the GPU throws a fit, a few tweaks can help. First, try lowering `micro_batch_size` or `sequence_length` in the configuration to fit within the available memory. Still stuck? Try setting the `mlp_recompute_level` option to `activation` or `full` to save memory in the backward pass, or experiment with higher ZeRO stages for reduced memory usage. And if that's not enough, tensor or model parallelism may be your friend. +- **CUDA Out of Memory**: When the GPU throws a fit, a few tweaks can help. First, try lowering `micro_batch_size` or `maximum_document_length` under `data:` in the configuration to fit within the available memory. Still stuck? Try setting the `recompute_level` option under `model: base_model: decoder: block: mlp:` to `activation` or `full` to save memory in the backward pass, or experiment with higher ZeRO stages for reduced memory usage. And if that's not enough, tensor or model parallelism may be your friend. - **Python Hash Seed Sync Error**: Encountering an error like @@ -28,7 +28,7 @@ Let's stay one step ahead of those pesky gotchas. Here's a list of common issues Watchdog caught collective operation timeout: WorkNCCL(SeqNum=408951, OpType=_ALLGATHER_BASE, … , Timeout(ms)=600000) ran for 600351 milliseconds before timing out ``` - appearing across all GPU workers, it usually means one or more hosts failed to complete a NCCL operation, causing others to block. NCCL errors can be frustrating to diagnose since they rarely specify which node or GPU caused the issue. It is difficult to surface which messages and operations are in progress during these crashes. If the issue happens at a specific moment of training like dataset preparation or model export, the issue might be that this specific procedure took too long and timed out other processes (e.g. when preparing large datasets for long training runs, or saving large models on slow storage). In this case, it can help to increase the timeout `distributed_timeout: 3600`. + appearing across all GPU workers, it usually means one or more hosts failed to complete a NCCL operation, causing others to block. NCCL errors can be frustrating to diagnose since they rarely specify which node or GPU caused the issue. It is difficult to surface which messages and operations are in progress during these crashes. If the issue happens at a specific moment of training like dataset preparation or model export, the issue might be that this specific procedure took too long and timed out other processes (e.g. when preparing large datasets for long training runs, or saving large models on slow storage). In this case, it can help to increase the timeout by setting `model: distributed: timeout: 3600` in your config. In some other cases, the best we can do is to restart the training job and hope it doesn't happen again. If the issue persists, it might be because of network congestion or a problematic GPU. If the worker that crashed is consistent across multiple runs, it's likely a hardware issue. If you can't resolve it, open an issue on GitHub, and we'll help you troubleshoot. For more detailed solutions, check out our GitHub Issues page. Odds are someone's already tackled a similar problem, and you might find the exact fix you need. diff --git a/docs/index.md b/docs/index.md index 80277ffd2..de3698f1a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,7 +34,7 @@ Fast-LLM isn't just another library, **it's a platform for powering the next gen Fast-LLM offers all the capabilities you need to accelerate your LLM training and **push the boundaries of what's possible**: -- **🚀 Speed Like No Other:** Achieve record-breaking training throughput with Fast-LLM. For instance, train Mistral-7B at **10,350 tokens/s/GPU** on a 4-node cluster with 32 H100 GPUs (batch size 64, sequence length 8k). Our optimized kernels, advanced parallelism, and memory-efficient techniques drastically reduce training time and cost. +- **🚀 Speed Like No Other:** Achieve record-breaking training throughput with Fast-LLM. For instance, train Mistral-7B at **10,350 tokens/s/GPU** on a 4-node cluster with 32 H100 GPUs (micro-batch size 8k tokens, total batch size 256k tokens). Our optimized kernels, advanced parallelism, and memory-efficient techniques drastically reduce training time and cost. - **📡 Unmatched Scalability:** Seamlessly scale from a single GPU to large compute clusters. Fast-LLM supports 3D parallelism (data, tensor, and pipeline), sequence length parallelism, and ZeRO-1,2,3 techniques for maximum memory efficiency. Scale to the size you need without sacrificing performance. diff --git a/docs/quick-start.md b/docs/quick-start.md index 20fc1a2b1..68ae9c056 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,7 +11,7 @@ To follow this guide, you'll need: - **Hardware**: At least one NVIDIA GPU, preferably with Ampere architecture or newer. Note that this tutorial is designed for 80 GB A100s or H100 GPUs, and some adjustments are needed to run it with less memory or an earlier architecture. - **Software**: Depending on your setup, you'll need one of the following: - **Docker**: If you're using the prebuilt Docker image on your local machine. - - **Python 3.10**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. + - **Python 3.12**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. - **Cluster Setup**: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments. ## 🏗 Step 1: Initial Setup @@ -69,7 +69,7 @@ Now, select the compute environment that matches your setup or preferred workflo Install PyTorch and pybind11 to meet Fast-LLM's requirements: ```bash - pip install pybind11 "torch>=2.2.2" + pip install pybind11 "torch>=2.9.0" ``` 4. **Install NVIDIA APEX**: @@ -86,7 +86,7 @@ Now, select the compute environment that matches your setup or preferred workflo Finally, install Fast-LLM along with its remaining dependencies, including [FlashAttention-2](https://github.com/Dao-AILab/flash-attention): ```bash - pip install --no-build-isolation "git+https://github.com/ServiceNow/Fast-LLM.git#egg=fast_llm[CORE,OPTIONAL,DEV]" + pip install --no-build-isolation "fast-llm[CORE,OPTIONAL] @ git+https://github.com/ServiceNow/Fast-LLM.git" ``` 6. **Verify the Installation**: @@ -220,7 +220,7 @@ Choose based on your goals for this tutorial. git clone https://huggingface.co/meta-llama/Llama-3.1-8B ./fast-llm-tutorial/pretrained-model ``` -## 📚 Step 3: Prepare the Training Data +## 📚 Step 4: Prepare the Training Data For this tutorial, we'll use text from the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run! @@ -471,7 +471,7 @@ Fast-LLM ships with a `prepare` command that will download and preprocess the da You can follow the job's progress by running `kubectl get pods` and checking the logs with `kubectl logs fast-llm-prepare-master-0`. -## ⚙️ Step 4: Configure Fast-LLM +## ⚙️ Step 5: Configure Fast-LLM Next, we'll create a configuration file for Fast-LLM. @@ -481,7 +481,7 @@ Next, we'll create a configuration file for Fast-LLM. !!! warning "Micro-Batch Size" - The `micro_batch_size` in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease the `sequence_length` to reduce the memory footprint. + The `micro_batch_size` in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease `maximum_document_length` under `data:` to reduce the memory footprint. Save the following as `fast-llm-tutorial/train-config.yaml`: @@ -506,31 +506,31 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: project_name: fast-llm-tutorial group_name: Small entity_name: null - batch: - micro_batch_size: 60 # (4)! - sequence_length: 1024 - batch_size: 480 # (5)! data: datasets: training: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (5)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (5)! + micro_batch_size: 61440 # (4)! + maximum_document_length: 1024 optimizer: learning_rate: base: 6.0e-04 pretrained: - format: llama # (7)! + format: llama # (6)! path: fast-llm-tutorial/pretrained-model - model_weights: no # (8)! + model_weights: no # (7)! model: base_model: - transformer: - use_flash_attention: yes # (9)! + decoder: + block: + mixer: + use_flash_attention: yes # (8)! distributed: - training_dtype: bf16 # (10)! + compute_dtype: bf16 # (9)! run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -538,13 +538,12 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: 1. For the small run, we'll stop after 100 iterations. 2. The trained model will be saved in `Transformers` Llama format to `fast-llm-tutorial/experiment/export/llama/100` at the end of the small run. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`. 3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section. - 4. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well. - 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch. - 6. Location of the dataset metadata files generated in Step 4. - 7. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`. - 8. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory). - 9. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. - 10. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. + 4. Adjust the micro-batch size based on GPU memory. For SmolLM2-135M with a maximum document length of 1024 tokens and a 80GB GPU, a `micro_batch_size` of 61440 tokens should work well. At 1024 tokens per document, this corresponds to about 500,000 tokens per batch on 8 GPUs. + 5. Location of the dataset metadata files generated in Step 4. + 6. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`. + 7. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory). + 8. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. + 9. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. === "Big" @@ -563,7 +562,6 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (2)! format: llama interval: 20_000 @@ -571,59 +569,56 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: project_name: fast-llm-tutorial group_name: Big entity_name: null - batch: - micro_batch_size: 2 # (4)! - sequence_length: 4096 - batch_size: 512 # (5)! data: datasets: training: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (5)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! - optimizer: # (7)! + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (5)! + micro_batch_size: 8192 # (4)! + maximum_document_length: 4096 + optimizer: # (6)! weight_decay: 0.1 beta_1: 0.9 beta_2: 0.95 - learning_rate: # (8)! + learning_rate: # (7)! base: 6.0e-04 minimum: 6.0e-05 decay_style: cosine decay_iterations: 100_000 warmup_iterations: 2000 pretrained: - format: llama # (9)! + format: llama # (8)! path: fast-llm-tutorial/pretrained-model - model_weights: yes # (10)! + model_weights: yes # (9)! model: base_model: - transformer: - use_flash_attention: yes # (11)! - cross_entropy_impl: fused # (12)! + decoder: + block: + mixer: + use_flash_attention: yes # (10)! multi_stage: - zero_stage: 2 # (13)! + zero_stage: 2 # (11)! distributed: - training_dtype: bf16 # (14)! + compute_dtype: bf16 # (12)! run: experiment_dir: fast-llm-tutorial/experiment ``` - 1. Total number of training tokens will be approximately 210B: 100,000 iterations * 512 * 4096 tokens per batch. + 1. Total number of training tokens will be approximately 26B: 100,000 iterations × 32 GPUs × 8,192 tokens per micro-batch. 2. A permanent model checkpoint in `Transformers` Llama format will be saved to `fast-llm-tutorial/experiment/export/llama/[iteration]/` every 20,000 iterations. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`. 3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section. - 4. Adjust the number of sequences per GPU based on GPU memory. Considering a 4k token sequence length and 80GB GPUs, a `micro_batch_size` of 1 should work well. - 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 4k tokens per sequence, 512 corresponds to about 2.1 million tokens per batch. - 6. Location of the dataset metadata file generated in Step 4. - 7. These are good default optimizer settings for training models. - 8. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. - 9. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. - 10. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. - 11. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. - 12. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code. - 13. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. - 14. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. + 4. Adjust the micro-batch size based on GPU memory. Considering a maximum document length of 4096 tokens and 80GB GPUs, a `micro_batch_size` of 8192 tokens should work well. + 5. Location of the dataset metadata file generated in Step 4. + 6. These are good default optimizer settings for training models. + 7. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. + 8. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. + 9. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. + 10. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. + 11. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. + 12. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. ## 🔑 (Optional) Step 6: Add Your Weights & Biases API Key diff --git a/docs/recipes/continue-training.md b/docs/recipes/continue-training.md index d7df7a196..6c1e347ce 100644 --- a/docs/recipes/continue-training.md +++ b/docs/recipes/continue-training.md @@ -48,15 +48,12 @@ This is not much different from a pretraining config. We will: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (1)! format: llama interval: 20_000 - batch: - micro_batch_size: 2 - sequence_length: 4096 - batch_size: 256 data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -80,13 +77,14 @@ This is not much different from a pretraining config. We will: model_weights: yes # (5)! model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/Llama-3.1-8B-cpt ``` @@ -107,15 +105,12 @@ This is not much different from a pretraining config. We will: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (1)! format: qwen2 interval: 20_000 - batch: - micro_batch_size: 1 - sequence_length: 8192 - batch_size: 256 data: + micro_batch_size: 8192 + maximum_document_length: 8192 datasets: training: type: file @@ -139,13 +134,14 @@ This is not much different from a pretraining config. We will: model_weights: yes # (5)! model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/qwen-2.5-7B-cpt ``` diff --git a/docs/recipes/data-preparation.md b/docs/recipes/data-preparation.md index be0f8ef00..b3e3274f2 100644 --- a/docs/recipes/data-preparation.md +++ b/docs/recipes/data-preparation.md @@ -12,7 +12,7 @@ For this guide, you would need: - **Software**: Depending on your setup, you'll need one of the following: - **Docker**: If you're using the prebuilt Docker image on your local machine. - - **Python 3.10**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. + - **Python 3.12**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. - **Cluster Setup**: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments. ## 📚 Step 1: Download the dataset from Huggingface @@ -104,7 +104,7 @@ Fast-LLM's prepare command processes the dataset by tokenizing and saving it in === "Custom Installation" - Please follow the instructions in the [Quick-Start guide](quick-start/#step-1-initial-setup-custom-installation) to set up Fast-LLM in your environment. + Please follow the instructions in the [Quick-Start guide](quick-start/#step-1-initial-setup) to set up Fast-LLM in your environment. Then, run the following command: diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md index d6d2333e1..77ea609a2 100644 --- a/docs/recipes/generate.md +++ b/docs/recipes/generate.md @@ -37,8 +37,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Optional: updates to Fast-LLM config before loading the model updates = { - ("base_model", "transformer", "use_flash_attention"): True, - ("distributed", "training_dtype"): "bf16" + ("base_model", "decoder", "block", "mixer", "use_flash_attention"): True, + ("distributed", "compute_dtype"): "bf16" } # Load the model from the checkpoint with the given configuration diff --git a/docs/recipes/instruction-finetuning.md b/docs/recipes/instruction-finetuning.md index 2c58a987d..0e28b7dc8 100644 --- a/docs/recipes/instruction-finetuning.md +++ b/docs/recipes/instruction-finetuning.md @@ -107,7 +107,7 @@ splits: ## ⚙️ Step 4: Configure Fast-LLM -It's time to configure the Fast-LLM training config. This is very similar to [Quick Start](../quick-start.md) with two additional options, namely, `truncate_documents` and `cross_document_attention` which are important for improving the task performance of instruction-tuned models. +It's time to configure the Fast-LLM training config. This is very similar to [Quick Start](../quick-start.md) with one additional option, namely `truncate_documents`, which is important for improving the task performance of instruction-tuned models. ```yaml training: @@ -124,16 +124,12 @@ training: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: llama interval: 1000 -batch: - micro_batch_size: 1 - sequence_length: 4096 - batch_size: 32 - cross_document_attention: no # (1)! data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -141,7 +137,7 @@ data: validation: type: file path: ./sft-tutorial/tokenized/Llama-3.1-8B/fast_llm_config_validation.yaml - truncate_documents: no # (2)! + truncate_documents: no # (1)! sampling: use_loss_masking_spans: yes optimizer: @@ -160,19 +156,19 @@ pretrained: model_weights: yes model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 3 distributed: timeout: 3600 - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: ./sft-tutorial/llama-3.1-8b-instruct-magpie ``` -1. Prevents paying attention to other documents in a packed sequence -2. Avoids truncating documents to fit into a packed sequence and starts a new sequence instead. Documents longer than sequence length will be skipped altogether. +1. Avoids truncating documents to fit into a packed sequence and starts a new sequence instead. Documents longer than sequence length will be skipped altogether. Launching the training run is similar to Step 7 in the [Quick Start](../quick-start.md) guide. diff --git a/docs/recipes/train.md b/docs/recipes/train.md index efdf6111b..224fd23a1 100644 --- a/docs/recipes/train.md +++ b/docs/recipes/train.md @@ -28,15 +28,12 @@ Let's start from the following training configuration: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: llama interval: 20_000 - batch: - micro_batch_size: 2 - sequence_length: 4096 - batch_size: 256 data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -55,12 +52,10 @@ Let's start from the following training configuration: decay_iterations: 100_000 warmup_iterations: 2000 model: - base_model: - cross_entropy_impl: fused multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -80,15 +75,12 @@ Let's start from the following training configuration: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: qwen2 interval: 20_000 - batch: - micro_batch_size: 1 - sequence_length: 8192 - batch_size: 256 data: + micro_batch_size: 8192 + maximum_document_length: 8192 datasets: training: type: file @@ -107,12 +99,10 @@ Let's start from the following training configuration: decay_iterations: 100_000 warmup_iterations: 2000 model: - base_model: - cross_entropy_impl: fused multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -155,47 +145,72 @@ Alternatively, we define the model architecture ourselves as follows: ```yaml model: base_model: - tie_word_embeddings: false - use_position_embeddings: false - vocab_size: 128256 - transformer: - activation_type: silu - add_linear_biases: false - ffn_hidden_size: 14336 - gated: true - head_groups: 8 - hidden_size: 4096 # (1)! - kv_channels: 128 + tied_embedding_weight: false + hidden_size: 4096 # (1)! + embeddings: + vocab_size: 128256 + decoder: + num_blocks: 32 + block: + mixer: + heads: 32 + head_groups: 8 + head_size: 128 + add_linear_biases: false + rotary: + type: llama3 + theta: 500_000 + mlp: + intermediate_size: 14336 + gated: true + activation: silu + add_linear_biases: false + normalization: + type: rms_norm + head: normalization: type: rms_norm - num_attention_heads: 32 - num_layers: 32 - rotary: - type: llama3 - theta: 500_000 ``` === "Qwen 2.5 7B" ```yaml model: base_model: - tie_word_embeddings: false - use_position_embeddings: false - vocab_size: 152064 - transformer: - activation_type: silu - add_linear_biases: only_attn_qkv - ffn_hidden_size: 18944 - gated: true - head_groups: 4 - hidden_size: 3584 # (1)! + tied_embedding_weight: false + hidden_size: 3584 # (1)! + embeddings: + vocab_size: 152064 + decoder: + num_blocks: 28 + block: + mixer: + heads: 28 + head_groups: 4 + head_size: 128 + add_linear_biases: false + query_layer: + bias: + enabled: true + key_layer: + bias: + enabled: true + value_layer: + bias: + enabled: true + rotary: + type: default + theta: 1_000_000 + mlp: + intermediate_size: 18944 + gated: true + activation: silu + add_linear_biases: false + normalization: + type: rms_norm + epsilon: 1e-06 + head: normalization: type: rms_norm epsilon: 1e-06 - num_attention_heads: 28 - num_layers: 28 - rotary: - type: default - theta: 1_000_000 ``` 1. Hidden-size/num-layers will be used to provide good defaults for weight initialization std. diff --git a/examples/fast-llm.pytorchjob.yaml b/examples/fast-llm.pytorchjob.yaml index 13a7a4df8..03c51fcb9 100644 --- a/examples/fast-llm.pytorchjob.yaml +++ b/examples/fast-llm.pytorchjob.yaml @@ -42,7 +42,7 @@ spec: --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral-4-node-benchmark.yaml + --config examples/mistral.yaml env: - name: NCCL_DEBUG value: "INFO" @@ -102,7 +102,7 @@ spec: --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral-4-node-benchmark.yaml + --config examples/mistral.yaml env: - name: NCCL_DEBUG value: "INFO" diff --git a/examples/fast-llm.sbat b/examples/fast-llm.sbat index 13a966ec3..8099bc141 100644 --- a/examples/fast-llm.sbat +++ b/examples/fast-llm.sbat @@ -34,4 +34,4 @@ srun --gpus-per-node=$SLURM_GPUS_PER_NODE \ --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral_4_node_benchmark.yaml" + --config examples/mistral.yaml" diff --git a/fast_llm/csrc/data.cpp b/fast_llm/csrc/data.cpp index a1a24c7c9..1696af449 100644 --- a/fast_llm/csrc/data.cpp +++ b/fast_llm/csrc/data.cpp @@ -181,7 +181,7 @@ py::array build_padded_token_cumsum(const py::array_t& sizes_, }); const auto byte_size = sizeof(int64_t); - return py::array(std::vector{token_cumsum.size()}, + return py::array(std::vector{static_cast(token_cumsum.size())}, {byte_size}, token_cumsum_result, free_when_done); diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 415ddc195..a25aede78 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -112,7 +112,7 @@ def get_iterator( ), num_workers=num_workers, prefetch_factor=prefetch_factor, - pin_memory=True, + pin_memory=self._distributed_config.use_cuda, collate_fn=functools.partial(self._collate_fn, dataset_name=dataset_name, preprocess=preprocess), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index bfe1509d6..ea30e3fc0 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -79,6 +79,13 @@ class SamplingConfig(SamplingConfigBase): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. predicted_tokens: int = Field(default=1) + token_cumsum_rate: int = Field( + default=10, + desc="Sampling interval for the token cumulative sum index." + " A smaller value reduces per-sample seek time at the cost of a larger index.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) cache_directory: pathlib.Path | None = Field(default=None) dataset_name: str = Field(default="dataset") world_size: int = Field(default=1) diff --git a/fast_llm/data/dataset/memmap/memmap.py b/fast_llm/data/dataset/memmap/memmap.py index d44ed9093..69175c893 100644 --- a/fast_llm/data/dataset/memmap/memmap.py +++ b/fast_llm/data/dataset/memmap/memmap.py @@ -38,7 +38,7 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: config_bytes = stream.read(int.from_bytes(stream.read(4), signed=False)) reader_config = MemmapIndexDatasetReaderConfig.from_dict(json.loads(config_bytes.decode("utf-8"))) - self._memmap = np.memmap(self._path, mode="r") + self._memmap = np.memmap(self._path, mode="c") self._reader = reader_config.get_reader(memoryview(self._memmap)) def __getstate__(self) -> tuple[str, pathlib.Path]: diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index bffa9ff66..db123a354 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -63,10 +63,6 @@ def _lazy_load(self): self._array = np.load(self._path, mmap_mode="r") -# TODO: Make configurable? -TOKEN_CUMSUM_RATE = 10 - - class SampledIndexedDataset[DocumentType: Document](SampledDataset[DocumentType]): """ A sampled dataset. @@ -253,9 +249,9 @@ def _sample(self) -> None: # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals (`token_cumsum_rate`). + # A larger rate reduces pre-computation overhead at the cost of more runtime scanning per sample. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::token_cumsum_rate]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( document_sizes, @@ -288,7 +284,7 @@ def _sample(self) -> None: ) self._token_cumsum_shuffled.save(token_cumsum_shuffled) self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( + document_shuffling[: (token_cumsum_shuffled.size + 1) * self._config.token_cumsum_rate].numpy( force=self._config.gpu ) ) @@ -298,10 +294,12 @@ def _sample(self) -> None: def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._config.truncate_documents: # Create the output tensor. - out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) + out = sizes.new_empty(sizes.numel() // self._config.token_cumsum_rate + 1, dtype=dtype.torch) # Get partial sums for regular intervals, excluding the last incomplete interval. torch.sum( - sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), + sizes[: sizes.numel() - sizes.numel() % self._config.token_cumsum_rate].view( + -1, self._config.token_cumsum_rate + ), dim=1, out=out[1:], ) @@ -319,7 +317,9 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - return out.numpy(force=self._config.gpu), None else: # TODO: dynamically handle int64 or int32 in CPP - out = build_padded_token_cumsum(sizes.cpu().numpy(), self._config.sample_size, TOKEN_CUMSUM_RATE, offset) + out = build_padded_token_cumsum( + sizes.cpu().numpy(), self._config.sample_size, self._config.token_cumsum_rate, offset + ) num_tokens = out[-1] out = out[:-1][ : np.clip( @@ -358,7 +358,9 @@ def __getitem__(self, index: int) -> list[DocumentType]: # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 - document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset + document_sampling_index = ( + token_start_cumsum_index * self._config.token_cumsum_rate + token_start_array_document_offset + ) token_count = token_start_array[token_start_cumsum_index].item() diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index e3fce4eb3..ec8fe7bd1 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -3,6 +3,7 @@ import logging import time import typing +import warnings import redis import torch.utils.data @@ -33,20 +34,25 @@ class RedisStreamingDocumentData(Config): def _validate(self): # Decode message - if isinstance(self.tokens, bytes): - self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) - elif isinstance(self.tokens, (list, tuple)): - self.tokens = torch.tensor(self.tokens, dtype=torch.int64) + with warnings.catch_warnings(): + # The tensors are read-only in practice; the non-writable-buffer warning is expected. + warnings.simplefilter("ignore", UserWarning) + if isinstance(self.tokens, bytes): + self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) + elif isinstance(self.tokens, (list, tuple)): + self.tokens = torch.tensor(self.tokens, dtype=torch.int64) if isinstance(self.loss_masking_spans, str): self.loss_masking_spans = json.loads(self.loss_masking_spans) if isinstance(self.chosen_span, str): self.chosen_span = json.loads(self.chosen_span) if isinstance(self.rejected_span, str): self.rejected_span = json.loads(self.rejected_span) - if isinstance(self.old_log_probabilities, bytes): - self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) - elif isinstance(self.old_log_probabilities, (list, tuple)): - self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + if isinstance(self.old_log_probabilities, bytes): + self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) + elif isinstance(self.old_log_probabilities, (list, tuple)): + self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) super()._validate() if self.old_log_probabilities is not None: Assert.eq(len(self.old_log_probabilities), self.num_tokens) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48714db40..11fe11d11 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -1,7 +1,7 @@ import enum import functools -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, test_field +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert @@ -69,10 +69,6 @@ class ScheduleConfig(Config): desc="Detailed time table for the schedule execution (cpu and gpu times).", hint=FieldHint.logging, ) - # TODO: Remove - estimate_critical_batch: bool = Field( - default=False, desc="No longer supported.", hint=FieldHint.deprecated, valid=test_field(lambda x: not x) - ) # Skip the weight update and related ops (debug) skip_step: bool = Field( default=False, diff --git a/mkdocs.yaml b/mkdocs.yaml index 00e52a011..bd8b7d4c4 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -157,6 +157,10 @@ plugins: branch: main - bibtex: bib_file: "docs/refs.bib" + enable_inline_citations: false + +exclude_docs: | + README.md nav: - Welcome: index.md @@ -190,5 +194,6 @@ nav: - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md - Testing: contributing/testing.md + - How to Release: contributing/how-to-release.md - About Us: about-us.md - Join Us: join-us.md diff --git a/pyproject.toml b/pyproject.toml index 8488623d7..c119daad4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,12 @@ testpaths = [ "fast_llm_external_models/tests" # External models tests ] norecursedirs = ["Megatron-LM"] +filterwarnings = [ + # PYTHONHASHSEED is not set by pytest; DataLoader workers will use a deterministic seed anyway. + "ignore:PYTHONHASHSEED should be set:UserWarning", + # Python 3.14 will remove pickle/copy support from itertools; comes from multiprocessing internals. + "ignore:Pickle, copy, and deepcopy support will be removed from itertools:DeprecationWarning", +] [tool.isort] profile = "black" diff --git a/setup.cfg b/setup.cfg index 955702907..e035cc0c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ install_requires = CORE = # Available through the nvidia base image torch>=2.9.0 + # apex # Available through the nvidia base image, requires manual build with --cuda_ext --fast_layer_norm numpy>=2.1.0 # Used for checkpoints safetensors>=0.6.2 From b935eb5539e6816665e58c57cecb6e313c022979 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 15:47:19 -0400 Subject: [PATCH 16/33] Fix bugs in engine/base_model and engine/config_utils - ReductionType enum values were copy-pasted from DataType (wrong string literals) - _serialize_architecture_field dropped dict keys (set comprehension instead of dict) - LayerBaseWithNamespace/LayerWithNamespace namespace param had spurious default=None - ParameterConfig/OptionalParameterConfig lacked _abstract=False (inherited True from ModuleConfig); empty _validate() bodies were accidentally suppressing the abstract check; stray pass before actual code in OptionalParameterConfig.get_parameter - FillInitializationConfig docstring was copy-pasted from NormalInitializationConfig - NormalInitializationConfig.max field desc said "Min value" (copy-paste from min) - UniformInitializationConfig.scale had wrong default=None and FieldHint.optional; mean validator Assert.geq(0) was wrong (mean can be negative) - TensorLogsConfig.max_elements had skip_valid_if_none despite being non-optional int - RunnableConfig._load_url opened the auth token file twice (second open unused) - UpdateType converted to StrEnum for consistency Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/config.py | 2 +- fast_llm/engine/base_model/base_model.py | 4 ++-- fast_llm/engine/base_model/config.py | 15 +++++---------- fast_llm/engine/config_utils/initialization.py | 10 ++++------ fast_llm/engine/config_utils/logging.py | 4 ++-- fast_llm/engine/config_utils/parameter.py | 9 +-------- fast_llm/engine/config_utils/runnable.py | 5 ++--- tests/utils/model_configs.py | 2 +- 8 files changed, 18 insertions(+), 33 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index eeaa6c7d3..61e22737a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -40,7 +40,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): _AUTO_VALIDATE = self._old_value -class UpdateType(str, enum.Enum): +class UpdateType(enum.StrEnum): # Override entries no matter what they contain. override = "override" # Override atomic entries and lists, but update dicts recursively by setting or overriding only the specified entries. diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 4cb529463..d0d634b63 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -97,7 +97,7 @@ class LayerBaseWithNamespace(LayerBase): TODO: Consider namespace for losses and metrics? """ - def __init__(self, layer: LayerBase, namespace: str = None): + def __init__(self, layer: LayerBase, namespace: str): super().__init__(layer._distributed_config) self._layer = layer self._namespace = namespace @@ -139,7 +139,7 @@ def _layers_with_namespace(self) -> list[Layer]: class LayerWithNamespace(LayerBaseWithNamespace, Layer): _layer: Layer - def __init__(self, layer: Layer, namespace: str = None): + def __init__(self, layer: Layer, namespace: str): super().__init__(layer, namespace) self.layer_count = self._layer.layer_count diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 2770e67a2..a68c9ebc8 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -57,7 +57,7 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: elif isinstance(value, (list, tuple, set)): return [self._serialize_architecture_field(value_) for value_ in value] elif isinstance(value, dict): - return {self._serialize_architecture_field(value_) for name, value_ in value.items()} + return {name: self._serialize_architecture_field(value_) for name, value_ in value.items()} else: return self._serialize_value(value) @@ -106,15 +106,10 @@ class ResourceUsageConfig: class ReductionType(enum.StrEnum): - """ - An enum to represent data types independently of third party libraries, - so we can swap them more easily and allow for lazy imports. - """ - - sum = "float64" - average = "float32" - minimum = "float16" - maximum = "bfloat16" + sum = "sum" + average = "average" + minimum = "minimum" + maximum = "maximum" @property def torch(self) -> "typing.Callable[[torch.Tensor], torch.Tensor]": diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index 2f12a45d2..0395324f6 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -46,7 +46,7 @@ def get_initializer(self) -> "Initializer": @config_class(dynamic_type={InitializationConfig: "fill"}) class FillInitializationConfig(InitializationConfig): """ - Normal initialization: normal(mean, std).clamp(min,max) + Fill initialization: fills the tensor with a constant value. """ _abstract = False @@ -88,7 +88,7 @@ class NormalInitializationConfig(InitializationConfig): ) max: float | None = Field( default=None, - desc="Min value for initialization clamping.", + desc="Max value for initialization clamping.", hint=FieldHint.optional, ) @@ -105,16 +105,14 @@ class UniformInitializationConfig(InitializationConfig): _abstract = False scale: float = Field( - default=None, desc="Initialization scale.", - hint=FieldHint.optional, + hint=FieldHint.core, valid=check_field(Assert.geq, 0), ) mean: float = Field( - default=None, + default=0.0, desc="Initialization mean.", hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), ) def get_initializer(self) -> "Initializer": diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 943b8de38..32deb4562 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -4,7 +4,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -73,7 +73,7 @@ class TensorLogsConfig(Config): default=8, desc="Maximum number of tensor values to print for each tensor when posting tensor logs to stdout.", hint=FieldHint.logging, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), + valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index c0910c09a..3e2b61120 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -38,6 +38,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): @config_class() class ParameterConfig(ModuleConfig): + _abstract = False initialization: InitializationConfig = Field( desc="If provided, override the default initialization method set by the parent layer.", hint=FieldHint.feature, @@ -50,9 +51,6 @@ class ParameterConfig(ModuleConfig): ) # TODO: Initialization, lr_scale - def _validate(self) -> None: - pass - def get_parameter( self, dims: tuple[TensorDim, ...], @@ -83,9 +81,6 @@ class OptionalParameterConfig(ParameterConfig): default=None, ) - def _validate(self) -> None: - pass - def get_parameter( self, dims: tuple[TensorDim, ...], @@ -97,8 +92,6 @@ def get_parameter( default_enabled: bool = False, peft: PeftConfig | None, ) -> "ParameterMeta|None": - pass - if (self.enabled is None and default_enabled) or self.enabled: return super().get_parameter( dims, diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 58c490cb9..74fa0a2ae 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -165,9 +165,8 @@ def _load_url(cls, config_url: str, config_auth_token_file: pathlib.Path | None headers = {"Accept": "application/vnd.github.v3.raw"} if config_auth_token_file is not None: - config_auth_token = config_auth_token_file.open("r").read().strip() - with open(config_auth_token_file) as f: - headers["Authorization"] = f"token {config_auth_token}" + config_auth_token = config_auth_token_file.read_text().strip() + headers["Authorization"] = f"token {config_auth_token}" response = requests.get(config_url, headers=headers) if response.status_code == 200: return response.text diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6268ac194..0f89d9323 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -706,7 +706,7 @@ def update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, ModelTestingGroup.streaming: ModelTestingGroupAction.normal, }, ) From cada39f95fa81a87d6d56163fcf11d33f9f9a65e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 15:48:02 -0400 Subject: [PATCH 17/33] Add docstrings to config classes, update docs and mkdocs Co-Authored-By: Claude Sonnet 4.6 --- docs/developer_guide/conversion.md | 194 ++++++-------- fast_llm/data/config.py | 2 +- fast_llm/data/dataset/config.py | 14 +- fast_llm/engine/checkpoint/config.py | 18 ++ fast_llm/engine/multi_stage/config.py | 10 +- fast_llm/engine/optimizer/config.py | 5 + fast_llm/engine/schedule/config.py | 10 +- fast_llm/engine/training/config.py | 28 +- fast_llm/layers/attention/config.py | 2 + fast_llm/layers/common/linear/config.py | 5 + .../layers/common/normalization/config.py | 12 +- fast_llm/layers/decoder/config.py | 14 +- fast_llm/layers/decoder/mlp/config.py | 6 +- fast_llm/models/gpt/config.py | 6 + fast_llm/profile.py | 2 +- mkdocs.yaml | 244 ++++++++++++++++++ 16 files changed, 448 insertions(+), 124 deletions(-) diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 6f42d8b6a..4ce982110 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -76,124 +76,99 @@ class AwesomeHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler) ### Configuration conversion -The configuration conversion utility interfaces between two configurations in the form of nested dictionaries: -a serialized Fast-LLM configuration and an external configuration. -The `_load_config` method is expected to read the configuration on disk, as expected by the checkpoint format, -and return the same configuration in the forma of a nested dictionary, -with `_save_config` handling the reverse operation. -See the [Hugging Face implementation](https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/engine/checkpoint/huggingface.py) for an example. - -To perform the conversion, the checkpoint handler relies on a list of `ParamConverter` objects, -which describe how individual parameters (or in some case multiple ones) should be converted. -The `ParamConverter` base interface is a dataclass consisting of two variables and two methods: - -* `fast_llm_names: tuple[tuple[str, ...], ...]`: An array of entry names on the Fast-LLM side, in tuple format. -For example, `((transformer, head_groups),)` refers to the single entry `config["transformer"]["head_groups"]`. -* `export_names: tuple[tuple[str, ...], ...]`: An array of entry names on the external side, in the same tuple format. -* `export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]`: -This method takes the configuration parameters corresponding to `fast_llm_names` (in the same order), -and returns converted parameters corresponding to `export_names`. -* `import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]`: -The converse of`export_params`, converting parameters corresponding to `export_names` into those corresponding to `fast_llm_names`. - -While not strictly part of the interface, it may also be useful to define a dataclass `__post_init__`, -for example to restrict the number of parameters in `fast_llm_names` and `export_names`. - -Fast-LLM offers several generic configuration converter classes, including: - -* `RenameParamConverter`: A simple 1-1 mapping between parameters, with optional renaming but identical value. -Typically, most converters are of this type. -* `ConstantImportParamConverter`: A 1-0 mapping for Fast-LLM parameters that without an equivalent in the external format, -that must take a specific value `fast_llm_value` for conversion to make sense (i.e., they take a hard-coded value in the external format). -This type of converter is common for Hugging Face converters, as Hugging Face models support much fewer configuration parameters. -* `ConstantExportParamConverter`: A 0-1 mapping, the converse of `ConstantImportParamConverter` -* `MappedConfigParamConverter`: A 1-1 mapping similar to `RenameParamConverter`, but with a non-trivial relation between values. - -In addition to those, you may need to implement your own custom converter. -Here is an example that associates several Fast-LLM variables with a tuple. +Configuration conversion is handled by a `HuggingFaceBaseModelConverter` subclass, +which is linked to the handler via a `base_model_converter_class` class variable. +The converter implements three class methods: -```python -@dataclasses.dataclass(kw_only=True) -class PackingParamConverter(ParamConverter): - def __post_init__(self): - # There may be any number of Fast-LLM variables, but only one external one - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values): - # Pack the values into a single tuple. - return (fast_llm_values,) - - def import_params(self, export_values): - # Unpack the values. We can safely assume `export_values` has length one because of the assertion in `__post_init__` - return export_values[0] -``` +* `import_config(cls, config: dict) -> dict`: +Reads the external (e.g., Hugging Face) configuration dict and returns a Fast-LLM `base_model` config dict. +* `export_config(cls, config: BaseModelConfig) -> dict`: +Takes a Fast-LLM `BaseModelConfig` object and returns the corresponding external configuration dict. +* `get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]`: +Returns the list of weight converters for this model (described in the next section). -Now that we've seen how parameter converters work, we're ready to add them to our handler class. -We do so by creating a list of converters in the `_create_config_converters` class method. -Continuing our `AwesomeModel` handler example, we define: +The `_load_config` and `_save_config` methods on the handler read and write the external configuration file. +See the [Hugging Face implementation](https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/engine/checkpoint/huggingface.py) for their default implementation. + +Continuing our `AwesomeModel` example, the base model converter class could look like: ```python +class AwesomeBaseModelConverter(HuggingFaceBaseModelConverter): @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - # For Hugging Face handlers, we need to call the superclass method. - return super()._create_config_converters() + [ - # A trivial example where both the name and value are the same on both sides. - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - # A non-trivial example of `RenameParamConverter` with renaming and handling of nested dictionaries. - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - # A constant import example indicating that the external format does not support absolute positional embeddings. - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - # The `architectures` parameter is a common use case for `ConstantExportParamConverter` in Hugging Face models. - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AwesomeModelForCausalLM"]), - # A value mapping example, where we match Fast-LLM activation types with their Hugging Face equivalents. - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - # A more hypothetical example using `PackingParamConverter` to pack two parameters `epsilon_1`, `epsilon_2` into a tuple `eps`. - PackingParamConverter( - fast_llm_names=(("epsilon_1",),("epsilon_2",)), - export_names=(("eps",),), - ), - ] -``` + def import_config(cls, config: dict) -> dict: + # Build and return a Fast-LLM base_model config dict from the external config. + return { + "hidden_size": config["hidden_size"], + "embeddings": {"vocab_size": config["vocab_size"]}, + "decoder": { + "num_blocks": config["num_hidden_layers"], + "block": { + "mixer": { + "heads": config["num_attention_heads"], + "head_groups": config.get("num_key_value_heads", config["num_attention_heads"]), + "rotary": {"type": "default", "theta": config.get("rope_theta", 10000)}, + "add_linear_biases": False, + }, + "mlp": { + "intermediate_size": config["intermediate_size"], + "gated": True, + "activation": ActivationType.from_hf_name(config["hidden_act"]), + "add_linear_biases": False, + }, + "normalization": {"type": "rms_norm", "epsilon": config["rms_norm_eps"]}, + }, + }, + "head": {"normalization": {"type": "rms_norm", "epsilon": config["rms_norm_eps"]}}, + "tied_embedding_weight": config.get("tie_word_embeddings", False), + } -!!! note "How conversion works" - The once the converters are defined, the conversion utility takes it from there. - Exporting works as follows (importing work similarly): - *The handler creates an empty export config dict, then loops over its list of converters. For each converter, it: - * Reads the value of each parameter defined in `fast_llm_names`, and gathers them in a tuple. - *Calls `converter.export_params`, providing the set of read values as argument. - * Ensure that the returned value has the correct length (that of `export_names`) - * Set the respective values in the export config dict. + @classmethod + def export_config(cls, config: AwesomeBaseModelConfig) -> dict: + # Build and return the external config dict from the Fast-LLM config object. + decoder_block = config.decoder.block + return { + "model_type": "awesome_model", + "architectures": ["AwesomeModelForCausalLM"], + "hidden_size": config.hidden_size, + "vocab_size": config.embeddings.vocab_size, + "num_hidden_layers": config.decoder.num_blocks, + "num_attention_heads": decoder_block.mixer.heads, + "num_key_value_heads": decoder_block.mixer.head_groups, + "rope_theta": decoder_block.mixer.rotary.theta, + "intermediate_size": decoder_block.mlp.intermediate_size, + "hidden_act": decoder_block.mlp.activation.hf_name, + "rms_norm_eps": decoder_block.normalization.epsilon, + "tie_word_embeddings": config.tied_embedding_weight, + } -!!! note "About `MISSING` and `DEFAULT`" - If a value is not found during import, it will be replaced by the `MISSING` tag. - The converter's `import_params` has the opportunity to handle this missing value, - and if a `MISSING`, the handler will throw an error because it does not know what value to set on the Fast-LLM side. + @classmethod + def get_converters(cls, config: AwesomeBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + # Described in the next section. + ... +``` - The `MISSING` tag is also supported during export, - but has a different meaning as the value is always expected to be found in the Fast-LLM configuration. - Instead, `export_params` may return a `MISSING` tag indicating that no value should not be added to the Fast-LLM config. - It may also return `DEFAULT`, which will be replaced by the default value for the configuration parameter. +Then wire the converter into the handler via `base_model_converter_class`: - Note that the handling of `MISSING` and `DEFAULT` is experimental and may be improved in the future. +```python +class AwesomeHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model_class = AwesomeModelConfig + architecture = "AwesomeModelForCausalLM" + base_model_converter_class = AwesomeBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + from transformers import AutoConfig + return AutoConfig +``` ### State conversion State conversion follows the same principle as configuration conversion, but acts on flat dictionaries of state tensors. Converters are defined by subclassing `WeightConverter`, with the interface: -* `fast_llm_name: str | tuple[str, ...]`: An entry name or array of entry names on the Fast-LLM side. -For example, `((transformer, head_groups),)` refers to the single entry `config["transformer"]["head_groups"]`. -* `export_name: str | tuple[str, ...]`: An entry name or array of entry names on the external side. +* `fast_llm_name: str | tuple[str, ...]`: A state dict key, or tuple of keys, on the Fast-LLM side. +For example, `"layers.0.mixer.weight"` or `("layers.0.weight_1", "layers.0.weight_2")`. +* `export_name: str | tuple[str, ...]`: A state dict key, or tuple of keys, on the external side. * `export_weight(self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]) -> tuple[torch.Tensor | SafeTensorSlice, ...]`: This method takes the state dict entries corresponding to `fast_llm_name` (in the same order), and returns converted entries corresponding to `export_name`. @@ -225,19 +200,20 @@ class TransposeWeightConverter(WeightConverter): return (weight[0][:].transpose().contiguous(),) ``` -We define the list of weight converters in the `_create_weight_converters` method. -Continuing our `AwesomeModel` handler example, we define: +We define the list of weight converters in the `get_converters` class method of the base model converter. +Continuing our `AwesomeModel` example, we define: ```python - def _create_weight_converters(self) -> list[WeightConverter]: + @classmethod + def get_converters(cls, config: AwesomeBaseModelConfig, exported_config: dict) -> list[WeightConverter]: converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = len(self._model.config.base_model.decoder) + # The set of converters may depend on the base model configuration. + num_layers = config.decoder.num_blocks # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - # We usually want to loop dynamically over layers + # We usually want to loop dynamically over layers. for i in range(num_layers): # A `SplitWeightConverter` example, splitting a weight in two. converters.append(SplitWeightConverter( diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 78bc20636..98444d149 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,7 +1,7 @@ import enum -class MultiprocessingContext(str, enum.Enum): +class MultiprocessingContext(enum.StrEnum): # Fast but risk of segfaults due to interactions with triton # (for example https://github.com/openai/triton/issues/2088). fork = "fork" diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index ea30e3fc0..bef80f468 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -18,7 +18,9 @@ logger = logging.getLogger(__name__) -class ShufflingType(str, enum.Enum): +class ShufflingType(enum.StrEnum): + """Strategy for shuffling dataset samples across training epochs.""" + # Shuffle all epochs together. Not extendable. full = "full" # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. @@ -115,6 +117,8 @@ def sampling_maximum_document_length(self) -> int: @config_class() class DatasetConfig[DocumentType: Document](Config): + """Abstract base configuration for all dataset types.""" + _abstract: typing.ClassVar[bool] = True @@ -130,6 +134,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class SamplableDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): + """Abstract configuration for datasets that can be built and then sampled.""" + def build(self) -> SamplableDataset[DocumentType]: raise NotImplementedError() @@ -139,6 +145,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class IndexedDatasetConfig[DocumentType: Document](SamplableDatasetConfig[DocumentType]): + """Abstract configuration for indexed datasets that support random access by index.""" + def build(self) -> "IndexedDataset[DocumentType]": raise NotImplementedError() @@ -211,6 +219,8 @@ def build(self) -> "DatasetSlice": @config_class(dynamic_type={SampledDatasetConfig: "blended"}) class BlendedDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): + """Mixes multiple datasets together, sampling from each according to specified weights.""" + _abstract = False name: str = Field( default="blended", @@ -265,6 +275,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class RedisConfig(Config): + """Configuration for connecting to a Redis server (host, port, timeout).""" + REDIS_FIELD: typing.ClassVar[str] = "data" REDIS_FIELD_B: typing.ClassVar[bytes] = REDIS_FIELD.encode() REDIS_GROUP_NAME: typing.ClassVar[str] = "fast_llm_group" diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 04b1dff46..190be62f1 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -92,6 +92,8 @@ def load_fast_llm(self) -> bool: @config_class() class CheckpointConfigBase(Config): + """Abstract base configuration for all checkpoint operations, holding the checkpoint format.""" + _abstract = True # Note: the `format` may be a str when configuring from file or cli. # The actual class should be set through `setup` in a parent config validation. @@ -117,6 +119,8 @@ def setup(self, model_config: "FastLLMModelConfig| type[FastLLMModelConfig]") -> @config_class() class CheckpointStateConfigBase(CheckpointConfigBase): + """Abstract base configuration for checkpoint operations that include model weights and/or optimizer state.""" + _abstract = True # Defaults and descriptions are set in derived classes. model_weights: bool = Field(default=True, hint=FieldHint.feature) @@ -125,6 +129,8 @@ class CheckpointStateConfigBase(CheckpointConfigBase): @config_class() class CheckpointSaveConfigBase(CheckpointConfigBase): + """Abstract base configuration for saving checkpoints, with file-size and dtype options.""" + _abstract = True parameters_per_file: int = Field( default=2**32, @@ -141,6 +147,8 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): + """Configuration for saving model weights and/or optimizer state to a checkpoint.""" + _abstract = False model_weights: bool = FieldOverride(desc="Save the model weights.") optimizer_state: bool = FieldOverride(desc="Save the optimizer state. Default: save if supported by the `format`.") @@ -157,6 +165,8 @@ def _validate(self) -> None: @config_class() class CheckpointPathConfigBase(CheckpointConfigBase): + """Abstract base configuration for checkpoint operations that require a filesystem path and optional timeout.""" + _abstract = True path: pathlib.Path | None = Field( default=None, @@ -173,16 +183,22 @@ class CheckpointPathConfigBase(CheckpointConfigBase): @config_class() class CheckpointSaveMetadataConfig(CheckpointPathConfigBase): + """Configuration for saving checkpoint metadata (path and format) without weights or optimizer state.""" + _abstract = False @config_class() class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConfigBase): + """Full configuration for saving a checkpoint: path, format, weights, optimizer state, and file options.""" + _abstract = False @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): + """Configuration for loading checkpoint metadata, controlling which config sections are loaded.""" + _abstract = False # TODO: Set default to model? (Not backward compatible) load_config: ModelConfigType = Field( @@ -194,6 +210,8 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): + """Full configuration for loading a checkpoint: path, format, and which state to restore.""" + _abstract = False model_weights: bool = FieldOverride(desc="Load the model weights.") diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 41736aed6..c642203fc 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -42,7 +42,7 @@ class ShardName: grads = "grads" -class StageMode(str, enum.Enum): +class StageMode(enum.StrEnum): # Allow forward and backward passes and optimizer. # TODO: Add mode for forward and backward but not optimizer? training = "training" @@ -72,6 +72,8 @@ def on_device(self) -> bool: @config_class() class StageConfig(Config): + """Configuration for a single model stage: gradient precision, frozen weight storage, and debug logging.""" + full_precision_gradients: bool = Field( default=True, desc="Reduce and accumulate gradients in fp32 to improve numerical stability.", @@ -141,6 +143,8 @@ class StageConfig(Config): @config_class() class MultiStageConfig(StageConfig): + """Configuration for the multi-stage model layout: layers per stage, ZeRO sharding, and buffer counts.""" + layers_per_stage: float = Field( default=1.0, desc="Number of layers to include in each Fast LLM stage.", @@ -206,6 +210,8 @@ def _validate(self) -> None: @config_class(registry=True) class FastLLMModelConfig(Config): + """Abstract base configuration for a Fast-LLM model: base model, multi-stage layout, and distributed config.""" + _abstract = True checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = ( DistributedCheckpointFormat, @@ -283,6 +289,8 @@ def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None: @config_class() class PretrainedFastLLMModelConfig(Config): + """Configuration wrapper that optionally loads model weights and config from a pretrained checkpoint.""" + # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index f4303a5d3..2b0e8709b 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -17,6 +17,8 @@ class LearningRateStageType: @config_class() class LearningRateScheduleConfig(Config): + """Configuration for the learning rate schedule (warmup, decay style, and bounds).""" + base: float = Field(default=0.0001, desc="Base learning rate for the optimizer.", hint=FieldHint.core) decay_style: str = Field(default="constant", desc="The learning rate decay formula.", hint=FieldHint.feature) decay_iterations: int | None = Field( @@ -38,6 +40,8 @@ class LearningRateScheduleConfig(Config): @config_class() class GradientScalerConfig(Config): + """Configuration for loss scaling, either fixed (constant) or dynamic (for fp16 training).""" + constant: float | None = Field( default=None, desc="Constant multiplier applied to the loss. Setting this disables dynamic scaling.", @@ -72,6 +76,7 @@ class GradientScalerConfig(Config): @config_class() class OptimizerConfig(Config): + """Configuration for the AdamW optimizer: learning rate schedule, gradient scaling, and hyperparameters.""" learning_rate: LearningRateScheduleConfig = Field( desc="A schedule for the learning rate.", diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 11fe11d11..f56c00f28 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -5,13 +5,15 @@ from fast_llm.utils import Assert -class StepType(str, enum.Enum): +class StepType(enum.StrEnum): forward = "forward" backward = "backward" @config_class() class ScheduleConfig(Config): + """Configuration for the micro-batch execution schedule: pipeline overlap, CPU throttling, and debug options.""" + depth_first_micro_batches: int = Field( default=1, desc="Size of individual micro-batches. May be derived or constrained be other quantities.", @@ -85,18 +87,18 @@ def num_inputs(self) -> int: return self.sequential_micro_batches * self.micro_batch_splits -class StreamType(str, enum.Enum): +class StreamType(enum.StrEnum): compute = "compute" data = "data" pipeline = "pipeline" -class StepScheduleType(str, enum.Enum): +class StepScheduleType(enum.StrEnum): breadth_first = "breadth_first" depth_first = "depth_first" -class EventType(str, enum.Enum): +class EventType(enum.StrEnum): # Global events batch_begin = "batch_begin" batch_end = "batch_end" diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index fecee4615..bece3cb49 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -51,6 +51,8 @@ def _validate_script(value: str | list[str]) -> list[str]: @config_class() class CallbackConfig(Config): + """Configuration for an optional shell script callback invoked after a checkpoint, export, or shutdown event.""" + script: list[str] | None = Field( default=None, desc="Shell script to run.", @@ -72,6 +74,8 @@ def run(self) -> None: @config_class() class WandbAlertConfig(IntervalConfig): + """Configuration for periodic Weights & Biases status alerts during training.""" + interval = FieldOverride( desc="The number of training iterations between each Wandb status post (alert)." " Setting to None will disable iteration-based wandb alerts." @@ -94,6 +98,8 @@ def post_alerts(self) -> bool: @config_class() class MetricsLogsConfig(IntervalConfig): + """Configuration for training metric logging interval (loss, throughput, etc.).""" + interval = FieldOverride( default=100, desc="The number of training iterations between each metric logs." @@ -104,6 +110,8 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): + """Configuration for Weights & Biases experiment tracking (project, entity, alerts).""" + alert: WandbAlertConfig = Field( desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", @@ -116,6 +124,8 @@ class WandbConfig(Config): @config_class() class TrainingCheckpointBaseConfig(IntervalConfig): + """Abstract base configuration for periodic saving operations (checkpoints and exports).""" + _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( @@ -156,6 +166,8 @@ def to_delete(self, iterations: list[int]) -> list[int]: @config_class() class TrainingCheckpointConfig(TrainingCheckpointBaseConfig): + """Configuration for saving full training checkpoints (weights + optimizer state) at a fixed interval.""" + _abstract = False save_name: typing.ClassVar[str] = "checkpoint" interval = FieldOverride( @@ -189,6 +201,8 @@ def get_load_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi @config_class() class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConfigBase): + """Configuration for exporting model weights to an external format (e.g. HuggingFace) at a fixed interval.""" + _abstract = False save_name: typing.ClassVar[str] = "export" interval = FieldOverride( @@ -206,6 +220,8 @@ def get_save_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi @config_class() class ShutdownConfig(IntervalConfig): + """Configuration for automatic training shutdown after a checkpoint, useful for preemptible jobs.""" + interval = FieldOverride( desc="The number of training iterations between each automated shutdown." " Setting to None will disable automated shutdowns." @@ -218,6 +234,8 @@ class ShutdownConfig(IntervalConfig): @config_class() class TrainingConfig(Config): + """Configuration for training phases: iterations, checkpoints, exports, logging, evaluators, and W&B.""" + evaluators: dict[str, EvaluatorConfig] = Field( default_factory=dict, desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", @@ -260,6 +278,8 @@ def _validate(self) -> None: @config_class(registry=True) class TrainerCallbackConfig(Config): + """Abstract base configuration for trainer callbacks that hook into training events.""" + def get_callback(self, model: "FastLLMModel") -> "TrainerCallback": raise NotImplementedError() @@ -269,6 +289,8 @@ def setup(self, config: "TrainerConfig") -> None: @config_class() class WeightsBroadcastConfig(Config): + """Configuration for broadcasting model weights to an external process via NCCL (used in online RL pipelines).""" + # TODO: Have the external model send these instead? host: str = Field( default="localhost", @@ -294,9 +316,7 @@ class WeightsBroadcastConfig(Config): @config_class(dynamic_type={TrainerCallbackConfig: "streaming"}) class StreamingTrainerCallbackConfig(TrainerCallbackConfig, RedisConfig): - """ - Aggregates all trainer-side Redis-based event configurations. - """ + """Trainer callback for online RL: exports and broadcasts model weights via Redis after each update.""" broadcast: WeightsBroadcastConfig = Field( desc="Configuration for signaling weight-ready events via Redis.", @@ -319,6 +339,8 @@ def setup(self, config: "TrainerConfig") -> None: @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): + """Abstract base configuration for a training run: model, data, schedule, optimizer, callbacks, and checkpointing.""" + _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 86469c3d9..fcb5bfaf6 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -42,6 +42,8 @@ class AttentionImplementation(enum.StrEnum): @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): + """Configuration for multi-head and grouped-query attention with optional rotary embeddings.""" + # TODO: Make mixer class dynamic. _abstract = False diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e92..4c64f0816 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -45,6 +45,8 @@ class AffineLinearBaseConfig(LinearBaseConfig): @config_class() class LinearConfig(LinearBaseConfig): + """Configuration for a linear (weight-only, no bias) layer with optional PEFT and tensor-parallelism support.""" + apply_peft: bool | None = Field( default=None, desc="Wrap this layer ." @@ -104,6 +106,8 @@ def get_layer( @config_class() class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): + """Configuration for an affine linear layer (weight + optional bias) with optional PEFT and tensor-parallelism support.""" + def get_layer( self, in_dim: TensorDim, @@ -175,6 +179,7 @@ class CausalConv1dConfig(AffineLinearBaseConfig): ) activation: ActivationType | None = Field( default=None, + desc="Activation function applied after the convolution. None means no activation.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 4b8edaebe..274215bf2 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -14,7 +14,7 @@ from fast_llm.layers.common.normalization.normalization import Normalization -class NormalizationImplementation(str, enum.Enum): +class NormalizationImplementation(enum.StrEnum): """ An enum for the available implementations of layer norm. """ @@ -28,6 +28,8 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) class NormalizationConfig(ModuleConfig): + """Abstract base configuration for normalization layers. Use `type: layer_norm`, `rms_norm`, `gated_rms_norm`, or `none`.""" + lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." @@ -62,6 +64,8 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi @config_class(dynamic_type={NormalizationConfig: "none"}) class NoNormalizationConfig(NormalizationConfig): + """Disables normalization entirely (identity pass-through).""" + _abstract = False @property @@ -106,6 +110,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): + """Configuration for standard layer normalization (mean and variance, with learnable weight and bias).""" + bias: ParameterConfig = Field( desc="Configuration for the weight.", hint=FieldHint.architecture, @@ -121,6 +127,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) class RMSNormalizationConfig(LayerNormalizationBaseConfig): + """Configuration for RMS normalization (variance only, no mean subtraction, no bias).""" + _abstract = False @property @@ -132,6 +140,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) class GatedRMSNormalizationConfig(RMSNormalizationConfig): + """Configuration for gated RMS normalization, which applies a learned activation gate alongside the norm weight.""" + _abstract = False activation: ActivationType = Field( diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4cab2d39b..6ab259b2b 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -56,6 +56,8 @@ def get_layer( @config_class(registry=True) class MLPBaseConfig(BlockWithBiasConfig): + """Abstract base configuration for MLP (feedforward) layers. Use `type: mlp` or `type: moe` to select a variant.""" + _abstract = True def get_layer( @@ -200,9 +202,17 @@ def layer_class(self) -> "type[StochasticMixer]": @config_class(dynamic_type={BlockConfig: "decoder"}) class DecoderBlockConfig(BlockConfig): + """Configuration for a transformer decoder block (attention + MLP + normalization + residual).""" + _abstract = False - mixer: MixerConfig = Field() - mlp: MLPBaseConfig = Field() + mixer: MixerConfig = Field( + desc="Configuration for the attention/mixer layer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the feedforward (MLP) layer.", + hint=FieldHint.architecture, + ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the block normalization layers.", diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 36841b45b..28198f2e4 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -18,13 +18,15 @@ class MLPLossNames: router_z_loss = "router_z_loss" -class RoutingType(str, enum.Enum): +class RoutingType(enum.StrEnum): topk = "aux_loss" sinkhorn = "sinkhorn" @config_class(dynamic_type={MLPBaseConfig: "mlp"}) class MLPConfig(MLPBaseConfig): + """Configuration for a dense feedforward (MLP) layer with optional gating and activation recomputation.""" + # TODO: Review names # TODO: Separate MoE? _abstract = False @@ -81,6 +83,8 @@ def layer_class(self) -> "type[MLP]": @config_class(dynamic_type={MLPBaseConfig: "moe"}) class MoEMLPConfig(MLPConfig): + """Configuration for a Mixture-of-Experts (MoE) feedforward layer with top-k token routing.""" + router: LinearConfig = Field( # TODO: Improve default? desc="Configuration for the MoE router.", diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 72cace032..770139816 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -56,6 +56,8 @@ def base_model_class(self) -> type["GPTBaseModel"]: @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): + """Configuration for the GPT model, including distributed, multi-stage, and HuggingFace checkpoint formats.""" + _abstract = False model_name: typing.ClassVar[str] = "gpt" base_model: GPTBaseModelConfig = FieldOverride() @@ -93,12 +95,16 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): + """Configuration for a GPT model together with an optional pretrained checkpoint to load.""" + _abstract = False model: GPTModelConfig = FieldOverride() @config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): + """Top-level configuration for training a GPT model. Entry point for `fast-llm train gpt`.""" + data: GPTDataConfig = FieldOverride() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldOverride() diff --git a/fast_llm/profile.py b/fast_llm/profile.py index a3902cf1e..58a72764d 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -class ProfileType(str, enum.Enum): +class ProfileType(enum.StrEnum): cpu = "cpu" cuda = "cuda" diff --git a/mkdocs.yaml b/mkdocs.yaml index bd8b7d4c4..0ad00ccef 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -77,6 +77,7 @@ theme: # Hooks hooks: - docs/overrides/hooks/shortcodes.py + - docs/overrides/hooks/generate_config_docs_hook.py # Additional configuration extra: @@ -195,5 +196,248 @@ nav: - Development Practices: contributing/dev-practices.md - Testing: contributing/testing.md - How to Release: contributing/how-to-release.md + # BEGIN AUTO-GENERATED CONFIG REFERENCE + - Configuration Reference: + - reference/configuration/index.md + - Data: + - reference/configuration/data/index.md + - Data: + - reference/configuration/data/data/index.md + - reference/configuration/data/data/DataConfig.md + - Gpt: + - reference/configuration/data/data/gpt/index.md + - reference/configuration/data/data/gpt/GPTDataConfig.md + - Dataset: + - reference/configuration/data/dataset/index.md + - reference/configuration/data/dataset/BlendedDatasetConfig.md + - reference/configuration/data/dataset/ConcatenatedDatasetConfig.md + - reference/configuration/data/dataset/DatasetConfig.md + - reference/configuration/data/dataset/DatasetSliceConfig.md + - reference/configuration/data/dataset/IndexedDatasetConfig.md + - reference/configuration/data/dataset/RedisConfig.md + - reference/configuration/data/dataset/SamplableDatasetConfig.md + - reference/configuration/data/dataset/SampledDatasetConfig.md + - reference/configuration/data/dataset/SamplingConfig.md + - reference/configuration/data/dataset/SamplingConfigBase.md + - reference/configuration/data/dataset/StreamingDatasetConfig.md + - Gpt: + - reference/configuration/data/dataset/gpt/index.md + - reference/configuration/data/dataset/gpt/FimConfig.md + - reference/configuration/data/dataset/gpt/GPTDatasetFromFileConfig.md + - reference/configuration/data/dataset/gpt/GPTFimSampledDatasetConfig.md + - reference/configuration/data/dataset/gpt/GPTRandomDatasetConfig.md + - reference/configuration/data/dataset/gpt/GPTSamplingConfig.md + - reference/configuration/data/dataset/gpt/GPTTestSlowDatasetConfig.md + - Memmap: + - reference/configuration/data/dataset/memmap/index.md + - reference/configuration/data/dataset/memmap/LanguageModelReaderConfig.md + - reference/configuration/data/dataset/memmap/MemmapDatasetConfig.md + - reference/configuration/data/dataset/memmap/MemmapIndexDatasetReaderConfig.md + - reference/configuration/data/dataset/memmap/MemmapReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/MemmapReaderConfig.md + - reference/configuration/data/dataset/memmap/NullReaderConfig.md + - reference/configuration/data/dataset/memmap/PatchReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/PatchReaderConfig.md + - reference/configuration/data/dataset/memmap/RangeReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/RangeReaderConfig.md + - reference/configuration/data/dataset/memmap/TokenDataReaderConfig.md + - reference/configuration/data/dataset/memmap/TokenReaderConfig.md + - Document: + - reference/configuration/data/document/index.md + - reference/configuration/data/document/BatchPreprocessingConfig.md + - reference/configuration/data/document/ImageNormalizationConfig.md + - reference/configuration/data/document/LanguageModelBatchPreprocessingConfig.md + - reference/configuration/data/document/LengthPreprocessingConfig.md + - reference/configuration/data/document/PatchPreprocessingConfig.md + - reference/configuration/data/document/TokenPreprocessingConfig.md + - Preparation: + - reference/configuration/data/preparation/index.md + - reference/configuration/data/preparation/DatasetPreparatorConfig.md + - Dataset Discovery: + - reference/configuration/data/preparation/dataset_discovery/index.md + - reference/configuration/data/preparation/dataset_discovery/DatasetDiscoveryConfig.md + - Gpt Memmap: + - reference/configuration/data/preparation/gpt_memmap/index.md + - reference/configuration/data/preparation/gpt_memmap/ConversationSourceConfig.md + - reference/configuration/data/preparation/gpt_memmap/DatasetPreparatorDistributedConfig.md + - reference/configuration/data/preparation/gpt_memmap/DocumentSourceConfig.md + - reference/configuration/data/preparation/gpt_memmap/GPTHuggingfaceDatasetConfig.md + - reference/configuration/data/preparation/gpt_memmap/GPTMemmapDatasetPreparatorConfig.md + - reference/configuration/data/preparation/gpt_memmap/LanguageModelSourceConfig.md + - Image Patch: + - reference/configuration/data/preparation/image_patch/index.md + - reference/configuration/data/preparation/image_patch/ImagePreparationConfig.md + - Tokenizer: + - reference/configuration/data/preparation/tokenizer/index.md + - reference/configuration/data/preparation/tokenizer/TokenizerConfig.md + - Engine: + - reference/configuration/engine/index.md + - Base Model: + - reference/configuration/engine/base_model/index.md + - reference/configuration/engine/base_model/BaseModelConfig.md + - reference/configuration/engine/base_model/ModuleConfig.md + - Checkpoint: + - reference/configuration/engine/checkpoint/index.md + - reference/configuration/engine/checkpoint/CheckpointConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointLoadConfig.md + - reference/configuration/engine/checkpoint/CheckpointLoadMetadataConfig.md + - reference/configuration/engine/checkpoint/CheckpointPathConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointSaveConfig.md + - reference/configuration/engine/checkpoint/CheckpointSaveConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointSaveMetadataConfig.md + - reference/configuration/engine/checkpoint/CheckpointStateConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointStateSaveConfigBase.md + - Config Utils: + - reference/configuration/engine/config_utils/index.md + - Initialization: + - reference/configuration/engine/config_utils/initialization/index.md + - reference/configuration/engine/config_utils/initialization/DefaultInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/FillInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/InitializationConfig.md + - reference/configuration/engine/config_utils/initialization/NormalInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/UniformInitializationConfig.md + - Interval: + - reference/configuration/engine/config_utils/interval/index.md + - reference/configuration/engine/config_utils/interval/IntervalConfig.md + - Logging: + - reference/configuration/engine/config_utils/logging/index.md + - reference/configuration/engine/config_utils/logging/TensorLogsConfig.md + - Parameter: + - reference/configuration/engine/config_utils/parameter/index.md + - reference/configuration/engine/config_utils/parameter/OptionalParameterConfig.md + - reference/configuration/engine/config_utils/parameter/ParameterConfig.md + - Run: + - reference/configuration/engine/config_utils/run/index.md + - reference/configuration/engine/config_utils/run/ExperimentConfig.md + - reference/configuration/engine/config_utils/run/RunConfig.md + - Runnable: + - reference/configuration/engine/config_utils/runnable/index.md + - reference/configuration/engine/config_utils/runnable/RunnableConfig.md + - Distributed: + - reference/configuration/engine/distributed/index.md + - reference/configuration/engine/distributed/DistributedConfig.md + - Evaluation: + - reference/configuration/engine/evaluation/index.md + - reference/configuration/engine/evaluation/EvaluatorConfig.md + - reference/configuration/engine/evaluation/LmEvalEvaluatorConfig.md + - reference/configuration/engine/evaluation/LossEvaluatorConfig.md + - Multi Stage: + - reference/configuration/engine/multi_stage/index.md + - reference/configuration/engine/multi_stage/CheckpointMetadata.md + - reference/configuration/engine/multi_stage/FastLLMModelConfig.md + - reference/configuration/engine/multi_stage/MultiStageConfig.md + - reference/configuration/engine/multi_stage/PretrainedFastLLMModelConfig.md + - reference/configuration/engine/multi_stage/StageConfig.md + - Optimizer: + - reference/configuration/engine/optimizer/index.md + - reference/configuration/engine/optimizer/GradientScalerConfig.md + - reference/configuration/engine/optimizer/LearningRateScheduleConfig.md + - reference/configuration/engine/optimizer/OptimizerConfig.md + - Schedule: + - reference/configuration/engine/schedule/index.md + - reference/configuration/engine/schedule/ScheduleConfig.md + - Training: + - reference/configuration/engine/training/index.md + - reference/configuration/engine/training/CallbackConfig.md + - reference/configuration/engine/training/MetricsLogsConfig.md + - reference/configuration/engine/training/ShutdownConfig.md + - reference/configuration/engine/training/StreamingTrainerCallbackConfig.md + - reference/configuration/engine/training/TrainerCallbackConfig.md + - reference/configuration/engine/training/TrainerConfig.md + - reference/configuration/engine/training/TrainingCheckpointBaseConfig.md + - reference/configuration/engine/training/TrainingCheckpointConfig.md + - reference/configuration/engine/training/TrainingConfig.md + - reference/configuration/engine/training/TrainingExportConfig.md + - reference/configuration/engine/training/WandbAlertConfig.md + - reference/configuration/engine/training/WandbConfig.md + - reference/configuration/engine/training/WeightsBroadcastConfig.md + - Layers: + - reference/configuration/layers/index.md + - Attention: + - reference/configuration/layers/attention/index.md + - reference/configuration/layers/attention/AttentionConfig.md + - Rotary: + - reference/configuration/layers/attention/rotary/index.md + - reference/configuration/layers/attention/rotary/DefaultRotaryConfig.md + - reference/configuration/layers/attention/rotary/Llama3RotaryConfig.md + - reference/configuration/layers/attention/rotary/NoRotaryConfig.md + - reference/configuration/layers/attention/rotary/Rotary2DConfig.md + - reference/configuration/layers/attention/rotary/RotaryConfig.md + - reference/configuration/layers/attention/rotary/YarnRotaryConfig.md + - Block: + - reference/configuration/layers/block/index.md + - reference/configuration/layers/block/BlockConfig.md + - reference/configuration/layers/block/BlockSequenceConfig.md + - reference/configuration/layers/block/FixedBlockSequenceConfig.md + - reference/configuration/layers/block/PatternBlockSequenceConfig.md + - Common: + - reference/configuration/layers/common/index.md + - Linear: + - reference/configuration/layers/common/linear/index.md + - reference/configuration/layers/common/linear/AffineLinearBaseConfig.md + - reference/configuration/layers/common/linear/AffineLinearConfig.md + - reference/configuration/layers/common/linear/CausalConv1dConfig.md + - reference/configuration/layers/common/linear/LinearBaseConfig.md + - reference/configuration/layers/common/linear/LinearConfig.md + - Normalization: + - reference/configuration/layers/common/normalization/index.md + - reference/configuration/layers/common/normalization/GatedRMSNormalizationConfig.md + - reference/configuration/layers/common/normalization/LayerNormalizationBaseConfig.md + - reference/configuration/layers/common/normalization/LayerNormalizationConfig.md + - reference/configuration/layers/common/normalization/NoNormalizationConfig.md + - reference/configuration/layers/common/normalization/NormalizationConfig.md + - reference/configuration/layers/common/normalization/RMSNormalizationConfig.md + - Peft: + - reference/configuration/layers/common/peft/index.md + - reference/configuration/layers/common/peft/LoRAConfig.md + - reference/configuration/layers/common/peft/NoPeftConfig.md + - reference/configuration/layers/common/peft/PeftConfig.md + - Decoder: + - reference/configuration/layers/decoder/index.md + - reference/configuration/layers/decoder/BlockWithBiasConfig.md + - reference/configuration/layers/decoder/DecoderBlockConfig.md + - reference/configuration/layers/decoder/MLPBaseConfig.md + - reference/configuration/layers/decoder/MixerConfig.md + - reference/configuration/layers/decoder/StochasticMixerConfig.md + - Mlp: + - reference/configuration/layers/decoder/mlp/index.md + - reference/configuration/layers/decoder/mlp/MLPConfig.md + - reference/configuration/layers/decoder/mlp/MoEMLPConfig.md + - Language Model: + - reference/configuration/layers/language_model/index.md + - reference/configuration/layers/language_model/LanguageModelConfig.md + - reference/configuration/layers/language_model/LanguageModelEmbeddingsConfig.md + - reference/configuration/layers/language_model/LanguageModelHeadConfig.md + - Loss: + - reference/configuration/layers/language_model/loss/index.md + - reference/configuration/layers/language_model/loss/LanguageModelDPOLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelDistillationLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelGRPOLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelLabelEntropyLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelZLossConfig.md + - Vision: + - reference/configuration/layers/vision/index.md + - reference/configuration/layers/vision/PatchEmbeddingsConfig.md + - reference/configuration/layers/vision/VisionEncoderConfig.md + - reference/configuration/layers/vision/VisionMultiModalModelConfig.md + - Models: + - reference/configuration/models/index.md + - Gpt: + - reference/configuration/models/gpt/index.md + - reference/configuration/models/gpt/GPTBaseModelConfig.md + - reference/configuration/models/gpt/GPTModelConfig.md + - reference/configuration/models/gpt/GPTTrainerConfig.md + - reference/configuration/models/gpt/PretrainedGPTModelConfig.md + - Multimodal: + - reference/configuration/models/multimodal/index.md + - reference/configuration/models/multimodal/MultiModalBaseModelConfig.md + - reference/configuration/models/multimodal/MultiModalModelConfig.md + - reference/configuration/models/multimodal/MultiModalTrainerConfig.md + - reference/configuration/models/multimodal/PretrainedMultiModalModelConfig.md + - Profile: + - reference/configuration/profile/index.md + - reference/configuration/profile/ProfilingConfig.md + # END AUTO-GENERATED CONFIG REFERENCE - About Us: about-us.md - Join Us: join-us.md From d752f03834c5203844190e748be048b55b2ab569 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 15:52:56 -0400 Subject: [PATCH 18/33] Add config docs generator script and MkDocs hook, ignore generated output Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 1 + .../hooks/generate_config_docs_hook.py | 27 + tools/generate_config_docs.py | 772 ++++++++++++++++++ 3 files changed, 800 insertions(+) create mode 100644 docs/overrides/hooks/generate_config_docs_hook.py create mode 100644 tools/generate_config_docs.py diff --git a/.gitignore b/.gitignore index f468ffd00..8a01a55f7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ # Doc build .cache site +docs/reference/ # Distribution / packaging *.egg-info/ diff --git a/docs/overrides/hooks/generate_config_docs_hook.py b/docs/overrides/hooks/generate_config_docs_hook.py new file mode 100644 index 000000000..39af6fd05 --- /dev/null +++ b/docs/overrides/hooks/generate_config_docs_hook.py @@ -0,0 +1,27 @@ +"""MkDocs hook: regenerate config reference docs before each build.""" + +import importlib.util +import pathlib +import sys + +_REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.parent +_SCRIPT = _REPO_ROOT / "tools" / "generate_config_docs.py" + +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + + +def _load_gen(): + spec = importlib.util.spec_from_file_location("generate_config_docs", _SCRIPT) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def on_pre_build(config) -> None: # noqa: ANN001 + """Regenerate config reference markdown before the build processes files.""" + gen = _load_gen() + # Regenerate pages but do not update mkdocs.yaml — nav must be updated + # manually by running `python tools/generate_config_docs.py` when config + # classes are added or modules are restructured. + gen.generate(update_nav=False, verbose=False) diff --git a/tools/generate_config_docs.py b/tools/generate_config_docs.py new file mode 100644 index 000000000..d797eb4ae --- /dev/null +++ b/tools/generate_config_docs.py @@ -0,0 +1,772 @@ +#!/usr/bin/env python3 +""" +Generate markdown documentation for Fast-LLM configuration classes. + +Walks the fast_llm package, finds all @config_class-decorated classes, and writes +one markdown file per class under docs/reference/configuration/, mirroring the +package structure. Also writes index.md files per directory and updates the nav +section in mkdocs.yaml. + +Usage: + python tools/generate_config_docs.py +""" + +import dataclasses +import importlib +import pathlib +import pkgutil +import re +import sys +import types +import typing + +from fast_llm.config import Config, Field, FieldHint, FieldHintImportance # noqa: E402 + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +REPO_ROOT = pathlib.Path(__file__).parent.parent +OUTPUT_DIR = REPO_ROOT / "docs" / "reference" / "configuration" +MKDOCS_YAML = REPO_ROOT / "mkdocs.yaml" + +sys.path.insert(0, str(REPO_ROOT)) + + +# --------------------------------------------------------------------------- +# Field filtering +# --------------------------------------------------------------------------- + +# Hints that describe internal/computed/testing fields — not useful in config docs. +EXCLUDED_HINTS: set[FieldHint] = {FieldHint.derived, FieldHint.setup, FieldHint.testing} + +# Field names that are always excluded regardless of hint. +EXCLUDED_FIELD_NAMES: set[str] = {"type"} + + +def is_user_field(name: str, field: Field) -> bool: + """Return True if this field should appear in user-facing documentation.""" + if name.startswith("_"): + return False + if name in EXCLUDED_FIELD_NAMES: + return False + if not field.init or field._field_type is not dataclasses._FIELD: # noqa: SLF001 + return False + if getattr(field, "hint", None) in EXCLUDED_HINTS: + return False + return True + + +# --------------------------------------------------------------------------- +# Module collection +# --------------------------------------------------------------------------- + + +def import_all_config_modules() -> None: + """Import every module in the fast_llm package so all Config subclasses are registered.""" + import fast_llm # noqa: F401 + + for module_info in pkgutil.walk_packages( + path=[str(REPO_ROOT / "fast_llm")], + prefix="fast_llm.", + onerror=lambda name: None, + ): + # Only import config modules — they are safe to import without GPU. + if not module_info.name.endswith(".config"): + continue + try: + importlib.import_module(module_info.name) + except Exception as exc: # noqa: BLE001 + print(f" [skip] {module_info.name}: {exc}", file=sys.stderr) + + +def collect_config_classes() -> dict[type, dict]: + """ + Return a dict mapping each Config subclass to metadata: + { + "module": str, + "fields": list[(name, Field, resolved_type)], + "registry": dict[str, type] | None, # subclasses if this has a registry + "registered_in": list[(base_cls, type_key)], # registries this class is in + "abstract": bool, + } + """ + import fast_llm.config as config_module + + config_base = config_module.Config + + # Collect all Config subclasses that have been processed by @config_class. + found: dict[type, dict] = {} + for cls in _all_subclasses(config_base): + if not getattr(cls, "__class_validated__", False): + continue + if cls.__module__ == "builtins": + continue + found[cls] = { + "module": cls.__module__, + "fields": [], + "registry": None, + "registered_in": [], + "abstract": bool(getattr(cls, "_abstract", False)), + } + + # Resolve type hints and build field lists. + for cls, info in found.items(): + try: + hints = typing.get_type_hints(cls) + except Exception: # noqa: BLE001 + hints = {} + for name, field in cls.fields(): + if not is_user_field(name, field): + continue + resolved = hints.get(name, field.type) + info["fields"].append((name, field, resolved)) + # Sort by hint importance (lower = more important), then alphabetically. + info["fields"].sort( + key=lambda t: (FieldHintImportance.get(getattr(t[1], "hint", FieldHint.unknown), 20), t[0]) + ) + + # Build registry info. + for cls, info in found.items(): + registry = getattr(cls, "_registry", None) + if registry is not None: + info["registry"] = {key: found_cls for key in registry if (found_cls := registry[key]) in found} + + # Build registered_in back-references. + for cls, info in found.items(): + registry = getattr(cls, "_registry", None) + if registry is None: + continue + for key in registry: + subclass = registry[key] + if subclass in found: + found[subclass]["registered_in"].append((cls, key)) + + return found + + +def _all_subclasses(cls: type) -> list[type]: + """Recursively collect all subclasses of a class.""" + result = [] + queue = list(cls.__subclasses__()) + seen = set() + while queue: + sub = queue.pop() + if sub in seen: + continue + seen.add(sub) + result.append(sub) + queue.extend(sub.__subclasses__()) + return result + + +# --------------------------------------------------------------------------- +# Back-reference computation +# --------------------------------------------------------------------------- + + +def build_back_refs(found: dict[type, dict]) -> dict[type, list[tuple[type, str]]]: + """ + For each config class, find all (owner_class, field_name) pairs that reference it + as part of their field type. + """ + back_refs: dict[type, list[tuple[type, str]]] = {cls: [] for cls in found} + + for owner_cls, info in found.items(): + for name, _field, resolved_type in info["fields"]: + for referenced_cls in _extract_config_types(resolved_type, found): + back_refs[referenced_cls].append((owner_cls, name)) + + return back_refs + + +def _extract_config_types(annotation, found: dict[type, dict]) -> list[type]: + """Extract all Config subclass types referenced in an annotation.""" + results = [] + if isinstance(annotation, type) and annotation in found: + results.append(annotation) + elif isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + for arg in typing.get_args(annotation): + results.extend(_extract_config_types(arg, found)) + elif hasattr(annotation, "__origin__"): + for arg in typing.get_args(annotation): + results.extend(_extract_config_types(arg, found)) + return results + + +# --------------------------------------------------------------------------- +# Type rendering +# --------------------------------------------------------------------------- + + +def render_type( + annotation, + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + own_path: pathlib.Path, +) -> str: + """Render a type annotation as a markdown string, linking to Config class pages.""" + if annotation is type(None): + return "`None`" + if annotation is typing.Any: + return "`Any`" + if isinstance(annotation, type): + if annotation in found: + rel_path = cls_output_paths.get(annotation) + if rel_path is not None: + link = _relative_link(own_path, rel_path) + return f"[{annotation.__name__}]({link})" + return f"`{annotation.__name__}`" + if issubclass(annotation, type): + return "`type`" + return f"`{annotation.__name__}`" + if isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + args = [a for a in typing.get_args(annotation) if a is not type(None)] + none_part = " or `None`" if type(None) in typing.get_args(annotation) else "" + inner = " or ".join(render_type(a, found, cls_output_paths, own_path) for a in args) + return inner + none_part + if hasattr(annotation, "__origin__"): + origin = annotation.__origin__ + args = typing.get_args(annotation) + if origin is list: + return f"list[{render_type(args[0], found, cls_output_paths, own_path)}]" if args else "`list`" + if origin is dict: + k = render_type(args[0], found, cls_output_paths, own_path) if args else "`Any`" + v = render_type(args[1], found, cls_output_paths, own_path) if len(args) > 1 else "`Any`" + return f"dict[{k}, {v}]" + if origin is tuple: + inner = ", ".join(render_type(a, found, cls_output_paths, own_path) for a in args) + return f"tuple[{inner}]" + if origin is set: + return f"set[{render_type(args[0], found, cls_output_paths, own_path)}]" if args else "`set`" + # Fallback for other generics + return f"`{getattr(origin, '__name__', str(origin))}`" + return f"`{annotation}`" + + +def render_default(field: Field, resolved_type, found: dict[type, dict]) -> str: + """Render the default value of a field as a string.""" + if field.default is not dataclasses.MISSING: + value = field.default + if isinstance(value, str): + return f'`"{value}"`' + if value is None: + return "`None`" + # Class objects: show the class name, not `` + if isinstance(value, type): + return f"`{value.__name__}`" + # Large integers: insert underscores every 3 digits for readability + if isinstance(value, int) and abs(value) > 999_999: + return f"`{value:_}`" + return f"`{value}`" + if field.default_factory is not dataclasses.MISSING: + factory = field.default_factory + # A factory that is itself a Config class means "instantiate with defaults". + if isinstance(factory, type) and factory in found: + return "*(sub-fields optional)*" + if hasattr(factory, "__name__"): + return f"`{factory.__name__}()`" + # If the type itself is a Config class, the value is still required in YAML + # but every sub-field within it has its own default — don't say "required". + core_type = _unwrap_optional(resolved_type) + if isinstance(core_type, type) and core_type in found: + return "*(sub-fields optional)*" + return "*(required)*" + + +def _unwrap_optional(annotation) -> type | None: + """Return the inner type of Optional[X] / X | None, or the annotation itself.""" + if isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + args = [a for a in typing.get_args(annotation) if a is not type(None)] + if len(args) == 1: + return args[0] + return annotation + + +# --------------------------------------------------------------------------- +# Output path computation +# --------------------------------------------------------------------------- + + +def get_module_dir(module_name: str) -> pathlib.Path: + """ + Convert a module name like 'fast_llm.engine.distributed.config' to a + relative output path like 'engine/distributed'. + """ + parts = module_name.split(".") + # Strip 'fast_llm' prefix. + if parts and parts[0] == "fast_llm": + parts = parts[1:] + # Strip trailing 'config'. + if parts and parts[-1] == "config": + parts = parts[:-1] + return pathlib.Path(*parts) if parts else pathlib.Path(".") + + +def compute_output_paths(found: dict[type, dict]) -> dict[type, pathlib.Path]: + """ + Return a dict mapping each class to its output path relative to OUTPUT_DIR, + e.g. engine/distributed/DistributedConfig.md + """ + return {cls: get_module_dir(info["module"]) / f"{cls.__name__}.md" for cls, info in found.items()} + + +# --------------------------------------------------------------------------- +# Markdown rendering +# --------------------------------------------------------------------------- + + +def render_hint_badge(hint: FieldHint) -> str: + badge_map = { + FieldHint.core: "core", + FieldHint.architecture: "architecture", + FieldHint.optional: "optional", + FieldHint.performance: "performance", + FieldHint.stability: "stability", + FieldHint.feature: "feature", + FieldHint.expert: "expert", + FieldHint.logging: "logging", + FieldHint.deprecated: "deprecated", + FieldHint.wip: "wip", + FieldHint.unknown: "", + } + label = badge_map.get(hint, str(hint)) + return f"`{label}`" if label else "" + + +def render_class_page( + cls: type, + info: dict, + back_refs: list[tuple[type, str]], + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + own_path: pathlib.Path, +) -> str: + """Render the full markdown page for a config class.""" + lines = [] + + # Title + lines.append(f"# {cls.__name__}\n") + + # Abstract badge + if info["abstract"]: + lines.append( + '!!! note "Abstract"\n This class cannot be instantiated directly. Use one of the variants listed below.\n' + ) + + # Module + lines.append(f"**Module:** `{cls.__module__}`\n") + + # Registered as / variant of + if info["registered_in"]: + for base_cls, type_key in info["registered_in"]: + base_path = cls_output_paths.get(base_cls) + if base_path is not None: + rel = _relative_link(own_path, base_path) + lines.append(f"**Variant of:** [{base_cls.__name__}]({rel}) — select with `type: {type_key}`\n") + else: + lines.append(f"**Variant of:** `{base_cls.__name__}` — select with `type: {type_key}`\n") + + # Inheritance (Config parents only, skip Config itself and internal bases) + config_parents = [ + base + for base in cls.__mro__[1:] + if base is not cls + and isinstance(base, type) + and issubclass(base, Config) + and base.__name__ != "Config" + and base in found + ] + if config_parents: + parent_links = [] + for parent in config_parents[:3]: # limit to 3 to avoid noise + p_path = cls_output_paths.get(parent) + if p_path is not None: + rel = _relative_link(own_path, p_path) + parent_links.append(f"[{parent.__name__}]({rel})") + else: + parent_links.append(f"`{parent.__name__}`") + lines.append(f"**Inherits from:** {', '.join(parent_links)}\n") + + lines.append("") + + # Fields — definition list, one entry per field + user_fields = info["fields"] + if user_fields: + lines.append("## Fields\n") + for name, field, resolved_type in user_fields: + type_str = render_type(resolved_type, found, cls_output_paths, own_path) + default_str = render_default(field, resolved_type, found) + hint = getattr(field, "hint", FieldHint.unknown) + hint_str = render_hint_badge(hint) + desc = getattr(field, "desc", None) or "" + doc = getattr(field, "doc", None) + if doc: + desc = f"{desc} {doc}".strip() if desc else doc + # Flatten multi-line descriptions (newlines break def-list indentation). + desc = " ".join(desc.split()) + # Term: field name + hint badge (omit separator when hint is empty) + term = f"`{name}`" + (f" — {hint_str}" if hint_str else "") + lines.append(term) + # Definition: metadata line, then description as a separate paragraph. + meta = f"**Type:** {type_str}    **Default:** {default_str}" + lines.append(f": {meta}") + if desc: + # Blank line + 4-space indent = new paragraph within the definition. + lines.append(f"") + lines.append(f" {desc}") + lines.append("") + else: + lines.append("*No user-configurable fields.*\n") + + # Variants table (if this class has a registry) + registry = info.get("registry") + if registry: + lines.append("## Variants\n") + lines.append("Select a variant by setting `type:` to one of the following values.\n") + lines.append("| `type` value | Class | Description |") + lines.append("|--------------|-------|-------------|") + for key in sorted(registry): + subclass = registry[key] + sub_path = cls_output_paths.get(subclass) + if sub_path is not None: + rel = _relative_link(own_path, sub_path) + class_link = f"[{subclass.__name__}]({rel})" + else: + class_link = f"`{subclass.__name__}`" + sub_info = found.get(subclass, {}) + desc = _class_one_liner(subclass, sub_info) + lines.append(f"| `{key}` | {class_link} | {desc} |") + lines.append("") + + # Used in (back-references) + if back_refs: + lines.append("## Used in\n") + seen = set() + for owner_cls, field_name in sorted(back_refs, key=lambda t: (t[0].__name__, t[1])): + key = (owner_cls, field_name) + if key in seen: + continue + seen.add(key) + owner_path = cls_output_paths.get(owner_cls) + if owner_path is not None: + rel = _relative_link(own_path, owner_path) + lines.append(f"- [`{field_name}`]({rel}) in [{owner_cls.__name__}]({rel})") + else: + lines.append(f"- `{field_name}` in `{owner_cls.__name__}`") + lines.append("") + + return "\n".join(lines) + + +def _class_one_liner(cls: type, info: dict) -> str: + """Return a short description for a class, or empty string if none is available.""" + doc = getattr(cls, "__doc__", None) + if doc: + first_line = doc.strip().split("\n")[0].strip().rstrip(".") + # Skip auto-generated __init__ signatures like "ClassName(**kwargs)" + if first_line and not re.match(r"^\w.*\(.*\)\s*$", first_line): + return first_line + return "" + + +def _relative_link(from_path: pathlib.Path, to_path: pathlib.Path) -> str: + """ + Compute a relative markdown link from one page to another, + both paths relative to OUTPUT_DIR. + """ + from_dir = from_path.parent + try: + rel = pathlib.Path(to_path).relative_to(from_dir) + except ValueError: + # Go up from from_dir to the common ancestor + parts_from = from_dir.parts + parts_to = to_path.parts + # Find common prefix length + common = 0 + for a, b in zip(parts_from, parts_to): + if a == b: + common += 1 + else: + break + up = len(parts_from) - common + rel = pathlib.Path(*[".."] * up, *parts_to[common:]) + return str(rel).replace("\\", "/") + + +# --------------------------------------------------------------------------- +# Index page rendering +# --------------------------------------------------------------------------- + + +def render_index_page( + directory: pathlib.Path, + classes_in_dir: list[tuple[type, dict]], + cls_output_paths: dict[type, pathlib.Path], + subdirs: list[pathlib.Path], +) -> str: + """Render an index.md for a directory.""" + lines = [] + + # Title: use the directory name + if directory == pathlib.Path("."): + title = "Configuration Reference" + else: + title = " / ".join(p.replace("_", " ").title() for p in directory.parts) + lines.append(f"# {title}\n") + + directory / "index.md" + + # Subdirectory links + if subdirs: + lines.append("## Sections\n") + for subdir in sorted(subdirs): + section_name = subdir.name.replace("_", " ").title() + rel = str((subdir / "index.md").relative_to(directory)).replace("\\", "/") + lines.append(f"- [{section_name}]({rel})") + lines.append("") + + # Class table + if classes_in_dir: + lines.append("## Classes\n") + lines.append("| Class | Description |") + lines.append("|-------|-------------|") + for cls, info in sorted(classes_in_dir, key=lambda t: t[0].__name__): + cls_path = cls_output_paths[cls] + rel = str(cls_path.relative_to(directory)).replace("\\", "/") + desc = _class_one_liner(cls, info) + abstract_note = " *(abstract)*" if info["abstract"] else "" + lines.append(f"| [{cls.__name__}]({rel}){abstract_note} | {desc} |") + lines.append("") + + return "\n".join(lines) + + +def render_root_index( + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + top_level_dirs: list[pathlib.Path], +) -> str: + """Render the top-level index.md.""" + lines = [ + "# Configuration Reference\n", + "This reference documents all configuration classes in Fast-LLM.", + "Configurations are YAML files passed to the `fast-llm` CLI.", + "The entry point is `GPTTrainerConfig`, which composes all other configurations.\n", + "## Sections\n", + ] + for d in sorted(top_level_dirs): + section_name = d.name.replace("_", " ").title() + lines.append(f"- [{section_name}]({d.name}/index.md)") + lines.append("") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Nav generation +# --------------------------------------------------------------------------- + + +def build_nav_tree(cls_output_paths: dict[type, pathlib.Path], found: dict[type, dict]) -> dict: + """ + Build a nested dict representing the nav tree: + { dir_path: { "index": index_path, "classes": [...], "subdirs": {subdir: ...} } } + """ + tree: dict = {} + + for cls, rel_path in cls_output_paths.items(): + parts = rel_path.parent.parts + node = tree + for part in parts: + node = node.setdefault(part, {}) + node.setdefault("_classes", []).append(cls) + + return tree + + +def nav_entries( + tree: dict, + cls_output_paths: dict[type, pathlib.Path], + prefix: pathlib.Path = pathlib.Path("."), +) -> list: + """Recursively build the mkdocs nav list for the config reference section.""" + entries = [] + + # Index for this directory + if prefix == pathlib.Path("."): + index_rel = "reference/configuration/index.md" + else: + index_rel = f"reference/configuration/{prefix}/index.md".replace("\\", "/") + entries.append(index_rel) + + # Classes directly in this directory + classes = tree.get("_classes", []) + for cls in sorted(classes, key=lambda c: c.__name__): + rel = cls_output_paths[cls] + entries.append(f"reference/configuration/{rel}".replace("\\", "/")) + + # Subdirectories + for key, subtree in sorted((k, v) for k, v in tree.items() if not k.startswith("_")): + subprefix = prefix / key if prefix != pathlib.Path(".") else pathlib.Path(key) + section_name = key.replace("_", " ").title() + sub_entries = nav_entries(subtree, cls_output_paths, subprefix) + entries.append({section_name: sub_entries}) + + return entries + + +def format_nav_yaml(entries: list, indent: int = 0) -> list[str]: + """Render nav entries as YAML lines.""" + lines = [] + pad = " " * indent + for entry in entries: + if isinstance(entry, str): + lines.append(f"{pad}- {entry}") + elif isinstance(entry, dict): + for key, sub_entries in entry.items(): + lines.append(f"{pad}- {key}:") + lines.extend(format_nav_yaml(sub_entries, indent + 1)) + return lines + + +# --------------------------------------------------------------------------- +# mkdocs.yaml nav update +# --------------------------------------------------------------------------- + +NAV_SENTINEL_START = " # BEGIN AUTO-GENERATED CONFIG REFERENCE" +NAV_SENTINEL_END = " # END AUTO-GENERATED CONFIG REFERENCE" + + +def update_mkdocs_nav(nav_lines: list[str]) -> None: + """ + Replace the auto-generated config reference section in mkdocs.yaml. + If the sentinels are not present, append the section to the nav. + """ + content = MKDOCS_YAML.read_text() + + new_block = "\n".join([NAV_SENTINEL_START] + nav_lines + [NAV_SENTINEL_END]) + + if NAV_SENTINEL_START in content and NAV_SENTINEL_END in content: + # Replace existing block + pattern = re.escape(NAV_SENTINEL_START) + r".*?" + re.escape(NAV_SENTINEL_END) + content = re.sub(pattern, new_block, content, flags=re.DOTALL) + else: + # Append before the last line of the nav section + # Find the nav: key and append at the end of its list + lines = content.splitlines() + # Find the last non-empty line inside the nav block (heuristic: insert before next top-level key) + insert_at = len(lines) + in_nav = False + for i, line in enumerate(lines): + if line.startswith("nav:"): + in_nav = True + elif in_nav and line and not line.startswith(" "): + insert_at = i + break + indent = " " + nav_indented = "\n".join(indent + l for l in new_block.splitlines()) + lines.insert(insert_at, nav_indented) + content = "\n".join(lines) + "\n" + + MKDOCS_YAML.write_text(content) + print(f"Updated nav in {MKDOCS_YAML}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def generate(*, update_nav: bool = True, verbose: bool = True) -> None: + """Generate all config reference docs, optionally updating mkdocs.yaml nav.""" + + def log(msg: str) -> None: + if verbose: + print(msg) + + log("Importing fast_llm config modules...") + import_all_config_modules() + + log("Collecting config classes...") + found = collect_config_classes() + log(f" Found {len(found)} config classes") + + log("Computing output paths...") + cls_output_paths = compute_output_paths(found) + + log("Building back-references...") + back_refs = build_back_refs(found) + + # Group classes by output directory + dir_to_classes: dict[pathlib.Path, list[tuple[type, dict]]] = {} + for cls, info in found.items(): + directory = cls_output_paths[cls].parent + dir_to_classes.setdefault(directory, []).append((cls, info)) + + log(f"Writing to {OUTPUT_DIR} ...") + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Write class pages + for cls, info in found.items(): + rel_path = cls_output_paths[cls] + out_path = OUTPUT_DIR / rel_path + out_path.parent.mkdir(parents=True, exist_ok=True) + content = render_class_page(cls, info, back_refs[cls], found, cls_output_paths, rel_path) + out_path.write_text(content) + + # Write index pages — include all ancestor directories, not just leaf dirs with classes. + leaf_dirs = {cls_output_paths[cls].parent for cls in found} + all_dirs: set[pathlib.Path] = set() + for directory in leaf_dirs: + all_dirs.add(directory) + for i in range(len(directory.parts)): + all_dirs.add(pathlib.Path(*directory.parts[:i]) if i > 0 else pathlib.Path(".")) + + # Find all top-level directories (direct children of output root) + top_level_dirs = sorted({d.parts[0] for d in all_dirs if d != pathlib.Path(".")}) + + for directory in sorted(all_dirs): + classes_in_dir = dir_to_classes.get(directory, []) + # Find immediate subdirectories + subdirs = sorted( + { + directory / d.parts[len(directory.parts)] + for d in all_dirs + if len(d.parts) > len(directory.parts) and d.parts[: len(directory.parts)] == directory.parts + } + ) + index_content = render_index_page(directory, classes_in_dir, cls_output_paths, subdirs) + index_path = OUTPUT_DIR / directory / "index.md" + index_path.parent.mkdir(parents=True, exist_ok=True) + index_path.write_text(index_content) + + # Write root index + root_index = render_root_index( + found, + cls_output_paths, + [pathlib.Path(d) for d in top_level_dirs], + ) + (OUTPUT_DIR / "index.md").write_text(root_index) + + if update_nav: + log("Updating mkdocs.yaml nav...") + tree = build_nav_tree(cls_output_paths, found) + nav_root = nav_entries(tree, cls_output_paths) + nav_yaml_lines = format_nav_yaml([{"Configuration Reference": nav_root}], indent=1) + update_mkdocs_nav(nav_yaml_lines) + + log("Done.") + + +def main() -> None: + generate(update_nav=True, verbose=True) + + +if __name__ == "__main__": + main() From d47904cea2e3dea675592665d21726480a5222a6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 18:18:21 -0400 Subject: [PATCH 19/33] Add unit tests for generate_config_docs.py 80 tests covering get_module_dir, _relative_link, _unwrap_optional, render_hint_badge, _class_one_liner, is_user_field, _extract_config_types, render_type, render_default, format_nav_yaml, and smoke tests for render_class_page / render_index_page. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/__init__.py | 0 tests/tools/test_generate_config_docs.py | 470 +++++++++++++++++++++++ 2 files changed, 470 insertions(+) create mode 100644 tests/tools/__init__.py create mode 100644 tests/tools/test_generate_config_docs.py diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/test_generate_config_docs.py b/tests/tools/test_generate_config_docs.py new file mode 100644 index 000000000..c9c7dc654 --- /dev/null +++ b/tests/tools/test_generate_config_docs.py @@ -0,0 +1,470 @@ +"""Unit tests for tools/generate_config_docs.py.""" + +import importlib.util +import pathlib +import typing + +import pytest + +from fast_llm.config import Config, Field, FieldHint, config_class + +# --------------------------------------------------------------------------- +# Load the generator module via importlib (it is not a package). +# --------------------------------------------------------------------------- + +_SCRIPT = pathlib.Path(__file__).parent.parent.parent / "tools" / "generate_config_docs.py" +_spec = importlib.util.spec_from_file_location("generate_config_docs", _SCRIPT) +_gen = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_gen) + + +# --------------------------------------------------------------------------- +# Minimal synthetic Config classes used across multiple tests. +# --------------------------------------------------------------------------- + + +@config_class() +class _InnerConfig(Config): + """A simple inner config for doc-generation tests.""" + + value: int = Field(default=0, hint=FieldHint.core, desc="A value.") + + +@config_class() +class _OuterConfig(Config): + """An outer config that references _InnerConfig.""" + + inner: _InnerConfig = Field(hint=FieldHint.core, desc="Inner config.") + required: str = Field(hint=FieldHint.core, desc="Required string field.") + inner_optional: _InnerConfig | None = Field(default=None, hint=FieldHint.feature, desc="Optional inner.") + string: str = Field(default="hello", hint=FieldHint.core, desc="A string.") + large_int: int = Field(default=2**32, hint=FieldHint.core, desc="A large integer.") + list_of_str: list[str] = Field(default_factory=list, hint=FieldHint.core, desc="A list of strings.") + dict_field: dict[str, int] = Field(default_factory=dict, hint=FieldHint.core, desc="A dict.") + + +# Minimal `found` and `cls_output_paths` dicts used in render_* tests. +_FOUND: dict = { + _InnerConfig: { + "module": "tests.tools._InnerConfig", + "fields": [], + "registry": None, + "registered_in": [], + "abstract": False, + }, + _OuterConfig: { + "module": "tests.tools._OuterConfig", + "fields": [], + "registry": None, + "registered_in": [], + "abstract": False, + }, +} +_CLS_OUTPUT_PATHS: dict[type, pathlib.Path] = { + _InnerConfig: pathlib.Path("tests/InnerConfig.md"), + _OuterConfig: pathlib.Path("tests/OuterConfig.md"), +} +_OWN_PATH = pathlib.Path("tests/SomeConfig.md") + +_OUTER_FIELDS = dict(_OuterConfig.fields()) +_OUTER_HINTS = typing.get_type_hints(_OuterConfig) + + +# --------------------------------------------------------------------------- +# get_module_dir +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "module_name, expected", + [ + ("fast_llm.config", pathlib.Path(".")), + ("fast_llm.engine.distributed.config", pathlib.Path("engine/distributed")), + ("fast_llm.data.dataset.config", pathlib.Path("data/dataset")), + ("fast_llm.models.gpt.config", pathlib.Path("models/gpt")), + ("fast_llm.engine.training.config", pathlib.Path("engine/training")), + # Module without trailing .config — just strip the fast_llm prefix. + ("fast_llm.profile", pathlib.Path("profile")), + ], +) +def test_get_module_dir(module_name, expected): + assert _gen.get_module_dir(module_name) == expected + + +# --------------------------------------------------------------------------- +# _relative_link +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "from_path, to_path, expected", + [ + # Same directory. + ("engine/distributed/A.md", "engine/distributed/B.md", "B.md"), + # Descend into child directory. + ("engine/A.md", "engine/distributed/B.md", "distributed/B.md"), + # Ascend to parent directory. + ("engine/distributed/A.md", "engine/B.md", "../B.md"), + # Sibling directory (up one, down one). + ("engine/distributed/A.md", "engine/training/B.md", "../training/B.md"), + # Deep cross-package link. + ("engine/training/runner/A.md", "data/dataset/B.md", "../../../data/dataset/B.md"), + # Top-level sibling packages. + ("engine/A.md", "data/B.md", "../data/B.md"), + ], +) +def test_relative_link(from_path, to_path, expected): + assert _gen._relative_link(pathlib.Path(from_path), pathlib.Path(to_path)) == expected + + +# --------------------------------------------------------------------------- +# _unwrap_optional +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected", + [ + (int | None, int), + (str | None, str), + (_InnerConfig | None, _InnerConfig), + (int, int), + (str, str), + ], +) +def test_unwrap_optional_strips_none(annotation, expected): + assert _gen._unwrap_optional(annotation) is expected + + +def test_unwrap_optional_union_unchanged(): + # Two non-None types: should not be simplified. + annotation = int | str + assert _gen._unwrap_optional(annotation) is annotation + + +def test_unwrap_optional_triple_union_unchanged(): + # Optional with two non-None types: should not be simplified. + annotation = int | str | None + assert _gen._unwrap_optional(annotation) is annotation + + +# --------------------------------------------------------------------------- +# render_hint_badge +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "hint, expected", + [ + (FieldHint.core, "`core`"), + (FieldHint.architecture, "`architecture`"), + (FieldHint.optional, "`optional`"), + (FieldHint.performance, "`performance`"), + (FieldHint.feature, "`feature`"), + (FieldHint.expert, "`expert`"), + (FieldHint.logging, "`logging`"), + (FieldHint.deprecated, "`deprecated`"), + (FieldHint.wip, "`wip`"), + # unknown → empty string (no badge). + (FieldHint.unknown, ""), + ], +) +def test_render_hint_badge(hint, expected): + assert _gen.render_hint_badge(hint) == expected + + +# --------------------------------------------------------------------------- +# _class_one_liner +# --------------------------------------------------------------------------- + + +class _DocOneLiner: + """A clean one-liner description.""" + + +class _DocMultiLine: + """First line only. + + More detail here that should be ignored. + """ + + +class _DocAutoSignature: + """SomeName(**kwargs)""" + + +class _DocNoDocstring: + pass + + +class _DocTrailingDot: + """Description ending with a dot.""" + + +@pytest.mark.parametrize( + "cls, expected", + [ + (_DocOneLiner, "A clean one-liner description"), + (_DocMultiLine, "First line only"), + (_DocAutoSignature, ""), # auto-generated __init__ signature — filtered out + (_DocNoDocstring, ""), + (_DocTrailingDot, "Description ending with a dot"), # trailing dot stripped + ], +) +def test_class_one_liner(cls, expected): + assert _gen._class_one_liner(cls, {}) == expected + + +# --------------------------------------------------------------------------- +# is_user_field — uses fields extracted from a synthetic Config class +# --------------------------------------------------------------------------- + + +@config_class() +class _IsUserFieldConfig(Config): + normal: str = Field(default="x", hint=FieldHint.core, desc="Normal field.") + feature: str = Field(default="x", hint=FieldHint.feature, desc="Feature field.") + derived: str = Field(default="x", hint=FieldHint.derived, desc="Derived field.") + testing: str = Field(default="x", hint=FieldHint.testing, desc="Testing field.") + setup_field: str = Field(default="x", hint=FieldHint.setup, desc="Setup field.") + + +_IS_USER_FIELD_FIELDS = dict(_IsUserFieldConfig.fields()) + + +@pytest.mark.parametrize( + "field_name, expected", + [ + ("normal", True), + ("feature", True), + ("derived", False), # excluded hint + ("testing", False), # excluded hint + ("setup_field", False), # excluded hint + ], +) +def test_is_user_field_hint(field_name, expected): + assert _gen.is_user_field(field_name, _IS_USER_FIELD_FIELDS[field_name]) == expected + + +@pytest.mark.parametrize( + "name, expected", + [ + ("_private", False), # underscore prefix → always excluded + ("type", False), # "type" is always excluded regardless of field content + ("normal_name", True), + ], +) +def test_is_user_field_name(name, expected): + # Use a valid public field object; only the name varies. + field = _IS_USER_FIELD_FIELDS["normal"] + assert _gen.is_user_field(name, field) == expected + + +# --------------------------------------------------------------------------- +# _extract_config_types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected_set", + [ + (_InnerConfig, {_InnerConfig}), + (_InnerConfig | None, {_InnerConfig}), + (list[_InnerConfig], {_InnerConfig}), + (dict[str, _InnerConfig], {_InnerConfig}), + (_InnerConfig | _OuterConfig, {_InnerConfig, _OuterConfig}), + (int, set()), + (str | None, set()), + (list[str], set()), + ], +) +def test_extract_config_types(annotation, expected_set): + result = _gen._extract_config_types(annotation, _FOUND) + assert set(result) == expected_set + + +# --------------------------------------------------------------------------- +# render_type +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected", + [ + (str, "`str`"), + (int, "`int`"), + (bool, "`bool`"), + (type(None), "`None`"), + (typing.Any, "`Any`"), + (str | None, "`str` or `None`"), + (int | None, "`int` or `None`"), + (list[str], "list[`str`]"), + (list[int], "list[`int`]"), + (dict[str, int], "dict[`str`, `int`]"), + (tuple[str, int], "tuple[`str`, `int`]"), + (set[str], "set[`str`]"), + ], +) +def test_render_type_primitives(annotation, expected): + assert _gen.render_type(annotation, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) == expected + + +def test_render_type_config_produces_link(): + result = _gen.render_type(_InnerConfig, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) + # Should be a markdown link to the class page. + assert result.startswith("[_InnerConfig](") + assert result.endswith(")") + + +def test_render_type_config_not_in_found(): + # Config type absent from found → backtick name, no link. + result = _gen.render_type(_InnerConfig, {}, {}, _OWN_PATH) + assert result == "`_InnerConfig`" + + +def test_render_type_optional_config(): + result = _gen.render_type(_InnerConfig | None, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) + assert "[_InnerConfig](" in result + assert "or `None`" in result + + +# --------------------------------------------------------------------------- +# render_default +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "field_name, expected", + [ + ("string", '`"hello"`'), + ("large_int", "`4_294_967_296`"), # 2**32 with underscores + ("list_of_str", "`list()`"), + ("dict_field", "`dict()`"), + ], +) +def test_render_default_simple(field_name, expected): + field = _OUTER_FIELDS[field_name] + resolved = _OUTER_HINTS.get(field_name, field.type) + assert _gen.render_default(field, resolved, _FOUND) == expected + + +def test_render_default_none(): + field = _OUTER_FIELDS["inner_optional"] + assert _gen.render_default(field, _InnerConfig | None, _FOUND) == "`None`" + + +def test_render_default_required_primitive(): + field = _OUTER_FIELDS["required"] + assert _gen.render_default(field, str, _FOUND) == "*(required)*" + + +def test_render_default_config_field_sub_fields_optional(): + # Config-typed field with no default → sub-fields are optional. + field = _OUTER_FIELDS["inner"] + assert _gen.render_default(field, _InnerConfig, _FOUND) == "*(sub-fields optional)*" + + +@config_class() +class _TypeDefaultConfig(Config): + fmt: type = Field(default=_InnerConfig, hint=FieldHint.core, desc="A type default.") + + +def test_render_default_type_class(): + # A field whose default value is itself a type object. + fields = dict(_TypeDefaultConfig.fields()) + assert _gen.render_default(fields["fmt"], type, _FOUND) == "`_InnerConfig`" + + +# --------------------------------------------------------------------------- +# format_nav_yaml +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "entries, indent, expected", + [ + # Flat list of strings. + ( + ["reference/a.md", "reference/b.md"], + 0, + ["- reference/a.md", "- reference/b.md"], + ), + # Single nested section. + ( + [{"Section": ["reference/a.md"]}], + 0, + ["- Section:", " - reference/a.md"], + ), + # Double-nested sections. + ( + [{"Outer": [{"Inner": ["reference/a.md"]}]}], + 0, + ["- Outer:", " - Inner:", " - reference/a.md"], + ), + # Non-zero base indent. + ( + ["reference/a.md"], + 1, + [" - reference/a.md"], + ), + # Mixed strings and dicts. + ( + ["reference/index.md", {"Sub": ["reference/sub/a.md"]}], + 0, + ["- reference/index.md", "- Sub:", " - reference/sub/a.md"], + ), + ], +) +def test_format_nav_yaml(entries, indent, expected): + assert _gen.format_nav_yaml(entries, indent) == expected + + +# --------------------------------------------------------------------------- +# render_class_page smoke test +# --------------------------------------------------------------------------- + + +def test_render_class_page_contains_key_sections(): + info = _FOUND[_OuterConfig] + # Build minimal fields list as the generator would. + fields = [] + for name, field in _OuterConfig.fields(): + if _gen.is_user_field(name, field): + resolved = _OUTER_HINTS.get(name, field.type) + fields.append((name, field, resolved)) + info_with_fields = {**info, "fields": fields} + + content = _gen.render_class_page( + _OuterConfig, + info_with_fields, + back_refs=[], + found=_FOUND, + cls_output_paths=_CLS_OUTPUT_PATHS, + own_path=_CLS_OUTPUT_PATHS[_OuterConfig], + ) + + assert "# _OuterConfig" in content + assert "## Fields" in content + assert "`string`" in content + assert "`large_int`" in content + assert "*(sub-fields optional)*" in content # inner field + assert "*(required)*" in content # required field + + +# --------------------------------------------------------------------------- +# render_index_page smoke test +# --------------------------------------------------------------------------- + + +def test_render_index_page_lists_classes(): + classes_in_dir = list(_FOUND.items()) + content = _gen.render_index_page( + pathlib.Path("tests"), + classes_in_dir, + cls_output_paths=_CLS_OUTPUT_PATHS, + subdirs=[], + ) + + assert "## Classes" in content + assert "_InnerConfig" in content + assert "_OuterConfig" in content From e6eea7b0ce666f6bfe9fbd76c078e15a1ad5f9eb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 18:29:38 -0400 Subject: [PATCH 20/33] Add Triton GRPO loss kernel with vocab-parallel support and tests - Add fast_llm/functional/triton/grpo_loss.py: Triton kernel for GRPO loss forward/backward, supporting both non-parallel and vocab-parallel (two-pass) modes, mirroring the entropy/z-loss Triton patterns - Add use_triton field to LanguageModelGRPOLossConfig and dispatch to Triton kernel in LanguageModelGRPOLoss._forward_backward - Update test_lm_losses.py: add num_labels_in_seq, test new_logprobs_mean, and add Triton kernel testing (guarded by triton_available) - Fix Triton interpreter bug in __init__.py: monkeypatch _patch_lang_tensor to use .item() instead of int() for tensor.__index__, fixing a pre-existing failure of all Triton tests under TRITON_INTERPRET=1 (constexpr int args to device functions arrived as 1-d numpy arrays, not scalars) Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/__init__.py | 15 +- fast_llm/functional/triton/grpo_loss.py | 229 ++++++++++++++++++ fast_llm/layers/language_model/loss/config.py | 5 + fast_llm/layers/language_model/loss/grpo.py | 9 +- tests/layers/test_lm_losses.py | 31 ++- 5 files changed, 285 insertions(+), 4 deletions(-) create mode 100644 fast_llm/functional/triton/grpo_loss.py diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index f5b394bfb..baf50099d 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -27,7 +27,20 @@ tl_arange = None tl_full = None elif triton_interpret: - # Workaround for a triton bug. + # Workaround for a triton interpreter bug: constexpr int arguments to device functions + # arrive as 1-d numpy arrays rather than scalars. The interpreter's _patch_lang_tensor sets + # tensor.__index__ = lambda self: int(self.handle.data), which fails for 1-d arrays. + # Patch _patch_lang_tensor to use .item() instead, which works for both 0-d and 1-d arrays. + import triton.runtime.interpreter as _triton_interpreter + + _orig_patch_lang_tensor = _triton_interpreter._patch_lang_tensor + + def _fixed_patch_lang_tensor(tensor): + _orig_patch_lang_tensor(tensor) + tensor.__index__ = lambda self: self.handle.data.item() + + _triton_interpreter._patch_lang_tensor = _fixed_patch_lang_tensor + @triton_jit def tl_arange(start, end): return tl.arange(int(start), int(end)) diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py new file mode 100644 index 000000000..deb261f09 --- /dev/null +++ b/fast_llm/functional/triton/grpo_loss.py @@ -0,0 +1,229 @@ +import torch + +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton.entropy_loss import ( + parallel_sum_exp_logits, + triton_cross_entropy_forward_from_labels_parallel_kernel, + triton_fused_softmax_base, +) +from fast_llm.functional.utils import reduce_losses + + +@triton_jit() +def triton_grpo_loss_forward_backward_kernel( + logits_ptr, + labels_ptr, + advantages_ptr, + old_log_probs_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + block_size: tl_constexpr, + losses_ptr=None, + new_logprobs_mean_parts_ptr=None, + num_labels_in_seq_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, + epsilon_low: tl_constexpr = 0.2, + epsilon_high: tl_constexpr = 0.2, + accumulate: tl_constexpr = False, +): + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + + label_idx = tl.load(labels_ptr + block_idx) + if label_idx < 0: + # Masked position. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if new_logprobs_mean_parts_ptr is not None: + tl.store(new_logprobs_mean_parts_ptr + block_idx, 0) + if grad_losses is not None and not accumulate: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + label_idx -= col_min + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + # Non-parallel: compute softmax and predicted logit in one forward pass. + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor + ) + if label_idx >= 0 and label_idx < n_cols: + predicted_logit = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logit *= logits_scale_factor + else: + # Parallel case only: target not in local vocab shard. + predicted_logit = 0.0 + else: + # Parallel case: use globally reduced values from the first pass. + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + predicted_logit = tl.load(predicted_logits_ptr + block_idx) + + # new_log_prob = log_softmax(logits * scale)[label] + # = logits[label]*scale - (max_logits + log(sum_exp_logits)) + new_log_prob = predicted_logit - max_logits - tl.log(sum_exp_logits) + old_log_prob = tl.load(old_log_probs_ptr + block_idx).to(tl.float32) + advantage = tl.load(advantages_ptr + block_idx).to(tl.float32) + + ratio = tl.exp(new_log_prob - old_log_prob) + clipped_ratio = tl.minimum(tl.maximum(ratio, 1.0 - epsilon_low), 1.0 + epsilon_high) + loss = -tl.minimum(ratio * advantage, clipped_ratio * advantage) + + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, loss) + + if new_logprobs_mean_parts_ptr is not None: + num_labels = tl.load(num_labels_in_seq_ptr + block_idx).to(tl.float32) + tl.store(new_logprobs_mean_parts_ptr + block_idx, new_log_prob / tl.maximum(num_labels, 1.0)) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + # effective_grad = probability_ratio_grad * ratio + # = (clamp_min(adv, 0) * (ratio <= 1+eps_high) + clamp_max(adv, 0) * (ratio >= 1-eps_low)) * ratio * grad_losses + effective_grad = ( + ( + tl.maximum(advantage, 0.0) * (ratio <= 1.0 + epsilon_high) + + tl.minimum(advantage, 0.0) * (ratio >= 1.0 - epsilon_low) + ) + * ratio + * grad_losses + ) + + # grad_logits_i = effective_grad * (p_i - delta_{i, label}) + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + prob = exp_logits / sum_exp_logits + if label_idx < 0 or label_idx >= n_cols: + # Target not in local vocab shard (parallel case): no delta term. + grad_logits = effective_grad * prob + else: + grad_logits = effective_grad * tl.where(col_offsets == label_idx, prob - 1.0, prob) + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) + + +def triton_grpo_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, + divisor: float | None = None, + block_size: int | None = None, + num_warps: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + assert logits.is_contiguous() + assert target.is_contiguous() + assert advantages.is_contiguous() + assert old_log_probabilities.is_contiguous() + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) + if divisor is None: + divisor = n_rows + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "epsilon_low": epsilon_low, + "epsilon_high": epsilon_high, + "block_size": block_size, + "num_warps": num_warps, + } + if grad_output is None: + backward_kwargs = {} + else: + accumulate = grad_logits is not None + grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits + backward_kwargs = { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / divisor, + "grad_logits_stride_0": grad_logits.stride(-2), + "accumulate": accumulate, + } + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if num_labels_in_seq is not None: + assert num_labels_in_seq.is_contiguous() + new_logprobs_mean_parts = torch.empty(n_rows, dtype=torch.float, device=logits.device) + new_logprobs_mean_kwargs = { + "new_logprobs_mean_parts_ptr": new_logprobs_mean_parts, + "num_labels_in_seq_ptr": num_labels_in_seq, + } + else: + new_logprobs_mean_kwargs = {} + + if group is None: + triton_grpo_loss_forward_backward_kernel[(n_rows,)]( + logits, + target, + advantages, + old_log_probabilities, + losses_ptr=losses, + **kwargs, + **backward_kwargs, + **new_logprobs_mean_kwargs, + ) + else: + local_max_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + sum_exp_logits = torch.empty_like(local_max_logits) + predicted_logits_local = torch.empty_like(local_max_logits) + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits_local, + col_min=n_cols * group.rank(), + **kwargs, + ) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + torch.distributed.all_reduce(predicted_logits_local, op=torch.distributed.ReduceOp.SUM, group=group) + triton_grpo_loss_forward_backward_kernel[(n_rows,)]( + logits, + target, + advantages, + old_log_probabilities, + losses_ptr=losses, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits_local, + col_min=n_cols * group.rank(), + **kwargs, + **backward_kwargs, + **new_logprobs_mean_kwargs, + ) + + loss = reduce_losses(losses, divisor) + new_logprobs_mean = new_logprobs_mean_parts.sum() if num_labels_in_seq is not None else None + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index a2c067a95..4381aa5d9 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -200,6 +200,11 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 62f591d9f..a933fec99 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -4,6 +4,7 @@ import torch from fast_llm.engine.base_model.config import LossDef +from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs @@ -19,7 +20,13 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - loss, grad, new_logprobs_mean = fused_grpo_loss_forward_backward( + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( logits, self._get_labels(kwargs, split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 3a68a999f..9b93aeb66 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -13,6 +13,7 @@ from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward @@ -250,6 +251,13 @@ def _test_grpo_loss( logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( num_columns, loss_masking, batch_shape, dtype ) + num_labels = int((target >= 0).sum().item()) + num_labels_in_seq = torch.where( + target >= 0, + torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), + torch.zeros(batch_shape, dtype=torch.int32, device=target.device), + ) + divisor = max(num_labels, 1) out_ref, grad_ref = loss_forward_backward( grad_output, lambda *args, **kwargs: reference_grpo_loss(*args, **kwargs)[0], @@ -263,7 +271,7 @@ def _test_grpo_loss( previous_grad = torch.randn_like(grad_ref) grad_ref = grad_ref + previous_grad local_previous_grad = split_op(previous_grad, group, -1).contiguous() - out_fused, grad_fused, _ = fused_grpo_loss_forward_backward( + out_fused, grad_fused, new_logprobs_fused = fused_grpo_loss_forward_backward( split_op(logits, group, -1), target, advantages, @@ -272,10 +280,29 @@ def _test_grpo_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - divisor=(target >= 0).sum().item(), + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, ) _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + if not triton_available: + return + out_triton, grad_triton, new_logprobs_triton = triton_grpo_loss_forward_backward( + split_op(logits, group, -1).contiguous(), + target, + advantages, + old_log_probabilities, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + block_size=block_size, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) + Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) + def _test_z_loss( batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None From 8ae107a0544f729b86f532af888be18d7089378f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 18:54:55 -0400 Subject: [PATCH 21/33] Fix bugs in fast_llm/engine and related modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix `self._run.index` → `run_index` in lm_eval evaluator (AttributeError on every eval) - Fix `broadcast_kwargs` dict iteration missing `.items()` in huggingface inference (ValueError with TP) - Fix missing `self` parameter in `HuggingfacePreTrainedModel.inner_forward` - Fix `kwargs.get` → `kwargs.pop` for `gguf_file` in inference config - Fix `get_state_tensor_iterator` and `import_state_tensor` using weight shard sizes for all shards (wrong for optimizer state shards with frozen params) - Fix `reset_shard_pad` and debug logging only covering last FSDP in `initialize_weights` (missed frozen FSDP) - Fix `len(grads_norm_slices) < 0` → `> 0` in grad norm slice merging (always-false condition, no merging ever happened) - Fix `PowerLRStage._interpolate` and `CosineLRStage._interpolate` incorrectly marked `@abc.abstractmethod` - Fix `CosineLRStage.lr`/`end_lr` typed as `int` instead of `float` - Remove hardcoded debug `logger.info` for `layers.1.norm_1.weight` from `SafeLoad` - Replace open("r"/"w") file handles with `.read_text()`/`.write_text()`/`.touch()` throughout - Fix `pass` before `wandb.alert(...)` call that suppressed the alert - Fix duplicate word in docstring ("not not needed") - Fix misleading field descriptions for `depth_first_micro_batches`/`breadth_first_micro_batches` Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/convert.py | 8 +++--- fast_llm/engine/checkpoint/distributed.py | 6 ++--- fast_llm/engine/checkpoint/huggingface.py | 10 +++---- fast_llm/engine/checkpoint/safe_load.py | 9 ------- fast_llm/engine/checkpoint/state_dict.py | 10 +++---- fast_llm/engine/config_utils/run.py | 4 +-- fast_llm/engine/config_utils/runnable.py | 2 +- .../engine/evaluation/lm_eval/evaluator.py | 4 +-- .../evaluation/lm_eval/fast_llm_wrapper.py | 4 +-- fast_llm/engine/inference/config.py | 2 +- fast_llm/engine/inference/huggingface.py | 6 ++--- fast_llm/engine/multi_stage/config.py | 2 +- fast_llm/engine/multi_stage/multi_stage.py | 8 +++--- fast_llm/engine/multi_stage/stage_base.py | 26 ++++++++++--------- fast_llm/engine/optimizer/learning_rate.py | 6 ++--- fast_llm/engine/schedule/config.py | 5 ++-- fast_llm/engine/training/trainer.py | 2 +- fast_llm/engine/training/wandb.py | 8 +++--- tests/models/test_checkpoint.py | 4 +-- 19 files changed, 56 insertions(+), 70 deletions(-) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 103d9488c..728877792 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -70,7 +70,7 @@ def _convert_model_partial( logger.info(f"Saving {output.format} checkpoint to {output.path}...") output.path.mkdir(parents=True, exist_ok=self.exist_ok) model.save_checkpoint(output) - (output.path / "ok").open("w") + (output.path / "ok").touch() logger.info(f"Done!") def run(self): @@ -120,7 +120,7 @@ def run(self): global_rename_map = {} file_count = 0 for step_path in step_paths: - step_index = json.load((step_path / index_filename).open("r")) + step_index = json.loads((step_path / index_filename).read_text()) if len(index) == 0: index.update(step_index) index["weight_map"] = weight_map @@ -141,7 +141,7 @@ def run(self): path = self.output.path / index_filename # Save the index. - json.dump(index, path.open("w"), indent=4) + path.write_text(json.dumps(index, indent=4)) # Copy the config (step_paths[0] / config_filename).rename(self.output.path / config_filename) @@ -158,5 +158,5 @@ def run(self): step_path.rmdir() # All good! - (self.output.path / "ok").open("w") + (self.output.path / "ok").touch() logger.info(f">>> All done!") diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index fecc35ef7..22782e49c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -32,17 +32,17 @@ class DistributedCheckpointHandler(CheckpointHandler): def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): serialized_metadata = metadata.to_dict() config.path.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + (config.path / "metadata.yaml").write_text(yaml.safe_dump(serialized_metadata)) @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: - return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) + return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").read_text())) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = metadata.to_dict() config.path.mkdir(parents=True, exist_ok=True) if self._model.config.distributed.rank == 0: - yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + (config.path / "metadata.yaml").write_text(yaml.safe_dump(serialized_metadata)) safetensors.torch.save_file( tensors={f"{shard_name}_shard": self._model.get_shard(shard_name) for shard_name in metadata.shards}, filename=config.path / f"rank_{self._model.config.distributed.rank}.safetensors", diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 270171755..5379e51d9 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -58,11 +58,7 @@ def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadat path = config.path / f"{cls.base_file_name}.safetensors.index.json" logger.info(f"Saving index to {path}") # Save the index. - json.dump( - {"metadata": metadata, "weight_map": index}, - path.open("w"), - indent=4, - ) + path.write_text(json.dumps({"metadata": metadata, "weight_map": index}, indent=4)) def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata) -> dict: huggingface_config = self._export_config(self._model.config) @@ -145,7 +141,7 @@ def _load_weights( logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { config.path / path - for path in json.load((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).open("r"))[ + for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } @@ -155,7 +151,7 @@ def _load_weights( logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { config.path / path - for path in json.load((config.path / transformers.utils.WEIGHTS_INDEX_NAME).open("r"))[ + for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index d3f72a47c..9667fa98b 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -149,17 +149,12 @@ def _check_parameters(self, errors: list[str]) -> None: f' and shard "{shard_name}": loaded {counter}, expected {local_size}' ) - counter_ = counter # Accumulate in a list for global counter check. if ( not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 ) or stage.is_tied_weight_copy: # Ignore the counter from duplicate tensors. counter = 0 - if parameter_name == "layers.1.norm_1.weight": - logger.info( - f"Parameter {parameter_name} local {counter_} keep {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" - ) counters.append(counter) # Check for unexpected parameters. @@ -179,10 +174,6 @@ def _check_parameters(self, errors: list[str]) -> None: for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: counter = counters.pop(0) - if parameter_name == "layers.1.norm_1.weight": - logger.info( - f"Parameter {parameter_name} global {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" - ) parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 32eea2db6..8106d85dc 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -154,7 +154,7 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler): def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: path = config.path / f"metadata.yaml" logger.warning(f"Loading metadata from {path}") - return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) + return CheckpointMetadata.from_dict(yaml.safe_load(path.read_text())) @classmethod def _save_serialized_metadata( @@ -166,7 +166,7 @@ def _save_serialized_metadata( if "metadata" not in serialized_metadata: serialized_metadata["metadata"] = {} serialized_metadata["metadata"]["state_index"] = index - yaml.safe_dump(serialized_metadata, path.open("w")) + path.write_text(yaml.safe_dump(serialized_metadata)) @classmethod def _get_key(cls, parameter_name: str, shard_name: str) -> str: @@ -259,15 +259,15 @@ def _merge_index(self) -> None: if self._do_save and self._distributed_config.pipeline_parallel != 1: # Combine the indexes from all pipeline ranks. logger.info(f"Merging pipeline-parallel indexes.") - yaml.dump( - self._index, (self._config.path / f"index_{self._distributed_config.pipeline_rank}.yaml").open("w") + (self._config.path / f"index_{self._distributed_config.pipeline_rank}.yaml").write_text( + yaml.dump(self._index) ) safe_barrier(self._distributed.pipeline_group, "save state dict", timeout=self._config.timeout) self._index = {} if self._distributed_config.pipeline_rank == 0: for rank in range(self._distributed_config.pipeline_parallel): file_name = self._config.path / f"index_{rank}.yaml" - local_index = yaml.safe_load(file_name.open("r")) + local_index = yaml.safe_load(file_name.read_text()) for key, value in local_index.items(): assert key not in self._index, key self._index[key] = value diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index ab6f27489..e32d7cd46 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -144,9 +144,9 @@ def __init__( (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) run = len(list((self._experiment_directory / "runs").iterdir())) (self._experiment_directory / "runs" / str(run)).mkdir() - yaml.safe_dump(config_dict, (self._experiment_directory / "config.yaml").open("w")) + (self._experiment_directory / "config.yaml").write_text(yaml.safe_dump(config_dict)) # Dumping a verbose version of the config - yaml.safe_dump(config_dict_verbose, (self._experiment_directory / "config_verbose.yaml").open("w")) + (self._experiment_directory / "config_verbose.yaml").write_text(yaml.safe_dump(config_dict_verbose)) else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 74fa0a2ae..a19893b40 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -153,7 +153,7 @@ def _load_default_config_dict(cls, parsed: argparse.Namespace) -> typing.Any: elif urllib.parse.urlparse(parsed.config).scheme == "https": return yaml.safe_load(cls._load_url(parsed.config, parsed.config_auth_token_file)) elif pathlib.Path(parsed.config).is_file(): - return yaml.safe_load(pathlib.Path(parsed.config).open("r").read()) + return yaml.safe_load(pathlib.Path(parsed.config).read_text()) else: raise FileNotFoundError(parsed.config) diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 4db258093..19ffc87c2 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -26,7 +26,7 @@ def setup( run_count: int, ) -> None: if "HUGGINGFACE_API_KEY_PATH" in os.environ: - os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).read_text().strip() else: if not "HF_TOKEN" in os.environ: logger.warning( @@ -62,4 +62,4 @@ def run( metrics: dict[str, typing.Any], ) -> None: assert self._is_setup - self._flm_wrapper.run(self._config.cli_args, metrics.get("completed_steps", 0), self._run.index) + self._flm_wrapper.run(self._config.cli_args, metrics.get("completed_steps", 0), run_index) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 56a2588c0..511bf6359 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -116,7 +116,7 @@ def max_length(self): if isinstance(self._config.fast_llm_config.base_model.transformer.mixer.rotary, NoRotaryConfig): return self._config.fast_llm_config.base_model.max_position_embeddings - # check if tokenizer holds model sequence leigh info + # check if tokenizer holds model sequence length info if hasattr(self._tokenizer, "model_max_length"): if self._tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH @@ -528,7 +528,7 @@ def tok_batch_encode( if left_truncate_len: original_lengths = encoding["input_ids"].size(1) if original_lengths > left_truncate_len: - logger.warn( + logger.warning( f"Left truncation applied. Original sequence length was {original_lengths}, " f"truncating to last {left_truncate_len} tokens. Some content will be lost.", ) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d19e2478d..b0b3b33a0 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -72,7 +72,7 @@ def _get_config_dict( kwargs.pop("_from_pipeline", None) kwargs.pop("_from_auto", False) kwargs.pop("_commit_hash", None) - kwargs.get("gguf_file", None) + kwargs.pop("gguf_file", None) # Get the pretrained config. if "pretrained" in kwargs: diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 5a07bd51b..67be46558 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -154,7 +154,7 @@ def forward( # TODO: Bypassed if passed as positional argument. assert kwargs.get("past_key_values") is None and not kwargs.get("use_cache") broadcast_kwargs = {**kwargs, **{i: arg for i, arg in enumerate(args)}, "continue_work": continue_work} - tensor_kwargs = {key: value for key, value in broadcast_kwargs if torch.is_tensor(value)} + tensor_kwargs = {key: value for key, value in broadcast_kwargs.items() if torch.is_tensor(value)} broadcast_object( [(key, tensor.shape, tensor.dtype) for key, tensor in tensor_kwargs.items()], distributed.tensor_group, @@ -162,7 +162,7 @@ def forward( ) for tensor in tensor_kwargs.values(): broadcast(tensor.to(distributed.device), 0, distributed.tensor_group) - non_tensor_kwargs = {key: value for key, value in broadcast_kwargs if key not in tensor_kwargs} + non_tensor_kwargs = {key: value for key, value in broadcast_kwargs.items() if key not in tensor_kwargs} broadcast_object( non_tensor_kwargs, distributed.tensor_group, @@ -240,6 +240,6 @@ def stop_workers(self): self.forward(coordinator_forward=True, continue_work=False) safe_barrier(distributed.world_group, "forward_work_end") - def inner_forward(*args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: + def inner_forward(self, *args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: # Meant to be overridden in derived classes raise NotImplementedError() diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index c642203fc..958a3d228 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -82,7 +82,7 @@ class StageConfig(Config): store_frozen_weights_in_optimization_precision: bool = Field( # TODO: Implement and set default to False default=True, - desc="Store frozen weights in full precision even if not not needed." + desc="Store frozen weights in full precision even if not needed." "Allows preserving the precision for saved checkpoints," " at the cost of memory and compute (copy) overheads.", hint=FieldHint.optional, diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index ed293b103..cd781beb7 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -487,13 +487,13 @@ def get_state_tensor_iterator( self, shard_names: tuple[str, ...], data_type: DataType | None = None ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: - shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) + shard_split = self._shards[shard_name].split(self._get_stage_shard_sizes(shard_name), 0) for shard_index, ((stage_index, stage), shard) in enumerate( zip(self._stages_on_device.items(), shard_split, strict=True) ): if stage_index in self._stages_owned: for name, tensor in stage._export_shard( - shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type + shard.split(self._get_fsdp_shard_sizes(shard_name)[shard_index]), data_type=data_type ): # noqa yield name, shard_name, tensor @@ -508,8 +508,8 @@ def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torc shard_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] stage_shards = ( self._shards[shard_name] - .split(self._stage_weight_shard_sizes, 0)[shard_index] - .split(self._fsdp_weight_shard_sizes[shard_index]) + .split(self._get_stage_shard_sizes(shard_name), 0)[shard_index] + .split(self._get_fsdp_shard_sizes(shard_name)[shard_index]) ) return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shards, tensor) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 23ee5d8bd..ea737524b 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -204,17 +204,19 @@ def initialize_weights(self) -> None: meta.init_parameter(parameter, self._distributed, debug=self._config.debug_param_init) if self.mode.on_device: - fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) + for fsdp in self._fsdps: + fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) if self._config.debug_param_init: if self._mode.on_device: - fsdp.log_shard( - name="param", - shard=fsdp.weight_shard, - distributed=self._distributed, - level=self._config.debug_param_init, - global_=self._config.debug_global_tensors, - ) + for fsdp in self._fsdps: + fsdp.log_shard( + name="param", + shard=fsdp.weight_shard, + distributed=self._distributed, + level=self._config.debug_param_init, + global_=self._config.debug_global_tensors, + ) def get_param_groups( self, optimizer_state_shards: dict[str, tuple[torch.Tensor]], param_group_cls: type[ParamGroup] @@ -238,9 +240,9 @@ def get_param_groups( continue chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale)) buffer_begin = fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name) - for i, lr_scale in enumerate(parameter_meta.lr_scale): - begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size) - end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) + for lr_scale_index, lr_scale in enumerate(parameter_meta.lr_scale): + begin = fsdp.index_buffer_to_shard(buffer_begin + lr_scale_index * chunk_size) + end = fsdp.index_buffer_to_shard(buffer_begin + (lr_scale_index + 1) * chunk_size) if lr_scale == 0 or begin == end: continue optimizer_params = (parameter_meta.param_weight_decay, lr_scale) @@ -279,7 +281,7 @@ def get_param_groups( grads_norm_slices = [] for name in grad_norm_names: begin, end = fsdp._get_parameter_range_in_shard(name) - if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: + if len(grads_norm_slices) > 0 and begin == grads_norm_slices[-1].stop: grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) else: grads_norm_slices.append(slice(begin, end)) diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index c6912e4f1..3f58c953c 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -59,22 +59,20 @@ def __post_init__(self) -> None: super().__post_init__() Assert.gt(self.power, 0) - @abc.abstractmethod def _interpolate(self, coeff: float) -> float: return coeff**self.power @dataclasses.dataclass() class CosineLRStage(InterpolateLRStage): - lr: int - end_lr: int + lr: float + end_lr: float power: float = 1.0 def __post_init__(self) -> None: super().__post_init__() Assert.gt(self.power, 0) - @abc.abstractmethod def _interpolate(self, coeff: float) -> float: return 0.5 * (1.0 - math.cos(math.pi * coeff**self.power)) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index f56c00f28..29720b90b 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -16,13 +16,14 @@ class ScheduleConfig(Config): depth_first_micro_batches: int = Field( default=1, - desc="Size of individual micro-batches. May be derived or constrained be other quantities.", + desc="Number of micro-batches processed depth-first, i.e., each runs through all model stages before the next" + " begins. This is the standard way to perform gradient accumulation.", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) breadth_first_micro_batches: int = Field( default=1, - desc="Size of individual micro-batches. May be derived or constrained be other quantities.", + desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 1e341077f..deda813bb 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -402,7 +402,7 @@ def _save_checkpoint( ) # Mark the checkpoint as complete. if self._run.is_main_rank: - (checkpoint_directory / "ok").open("w") + (checkpoint_directory / "ok").touch() logger.info(f"Saved {config.save_name} to {checkpoint_directory}") to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_base_directory.iterdir())) diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 724b5b718..3349cff26 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -19,14 +19,14 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): # Wandb login from file api_key_path = os.environ.get("WANDB_API_KEY_PATH") if api_key_path: - os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).open("r").read().strip() + os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).read_text().strip() wandb_path = ( None if self._run.experiment_directory is None else self._run.experiment_directory / "wandb_config.yaml" ) if wandb_path is not None and wandb_path.is_file(): - wandb_config = yaml.safe_load(wandb_path.open("r")) + wandb_config = yaml.safe_load(wandb_path.read_text()) else: wandb_config = { "id": wandb.sdk.lib.runid.generate_id(16), @@ -38,7 +38,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): "resume": "allow", } if wandb_path is not None: - yaml.safe_dump(wandb_config, wandb_path.open("w")) + wandb_path.write_text(yaml.safe_dump(wandb_config)) # TODO: Does wandb work with nested configs? self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config) else: @@ -53,8 +53,6 @@ def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: - pass - self._wandb.alert( # noqa title=title() if callable(title) else title, text=f"[{self._config.project_name}/{self._run.experiment_name}, run {self._run.index}]" diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5f0f5a80f..094cbc094 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -235,7 +235,7 @@ def test_load_pretrained( ): # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. reference_config = model_testing_config.model_config_class.from_dict( - yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] + yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").read_text())["model"] ) reference_shard = safetensors.torch.load_file( get_convert_path() / "rank_0.safetensors", device=str(testing_device) @@ -260,7 +260,7 @@ def test_load_pretrained( "base_model": yaml.safe_load( get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) .joinpath("metadata.yaml") - .open("r") + .read_text() )["config"]["base_model"] } ) From 9513bca6b2be52748406195bc57b382b86e39eff Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 19:11:41 -0400 Subject: [PATCH 22/33] Fix NameError in lora_linear forward_only with out_channel_begin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit layer_out was only assigned inside the isinstance(output, tuple) branch, causing NameError when out_channel_begin is set with a plain Linear. Also fix tuple unpack: layer_out, tp_bias = output[0] → output. Affects value-only LoRA in GQA attention (attention.py uses out_channel_begin). Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/common/peft/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index eaf9f67f0..badfc91f2 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -61,8 +61,9 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor input_ = input_.detach().requires_grad_() with torch.enable_grad(): output = old_forward(input_) + layer_out = output if isinstance(output, tuple): - layer_out, tp_bias = output[0] + layer_out, tp_bias = output assert tp_bias is None lora_out = (alpha / rank) * module.lora_1( module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) From eb113948e86dcc384d87fdfce77f6794e9a34416 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 19:18:20 -0400 Subject: [PATCH 23/33] Fix swapped args in MoE _add_shared_experts call and wrong error message in Llama import_config - MixtureOfExpertMLP._forward called _add_shared_experts(top_experts, scores) but the signature is _add_shared_experts(scores, top_experts). Passing integer indices as scores and float scores as the top_experts would produce wrong dtypes for both the shared expert index arange and the concatenated scores tensor. - LlamaAttentionConverter.import_config raised NotImplementedError with `type(config.rotary).__name__` where `config` is a dict (no .rotary attribute); use the `rope_type` variable already in scope instead. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/decoder/mlp/mixture_of_experts.py | 2 +- fast_llm/models/gpt/conversion/llama.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 48bc5a5e1..04f0cd04a 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -116,7 +116,7 @@ def _forward( if self._config.routing == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) if self._config.shared_experts > 0: - scores, top_experts = self._add_shared_experts(top_experts, scores) + scores, top_experts = self._add_shared_experts(scores, top_experts) elif self._config.routing == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 38dc38586..491ddde6e 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -217,7 +217,7 @@ def import_config(cls, config: dict) -> dict: } ) else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") out = { "rotary": rotary_config, "heads": config["num_attention_heads"], From edfe59fbc36a1a07af864708683efe9295931e5e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 19:23:19 -0400 Subject: [PATCH 24/33] Remove duplicate NotImplementedError check in LanguageModelDPOLoss Lines 16-17 were an exact duplicate of lines 12-13 (_prediction_distance > 1), dead code that never executes. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/language_model/loss/dpo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index af2149d36..059f808e5 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -13,8 +13,6 @@ def __init__(self, *args, **kwargs): raise NotImplementedError() if self._num_splits > 1: raise NotImplementedError() - if self._prediction_distance > 1: - raise NotImplementedError() if self._vocab_parallel: raise NotImplementedError() From a0b91d2fe2a9538270e763da18e05a4bbc093d99 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 19:25:10 -0400 Subject: [PATCH 25/33] Remove dead code in MLP and StochasticMixer - MLPBase._get_intermediate_dims: remove discarded TensorDim("gate_and_up", 2) created but immediately thrown away before the ConcatenatedTensorDim call. - StochasticMixer.get_preprocessing_config: remove first loop that called mixer.get_preprocessing_config() on each mixer but discarded every result, causing each mixer's method to be called twice. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/decoder/mlp/mlp.py | 1 - fast_llm/layers/decoder/stochastic_mixer.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 1048f7c2a..80599da97 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -71,7 +71,6 @@ def __init__( def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.intermediate_size, self._parallel_dim) if self._config.gated: - TensorDim("gate_and_up", 2) intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) else: intermediate_1_dim = intermediate_2_dim diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 97bd1f477..a3ea8b846 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -161,8 +161,6 @@ def _forward( return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) def get_preprocessing_config(self) -> dict[str, typing.Any]: - for mixer in self.mixers.values(): - mixer.get_preprocessing_config() return safe_merge_dicts(*(mixer.get_preprocessing_config() for mixer in self.mixers.values())) def _sample_allocation(self, num_layers: int, generator: torch.Generator) -> list[int]: From ac7258245a1a2b22dc6e94a17800a6863f08fe89 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 19:50:14 -0400 Subject: [PATCH 26/33] Fix grad scaler load and Apriel dt_rank auto formula - DynamicGradScaler.load: restore _scale from checkpoint (was never loaded, leaving _scale unset after resume from checkpoint) - ConstantGradScaler.load: handle case where _scale not yet set (when loading checkpoint without prior reset_state call) - AprielMambaConverter: fix dt_rank auto formula to use ceil(hidden_size / 16) matching the reference model Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/optimizer/optimizer.py | 8 ++++++-- fast_llm/models/gpt/conversion/apriel.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/optimizer/optimizer.py b/fast_llm/engine/optimizer/optimizer.py index 0dd094390..80def6a28 100644 --- a/fast_llm/engine/optimizer/optimizer.py +++ b/fast_llm/engine/optimizer/optimizer.py @@ -242,8 +242,11 @@ def unscale_and_check_nans(self, tensor: torch.Tensor) -> None: class ConstantGradScaler(VariableGradScaler): def load(self, state, validate=True) -> None: - if validate: - Assert.eq(self._scale, state["scale"]) + if hasattr(self, "_scale"): + if validate: + Assert.eq(self._scale, state["scale"]) + else: + self._set_scale(state["scale"]) super().load(state, validate=validate) def _set_scale(self, value) -> None: @@ -282,6 +285,7 @@ def save(self) -> dict[str, typing.Any]: def load(self, state, validate=True) -> None: super().load(state, validate=validate) + self._set_scale(state["scale"]) self._growth_tracker = state["growth"] self._hysteresis_tracker = state["hysteresis"] diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 505c62d70..ac732ba22 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -34,7 +34,7 @@ def import_config(cls, config: dict) -> dict: "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, "dt_rank": ( - math.ceil(config["hidden_size"]) + math.ceil(config["hidden_size"] / 16) if config["ssm_cfg"].get("dt_rank", "auto") == "auto" else config["ssm_cfg"]["dt_rank"] ), From 021a71a460b7fc882999f7347e8aa0a2990ed559 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 20:04:28 -0400 Subject: [PATCH 27/33] Add parallelism documentation (user guide and developer guide) Co-Authored-By: Claude Sonnet 4.6 --- docs/developer_guide/parallelism.md | 212 ++++++++++++++++++++++++++++ docs/user_guide/parallelism.md | 179 ++++++++++++++++++++++- mkdocs.yaml | 1 + 3 files changed, 390 insertions(+), 2 deletions(-) create mode 100644 docs/developer_guide/parallelism.md diff --git a/docs/developer_guide/parallelism.md b/docs/developer_guide/parallelism.md new file mode 100644 index 000000000..651d0b4d6 --- /dev/null +++ b/docs/developer_guide/parallelism.md @@ -0,0 +1,212 @@ +--- +title: Parallelism Internals +--- + +This document describes how Fast-LLM's four parallelism strategies are implemented. It is aimed at contributors adding new layer types, modifying the distributed engine, or debugging performance issues. + +For user-facing configuration, see the [Parallelism guide](../user_guide/parallelism.md). + +--- + +## Rank Assignment + +All rank arithmetic lives in `fast_llm/engine/distributed/config.py`. Given `world_size`, `tensor_parallel`, `pipeline_parallel`, and `sequence_data_parallel`, the derived dimensions are: + +```python +data_parallel = world_size // (tensor_parallel * pipeline_parallel) +batch_data_parallel = data_parallel // sequence_data_parallel + +tensor_rank = rank % tensor_parallel +data_rank = (rank // tensor_parallel) % data_parallel +pipeline_rank = rank // (tensor_parallel * data_parallel) +batch_data_rank = data_rank // sequence_data_parallel +sequence_data_rank = data_rank % sequence_data_parallel +``` + +When `pipeline_first=True`, `data_rank` and `pipeline_rank` are swapped: + +```python +pipeline_rank = (rank // tensor_parallel) % pipeline_parallel +data_rank = (rank // tensor_parallel) // pipeline_parallel +``` + +This alternative ordering places pipeline stages nearer in global rank space, which can improve NUMA locality when each node runs multiple pipeline stages. + +--- + +## Process Groups + +`fast_llm/engine/distributed/distributed.py` constructs the NCCL (or Gloo for CPU) process groups from the `DistributedConfig`. Groups are de-duplicated through `ProcessGroupPool` — if two parallelism dimensions happen to cover the same set of ranks, they share a single underlying `torch.distributed.ProcessGroup`. + +The named groups used throughout the engine are: + +| Name | Members | Primary use | +| --- | --- | --- | +| `world` | All ranks | Global barriers | +| `tensor` | Same `data_rank`, `pipeline_rank` | TP all-reduces | +| `data` | Same `tensor_rank`, `pipeline_rank` | ZeRO reduce-scatter / all-gather | +| `pipeline` | Same `tensor_rank`, `data_rank` | Pipeline send/recv | +| `sequence_data` | Same `tensor_rank`, `pipeline_rank`, `batch_data_rank` | Sequence-parallel reduction | +| `batch_data` | Same `tensor_rank`, `pipeline_rank`, `sequence_data_rank` | Non-sequence data reduction | +| `tensor_and_data` | Same `pipeline_rank` | ZeRO with TP | +| `tensor_and_sequence_data` | Same `pipeline_rank`, `batch_data_rank` | Sequence-TP activations | +| `model_and_sequence_data` | Same `batch_data_rank` | Cross-pipeline sequence gradient | + +`Distributed.set_step(step, phase)` reseeds per-step generators (`pp_generator`, `tp_generator`) using large prime offsets per dimension, so dropout and other stochastic ops are deterministically reproducible across ranks and resumptions. + +--- + +## Tensor Parallelism + +### Sharded linear layers + +Tensor parallelism is implemented by two linear layer variants in `fast_llm/layers/common/linear/linear.py`: + +**`OutputParallelLinear`** (column split): + +- Weight shape: `[output_dim / tensor_parallel, input_dim]` +- Each rank computes a different slice of the output columns +- Forward: local `Y_local = X @ W_local`; output stays partitioned — no communication on the output +- If `sequence_parallel`: input is first **all-gathered** across the tensor group before the matmul +- Backward: grad_input is **all-reduced** (or **reduce-scattered** with sequence-TP) across the tensor group +- Used for: Q/K/V projections, MLP gate/up projections + +**`InputParallelLinear`** (row split): + +- Weight shape: `[output_dim, input_dim / tensor_parallel]` +- Each rank holds a row slice of the weight (a slice of the input dimension) +- Forward: local `Y_local = X_local @ W_local`, then **all-reduce** output across the tensor group (so every rank has the full output) +- If `sequence_parallel`: output is **reduce-scattered** instead of all-reduced, so each rank ends up with a sequence slice +- Custom autograd via `input_parallel_linear_autograd` to correctly handle gradient flow +- Used for: attention output projection, MLP down projection + +### Sequence-tensor parallelism + +The standard (non-sequence-TP) TP pattern replicates the full sequence on every rank between layers. Sequence-tensor parallelism keeps activations distributed across the sequence dimension between layers, reducing activation memory by a factor of `tensor_parallel`. + +At each transformer sub-layer (attention or MLP), the flow is: + +- **`OutputParallelLinear`**: **all-gather** the sequence-chunked input → full sequence × partial output columns per rank +- Attention / elementwise ops: operate on full-sequence slices +- **`InputParallelLinear`**: matmul → **reduce-scatter** the output → each rank returns to holding `seq_len / tensor_parallel` rows + +The total communication volume (all-gather + reduce-scatter) equals that of a single all-reduce, so there is no extra bandwidth cost. The benefit is smaller activation tensors between layers. + +### Adding a new tensor-parallel layer + +1. Declare weight dimensions using `TensorDim` objects from `fast_llm/engine/config_utils/tensor_dim.py`. Mark the sharded dimension with the appropriate `DistributedDim`. +2. Inherit from `OutputParallelLinear` or `InputParallelLinear`, or implement the custom forward/backward directly in `fast_llm/functional/`. +3. Ensure the layer's `forward()` uses the standard signature: `(input, kwargs, losses, metrics) → Tensor`. + +--- + +## Pipeline Parallelism + +### Stage splitting + +`MultiStageModel._split_into_stages()` (in `fast_llm/engine/multi_stage/multi_stage.py`) partitions the flat list of `Layer` objects returned by `BaseModel.get_layers()`. Each stage holds `layers_per_stage` consecutive layers. The mapping from stage index to pipeline rank is: + +```python +pipeline_rank = (stage_index // stages_per_pipeline_stage) % pipeline_parallel +``` + +Stages owned by this rank have full weights and compute. Stages on other pipeline ranks are metadata-only stubs (except for tied weights, see below). + +### Tied weights + +Embedding and LM-head weights are often shared. When these layers land on different pipeline stages, `Stage._tied_weight_copies` holds a reference-only copy: + +- Forward pass: tied weights are **broadcast** from the owning stage to all stages that need them. +- Backward pass: gradients from non-owning stages are **all-reduced** back to the owning stage. + +### Schedule + +The schedule (`fast_llm/engine/schedule/`) builds a DAG of `ScheduleNode` operations (forward, backward, send, recv, optimizer step) and executes them across three CUDA streams (compute, send, recv). Pipeline communication uses PyTorch `isend` / `irecv` for overlap with compute. + +`breadth_first_micro_batches` controls how many micro-batches are in-flight at once. With `N` pipeline stages and `breadth_first_micro_batches = N`, the pipeline bubble fraction approaches zero for large batches. + +--- + +## Data Parallelism and ZeRO/FSDP + +Data parallelism in Fast-LLM is inseparable from the ZeRO sharding implementation in `fast_llm/engine/multi_stage/fsdp.py`. The `FSDP` class owns the per-stage weight and gradient buffers and orchestrates all-gather / reduce-scatter across the data-parallel group. + +### Buffer layout + +Each `FSDP` instance maintains flat buffers for a pipeline stage's parameters: + +```text +_weight_shard : [total_params / data_parallel] # local shard, always resident +_weight_buffer : [total_params] # full weights, reconstructed on demand (ZeRO-3) +_grad_shard : [total_params / data_parallel] # local gradient shard +_grad_buffer : [total_params] # full gradient accumulation buffer +``` + +Every parameter is a view into the appropriate buffer slice, so there are no copies during forward/backward — the buffer *is* the parameter storage. + +Shards are padded to a multiple of `SHARD_PAD_TO_MULTIPLE` (32) per rank to ensure aligned communication. + +### Forward pass (`restore_parameters`) + +Before each forward pass through a stage: + +1. Copy `_weight_shard` into the local slice of `_weight_buffer` +2. If ZeRO stage 3: `all_gather(_weight_buffer)` across the data-parallel group + +With double-buffering (`num_weight_buffers=2`), the all-gather for stage `i+1` is issued asynchronously while stage `i` is computing. + +### Backward pass (`reduce_gradients`) + +After each backward pass: + +1. If sequence-parallel: `all_reduce` sequence-parallel gradient contributions across the tensor-and-sequence-data group +2. `reduce_scatter(_grad_buffer → _grad_shard)` across the data-parallel group (average reduction) +3. If the gradient shard dtype differs from the buffer dtype (e.g. fp32 grad shard, bf16 buffer), copy and cast + +With double-buffering (`num_grad_buffers=2`), the reduce-scatter for stage `i` is overlapped with the backward pass for stage `i-1`. + +### ZeRO stage effect on buffers + +| ZeRO stage | `_weight_buffer` | `_grad_buffer` | Communication | +| --- | --- | --- | --- | +| 1 | Not used (weights replicated) | Not used (grads replicated, then all-reduce) | All-reduce on grads | +| 2 | Not used | Used (grad reduce-scatter → shard) | Reduce-scatter on grads | +| 3 | Used (all-gather before forward) | Used | All-gather on weights + reduce-scatter on grads | + +--- + +## Sequence Data Parallelism + +Sequence data parallelism sub-divides the data-parallel group by the sequence dimension. The `sequence_data` process group covers ranks with the same `tensor_rank`, `pipeline_rank`, and `batch_data_rank`. + +During the forward pass of sequence-parallel layers, each rank holds a contiguous chunk of the sequence. When a layer requires the full sequence (e.g. attention), an all-gather is performed across the `sequence_data` group. After the layer, a reduce-scatter returns each rank to its sequence chunk. + +Gradient synchronization is handled in `FSDP.reduce_gradients`: gradients from the sequence-parallel group are all-reduced before the data-parallel reduce-scatter, so the sequence dimension is handled before any ZeRO sharding. + +--- + +## Seeding and Reproducibility + +`Distributed.set_step(step, phase)` is called at the start of each forward/backward pass. It reseeds two per-rank generators: + +- `pp_generator`: seeded by `(step, phase, tensor_rank, data_rank)` — ensures dropout is identical across pipeline ranks within the same TP group. +- `tp_generator`: seeded by `(step, phase, pipeline_rank, data_rank)` — ensures TP ranks sample the same dropout mask. + +Large prime offsets per dimension ensure seeds are distinct across all rank combinations. This guarantees deterministic training regardless of which GPU processes which micro-batch, and allows exact resumption from a checkpoint. + +--- + +## Key Source Files + +| File | Purpose | +| --- | --- | +| `fast_llm/engine/distributed/config.py` | `DistributedConfig`: rank arithmetic, derived fields | +| `fast_llm/engine/distributed/distributed.py` | `Distributed`: process group construction, `ProcessGroupPool`, seeding | +| `fast_llm/engine/multi_stage/fsdp.py` | `FSDP`: buffer management, all-gather, reduce-scatter, checkpoint loading | +| `fast_llm/engine/multi_stage/multi_stage.py` | `MultiStageModel`: stage splitting, tied weights | +| `fast_llm/engine/multi_stage/config.py` | `MultiStageConfig`: ZeRO stage, buffer counts | +| `fast_llm/layers/common/linear/linear.py` | `OutputParallelLinear`, `InputParallelLinear` | +| `fast_llm/functional/linear.py` | Functional forward/backward for TP linear ops | +| `fast_llm/engine/schedule/config.py` | `ScheduleConfig`: micro-batch and pipeline scheduling | +| `fast_llm/engine/schedule/runner.py` | `ScheduleRunner`: DAG execution, CUDA stream management | +| `tests/utils/distributed_configs.py` | 20+ reference configurations combining all strategies | diff --git a/docs/user_guide/parallelism.md b/docs/user_guide/parallelism.md index 406908cd8..1b34ff40d 100644 --- a/docs/user_guide/parallelism.md +++ b/docs/user_guide/parallelism.md @@ -2,6 +2,181 @@ title: Parallelism --- -!!! warning +Fast-LLM supports four complementary parallelism strategies that can be combined to train models at any scale. This guide explains each strategy, how to configure it, and when to use it. - Looking for the parallelism guide? It's on its way, come back soon! +For background on memory sharding (ZeRO), see the [Multi-Stage guide](multi-stage.md). The strategies below focus on how the computation itself is distributed. + +--- + +## Overview + +| Strategy | Config key | What it splits | Primary benefit | +| --- | --- | --- | --- | +| Data parallelism | `distributed.batch_data_parallel` (derived) | Batch across GPUs | Throughput | +| Tensor parallelism | `distributed.tensor_parallel` | Layers horizontally (weight matrices) | Model memory | +| Pipeline parallelism | `distributed.pipeline_parallel` | Layers vertically (by depth) | Model memory | +| Sequence data parallelism | `distributed.sequence_data_parallel` | Sequence dimension across GPUs | Activation memory | + +These strategies compose: a 64-GPU run might use 4-way tensor parallelism, 4-way pipeline parallelism, and 4-way data parallelism simultaneously. + +--- + +## Data Parallelism + +Data parallelism replicates the full model on every GPU and processes different micro-batches in parallel. Gradients are averaged across all replicas before the optimizer step. + +Fast-LLM infers the data-parallel degree automatically: + +```text +data_parallel = world_size / (tensor_parallel × pipeline_parallel) +``` + +You do not set `data_parallel` directly — it falls out from the other settings. + +Data parallelism is the baseline scaling strategy: it increases training throughput proportionally to the number of replicas (assuming sufficient batch size) and adds no memory pressure for the model itself. Any GPUs not used by tensor or pipeline parallelism are automatically used for data parallelism. Its only constraint is that the global batch size grows with the number of replicas. + +--- + +## Tensor Parallelism + +Tensor parallelism (sometimes called *intra-layer model parallelism*) shards individual weight matrices across GPUs within a group. Each GPU holds a slice of the weight and computes its portion of the output; an all-reduce synchronizes results. + +```yaml +model: + distributed: + tensor_parallel: 4 # shard weights across 4 GPUs +``` + +Valid values are 1 (disabled) or any integer that divides `world_size`. In practice, powers of two work best (1, 2, 4, 8). + +**When to use:** When a single model layer (e.g. attention projection or MLP) does not fit on one GPU, or when activation memory from large hidden dimensions is the bottleneck. Tensor parallelism requires high-bandwidth interconnects (NVLink within a node) because it adds an all-reduce communication on every forward *and* backward pass of every sharded layer. + +**Rule of thumb:** Keep tensor parallelism within a node (≤ 8 GPUs). Crossing node boundaries with tensor parallelism incurs heavy inter-node communication overhead. + +### Sequence-Tensor Parallelism + +When tensor parallelism is active, you can enable an additional optimization that keeps activations distributed along the sequence dimension between layers, rather than replicating the full sequence on every GPU: + +```yaml +model: + distributed: + tensor_parallel: 4 + sequence_tensor_parallel: true +``` + +With this enabled, each GPU holds only `1 / tensor_parallel` of the sequence at any given time. Activations are gathered before layers that need the full sequence, and scatter-reduced afterward. This reduces peak activation memory per GPU by a factor of `tensor_parallel`, at the same total communication cost as without the option. It is recommended whenever `tensor_parallel > 1`. + +--- + +## Pipeline Parallelism + +Pipeline parallelism splits the model by depth: each GPU holds a consecutive block of layers. Activations flow forward from stage to stage; gradients flow backward. Multiple micro-batches can be in-flight simultaneously to keep all stages busy. + +```yaml +model: + distributed: + pipeline_parallel: 4 # split model across 4 GPUs +``` + +The number of layers per pipeline stage is controlled by how the total layer count divides across stages (see the [Multi-Stage guide](multi-stage.md) for `layers_per_stage`). + +Pipeline parallelism works well across slow interconnects (e.g. InfiniBand between nodes) because point-to-point sends only occur at stage boundaries, and their volume is proportional to the activation size of a single layer rather than the full model. + +### Scheduling micro-batches + +To hide pipeline bubbles, Fast-LLM uses *breadth-first* scheduling: it keeps several micro-batches in flight simultaneously so each stage always has work to do. + +```yaml +schedule: + micro_batch_splits: 1 # sub-divide each micro-batch along the sequence + breadth_first_micro_batches: 4 # interleave this many micro-batches across stages + depth_first_micro_batches: 1 # gradient accumulation steps within one stage +``` + +A larger `breadth_first_micro_batches` reduces idle (bubble) time but increases activation memory, since activations from all in-flight micro-batches are held simultaneously. Start with a value equal to the number of pipeline stages. + +!!! note + The total number of micro-batches per step (`breadth_first_micro_batches × depth_first_micro_batches`) must be at least equal to `pipeline_parallel`. Otherwise some pipeline stages will be idle for the entire step. + +**When to use:** When the model is too large to fit on a single node, or when you want to spread memory across nodes without incurring the per-layer all-reduce cost of tensor parallelism. Pipeline parallelism is naturally suited to slow cross-node links. + +--- + +## Sequence Data Parallelism + +Sequence data parallelism sub-divides the data-parallel group along the sequence dimension. Instead of each GPU processing an independent sequence in full, a group of GPUs collectively processes one sequence by splitting it into chunks. + +```yaml +model: + distributed: + sequence_data_parallel: 2 # 2 GPUs share each sequence +``` + +`sequence_data_parallel` must divide `data_parallel`. The effective batch dimension is: + +```text +batch_data_parallel = data_parallel / sequence_data_parallel +``` + +**When to use:** When training on very long sequences and activation memory is the primary constraint. Sequence data parallelism reduces per-GPU activation memory roughly in proportion to its value, at the cost of added gradient synchronization along the sequence dimension. + +--- + +## Combining Strategies + +All four strategies compose freely. A typical large-scale configuration looks like: + +```yaml +model: + distributed: + tensor_parallel: 4 # within-node weight sharding + sequence_tensor_parallel: true # sequence-split activations + pipeline_parallel: 8 # cross-node layer sharding + sequence_data_parallel: 1 # each sequence lives on one GPU + # data_parallel is inferred: world_size / (4 × 8) = e.g. 4 for a 128-GPU run + +schedule: + breadth_first_micro_batches: 8 # match pipeline depth +``` + +### Choosing a configuration + +Start with the simplest setup that fits the model in memory and scale from there: + +1. **Single GPU**: no parallelism needed. +2. **Multi-GPU, single node**: add `tensor_parallel` up to the number of GPUs (typically 8). Enable `sequence_tensor_parallel` alongside it. +3. **Multi-node**: add `pipeline_parallel` across nodes. Keep `tensor_parallel` within nodes. +4. **Very long sequences**: add `sequence_data_parallel` to reduce activation memory. +5. **Still out of memory**: increase `zero_stage` (see [Multi-Stage guide](multi-stage.md)). + +### Rank ordering + +By default, Fast-LLM assigns global ranks in tensor → data → pipeline order. If pipeline stages are on different sockets of the same machine, setting `pipeline_first: true` can improve NUMA locality: + +```yaml +model: + distributed: + pipeline_first: true +``` + +--- + +## Configuration Reference + +All distributed settings live under `model.distributed`: + +| Field | Default | Description | +| --- | --- | --- | +| `tensor_parallel` | `1` | Size of the tensor-parallel group | +| `pipeline_parallel` | `1` | Number of pipeline stages | +| `sequence_data_parallel` | `1` | Sub-divide data-parallel group by sequence | +| `sequence_tensor_parallel` | `false` | Enable sequence-parallel activation splitting in TP layers | +| `pipeline_first` | `false` | Swap data and pipeline rank ordering for NUMA locality | + +Schedule settings live under `schedule`: + +| Field | Default | Description | +| --- | --- | --- | +| `breadth_first_micro_batches` | `1` | Micro-batches in flight simultaneously (reduces pipeline bubble) | +| `depth_first_micro_batches` | `1` | Gradient accumulation steps within a stage | +| `micro_batch_splits` | `1` | Sub-divide each micro-batch along the sequence dimension | diff --git a/mkdocs.yaml b/mkdocs.yaml index 0ad00ccef..56c79f520 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -187,6 +187,7 @@ nav: - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md + - Parallelism: developer_guide/parallelism.md - Model: - Model: developer_guide/model.md - Conversion: developer_guide/conversion.md From 8a665fdb4de0c8725eac09c116f48ef660d57f5d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 20:11:14 -0400 Subject: [PATCH 28/33] Fix .bin checkpoint loading in HuggingFace handler `yield from torch.load(...)` on a dict yields only keys (strings), not the `(parameter_name, shard_name, tensor)` tuples that `_load_weights` is expected to produce. Fix by iterating `.items()` and yielding the correct 3-tuple. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/huggingface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 5379e51d9..8cdb779dd 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -165,7 +165,7 @@ def _load_weights( for key in f.keys(): yield key, "weights", f.get_slice(key) elif path.suffix == ".bin": - # TODO: Confirm that loading works with `weights_only=True` - yield from torch.load(path, weights_only=True) + for key, tensor in torch.load(path, weights_only=True).items(): + yield key, "weights", tensor else: raise NotImplementedError(f"Unknown file format for {path}") From beca8815e8bd40a65420b89021f23ef85362b3e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Mar 2026 21:04:48 -0400 Subject: [PATCH 29/33] Remove unnecessary .value calls on StrEnum members Since StrEnum values are str subclasses, .value is redundant when used in string contexts (f-strings, dict keys, DataLoader args, Triton kernels). Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/engine/schedule/runner.py | 14 +++++++------- fast_llm/engine/schedule/schedule.py | 4 ++-- fast_llm/engine/training/trainer.py | 6 +++--- fast_llm/functional/triton/mlp.py | 2 +- fast_llm/layers/common/linear/convolution.py | 2 +- fast_llm/models/gpt/conversion/apriel2.py | 4 ++-- tests/functional/test_functional.py | 2 +- tests/test_config.py | 2 +- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a25aede78..9253d0311 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -114,7 +114,7 @@ def get_iterator( prefetch_factor=prefetch_factor, pin_memory=self._distributed_config.use_cuda, collate_fn=functools.partial(self._collate_fn, dataset_name=dataset_name, preprocess=preprocess), - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + multiprocessing_context=self._config.multiprocessing_context if num_workers > 0 else None, ) if self._datasets[dataset_name].requires_broadcast: data_loader = DistributedDataLoaderWrapper( diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 7ad03b24c..b95d39463 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -164,7 +164,7 @@ def run_step( if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"Beginning of {context.phase.value} iteration {iteration}", str) + lambda: log_memory_usage(f"Beginning of {context.phase} iteration {iteration}", str) ) self._multi_stage.train(context.is_training) self._distributed.set_step(iteration, schedule.phase) @@ -278,7 +278,7 @@ def run_step( if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) + lambda: log_memory_usage(f"End of {context.phase} iteration {iteration}", str) ) return self._reduce_losses(context), update_successful, metrics @@ -487,12 +487,12 @@ def _handle_events(self, context: BatchContext) -> None: def _save_events(self, events, context: BatchContext) -> None: out = { "iteration": context.iteration, - "phase": context.phase.value, + "phase": context.phase, "rank": self._distributed_config.rank, "events": [ { - "event_type": type_.value, - "stream": stream.value, + "event_type": type_, + "stream": stream, "gpu_time": gpu_time, "cpu_time": cpu_time, **( @@ -500,7 +500,7 @@ def _save_events(self, events, context: BatchContext) -> None: if step is None else { "step_idx": step.global_index, - "step_type": step.type_.value, + "step_type": step.type_, "step_stage": step.stage, "step_depth_first_micro_batch": step.depth_first_micro_batch, "step_breadth_first_micro_batch": step.breadth_first_micro_batch, @@ -514,7 +514,7 @@ def _save_events(self, events, context: BatchContext) -> None: yaml.safe_dump( out, get_run().open_artifact( - f"schedule_profile_rank_{self._distributed_config.rank}_{context.phase.value}_step_{context.iteration}" + f"schedule_profile_rank_{self._distributed_config.rank}_{context.phase}_step_{context.iteration}" ), ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index e2a9c75b5..6f7bf1d95 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -149,7 +149,7 @@ def __init__( self._setup_metas() if self._config.debug_schedule: - logger.info(f"{self._phase.value} schedule:\n{self._steps}") + logger.info(f"{self._phase} schedule:\n{self._steps}") @property def phase(self) -> PhaseType: @@ -210,7 +210,7 @@ def _create_index(self) -> None: for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( step_map.pop((type_, stage, data_index), None) is not None - ), f"Missing {type_.value} step with stage={stage}, data_index={data_index}" + ), f"Missing {type_} step with stage={stage}, data_index={data_index}" Assert.empty(step_map) # Related steps diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index deda813bb..00cf2fa0d 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -192,7 +192,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: interrupter = Interrupter(self._config.training.checkpoint.enabled()) train_iterator = self._get_data_iterator( - PhaseType.training.value, + PhaseType.training, self._completed_steps, self._config.training.prefetch_factor, ) @@ -254,7 +254,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - metrics_key = PhaseType.training.value + metrics_key = PhaseType.training metrics[metrics_key] = { "batch_size": self._batch_size, **{ @@ -429,7 +429,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. - self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] + self._completed_steps = metadata["schedules"][PhaseType.training]["completed_steps"] else: self._completed_steps = metadata["completed_steps"] # TODO: Move barrier, ok file to FastLLMModel diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 4a8c5f179..52af93bde 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -156,7 +156,7 @@ def triton_mlp_activation_forward( input_, output, gated=gated, # noqa - activation_type=activation_type.value, # noqa + activation_type=activation_type, # noqa n_cols=n_cols, # noqa block_size=TritonConfig.POINTWISE_BLOCK_SIZE, ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 1c23d6d8a..fd9670d9f 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -60,7 +60,7 @@ def _forward_causal_conv1d( input_, self.weight.squeeze(1), self.bias, - activation=(None if self._activation == ActivationType.identity else self._activation.value), + activation=(None if self._activation == ActivationType.identity else self._activation), seq_idx=document_index, ) diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 91e3be508..d073830c4 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -495,7 +495,7 @@ def export_config(cls, config: StochasticMixerConfig) -> dict: "type": "stochastic", "mixers": mixers, "main_mixer_name": config.main_mixer_name, - "sampling_strategy": config.sampling_strategy.value, + "sampling_strategy": config.sampling_strategy, } @classmethod @@ -620,7 +620,7 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: mlp = { "type": "mlp", "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation.value, + "activation": config.mlp.activation, "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 7980f05bf..07b8768da 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -48,7 +48,7 @@ def test_mlp_recomputation(gated, activation, testing_device): param_grad_refs = [param.grad for param in params] for i, recompute_level in enumerate(MLPRecomputeLevel): - print(recompute_level.value) # noqa + print(recompute_level) # noqa input_.grad = None for param in params: param.grad = None diff --git a/tests/test_config.py b/tests/test_config.py index 492a57b02..792eab077 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -70,7 +70,7 @@ def test_serialize_default_config_updates(cls): @pytest.mark.parametrize("load_config", tuple(ModelConfigType)) def test_pretrained_config(load_config: ModelConfigType, result_path): - config_path = result_path / "pretrained_config" / load_config.value + config_path = result_path / "pretrained_config" / load_config pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { From 7d594ef15d1a625907f8366c7f426f8e762e913c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 30 Mar 2026 17:44:40 -0400 Subject: [PATCH 30/33] Fix GatedDeltaNetConfig losing value_heads validation due to duplicate _validate The class defined _validate twice; Python uses the last definition, silently dropping Assert.multiple(self.value_heads, self.key_heads) from the first one. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/ssm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce54685e8..9e690e668 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -89,9 +89,6 @@ def layer_class(self) -> "type[GatedDeltaNet]": return GatedDeltaNet - def _validate(self) -> None: - super()._validate() - @config_class(dynamic_type={MixerConfig: "kda"}) class KimiDeltaAttentionConfig(MixerConfig): From 0de1b6cdfae2734b0857b2a097cbb2ee95b05076 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 30 Mar 2026 18:38:06 -0400 Subject: [PATCH 31/33] Fix five logic bugs found in audit of core/, layers/block/, data/dataset/memmap/ - core/distributed.py: Remove spurious tensor.copy_() after Gloo+GPU send (copy-paste from recv; send should never write to the source tensor) - data/dataset/memmap/token.py: Fix _get_nearest_split always rounding down (fraction used cumsum[left]/cumsum[left+1] giving always-negative numerator; correct formula uses cumsum[left-1]/cumsum[left] as document span boundaries; also fixes OOB access when left == len(cumsum) - 1) - data/dataset/memmap/config.py: Fix blend_metadata writing rejected_spans data to wrong key "image_patches" and reading from wrong source key - layers/block/config.py: Fix raise warnings.warn() crashing with TypeError (raise None); should be warnings.warn() without raise - layers/block/block.py: Fix double-prefixed gradient tensor name in debug logging Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/core/distributed.py | 6 ++---- fast_llm/data/dataset/memmap/config.py | 4 ++-- fast_llm/data/dataset/memmap/token.py | 3 ++- fast_llm/layers/block/block.py | 2 +- fast_llm/layers/block/config.py | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 16f7d92c8..a71cbc306 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -164,9 +164,7 @@ def send( assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # send not supported for gloo on GPU. - tensor_cpu = tensor.cpu() - group.send([tensor_cpu], dst, tag).wait() - tensor.copy_(tensor_cpu) + group.send([tensor.cpu()], dst, tag).wait() return None work = group.send([tensor], dst, tag) if async_op: @@ -182,7 +180,7 @@ def recv( assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # recv not supported for gloo on GPU. - tensor_cpu = tensor.cpu() + tensor_cpu = tensor.new_empty(device="cpu") group.recv([tensor_cpu], src, tag).wait() tensor.copy_(tensor_cpu) return None diff --git a/fast_llm/data/dataset/memmap/config.py b/fast_llm/data/dataset/memmap/config.py index cc8665204..d043a20af 100644 --- a/fast_llm/data/dataset/memmap/config.py +++ b/fast_llm/data/dataset/memmap/config.py @@ -449,8 +449,8 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi [metadata_["chosen_spans"] for metadata_ in metadata] ) if "rejected_spans" in metadata[0]: - out["image_patches"] = RangeReaderConfig.blend_metadata( - [metadata_["image_patches"] for metadata_ in metadata] + out["rejected_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["rejected_spans"] for metadata_ in metadata] ) if "image_patches" in metadata[0]: out["image_patches"] = PatchReaderConfig.blend_metadata( diff --git a/fast_llm/data/dataset/memmap/token.py b/fast_llm/data/dataset/memmap/token.py index 3e8b86a3c..05a16d22d 100644 --- a/fast_llm/data/dataset/memmap/token.py +++ b/fast_llm/data/dataset/memmap/token.py @@ -54,7 +54,8 @@ def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: left = torch.searchsorted(cumsum, value, side="right") if left == len(cumsum): return left.item() - return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + prev = cumsum[left - 1].item() if left > 0 else 0 + return left.item() + 1 if (value - prev) / (cumsum[left].item() - prev) > 0.5 else left.item() class TokenWriter(MemmapWriter): diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index acf807c69..805eae1e5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -68,7 +68,7 @@ def __call__( "", tensor, level=level, - meta=self._get_meta(tensor, name + f"{name}.grad", dims), + meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b6ed2d851..aa47a5f2e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -169,7 +169,7 @@ def _validate(self): if missing := used_blocks - available_blocks: raise ValueError(f"The following blocks are present in the pattern but undefined: {missing}") if extra := available_blocks - used_blocks: - raise warnings.warn(f"The following blocks are defined but unused: {extra}") + warnings.warn(f"The following blocks are defined but unused: {extra}") super()._validate() From d32b6b03b59a56e8dde1f5f2fe51c372db55b038 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 31 Mar 2026 00:35:14 -0400 Subject: [PATCH 32/33] Fix streaming tests: keep ProcessGroupPool alive and use dynamic xread timeout Three issues fixed: 1. ProcessGroupPool was created anonymously and immediately GC'd, calling shutdown() on the NCCL communicator before broadcasts completed. Fix by storing the pool in a variable (self._pool / pool) and shutting it down explicitly in the finally/cleanup block. 2. Consumer used torch.cuda.set_device() via ProcessGroupPool's old use_cuda path, corrupting the test context's CUDA device. Fix by adding a device parameter to ProcessGroupPool that accepts an explicit torch.device, so consumers can pass their already-set current device. 3. Consumer xread timeout was hardcoded at 10s, too short for the first training step of SDP/TP configs which require CUDA kernel JIT compilation. Fix by using streaming_config.timeout (120s) instead. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/distributed/distributed.py | 14 ++++-- fast_llm/engine/training/streaming.py | 14 ++++-- tests/models/test_streaming.py | 53 ++++++++++++++-------- tests/utils/subtest.py | 5 +- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index b0ab08482..372fb7f68 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -27,9 +27,9 @@ def __init__( world_size: int | None = None, local_world_size: int | None = None, timeout: float = 60, - use_cuda: bool = True, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, + device: torch.device | None = None, ): self._rank = DistributedConfig.default_rank if rank is None else rank @@ -38,20 +38,24 @@ def __init__( DistributedConfig.default_local_world_size if local_world_size is None else local_world_size ) self._timeout = timeout - self._use_cuda = use_cuda self._backend = backend self._process_groups = {} - if self._use_cuda: + if device is None: assert torch.cuda.is_available() Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) torch.cuda.set_device(self._device) + elif device.type == "cuda": + assert torch.cuda.is_available() + torch.cuda.init() + self._device = device + torch.cuda.set_device(self._device) else: if backend == DistributedBackend.nccl: Assert.eq(self._world_size, 1) - self._device = torch.device("cpu") + self._device = device if self._world_size > 1: if self._rank == 0: @@ -165,8 +169,8 @@ def __init__(self, config: DistributedConfig): self._config.world_size, self._config.local_world_size, self._config.timeout, - self._config.use_cuda, backend=self._config.backend, + device=None if self._config.use_cuda else torch.device("cpu"), ) else: self._pool = _default_pool diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py index 7870b45bc..aec14530f 100644 --- a/fast_llm/engine/training/streaming.py +++ b/fast_llm/engine/training/streaming.py @@ -2,6 +2,8 @@ import logging import typing +import torch + from fast_llm.core.distributed import broadcast as _broadcast from fast_llm.core.distributed import broadcast_object as _broadcast_object from fast_llm.engine.distributed.config import DistributedBackend @@ -26,15 +28,16 @@ def __init__(self, config: ConfigType, model: "FastLLMModel"): init_method = f"tcp://{config.broadcast.host}:{config.broadcast.port}" logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") world_size = config.broadcast.external_world_size + 1 - self._process_group = ProcessGroupPool( + self._pool = ProcessGroupPool( rank=0, world_size=world_size, local_world_size=1, timeout=self._config.timeout, - use_cuda=self._config.broadcast.backend == DistributedBackend.nccl, + device=None if self._config.broadcast.backend == DistributedBackend.nccl else torch.device("cpu"), init_method=init_method, backend=self._config.broadcast.backend, - ).get_process_group(range(world_size), 0) + ) + self._process_group = self._pool.get_process_group(range(world_size), 0) logger.info(f"Weights broadcast rendezvous at {init_method} connected") def run_begin(self, step: int): @@ -61,8 +64,9 @@ def __del__(self): self._clear() def _clear(self): - if hasattr(self, "_process_group"): - self._process_group.shutdown() + if hasattr(self, "_pool"): + self._pool.shutdown() + del self._pool del self._process_group def _broadcast_weights(self, step: int): diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 0c40f0a48..e65c128f6 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -9,6 +9,8 @@ import safetensors import torch +from fast_llm.core.distributed import broadcast as _broadcast +from fast_llm.core.distributed import broadcast_object as _broadcast_object from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.distributed.distributed import ProcessGroupPool from fast_llm.engine.training.config import StreamingTrainerCallbackConfig @@ -68,27 +70,34 @@ def _run_event_consumer( path.mkdir(parents=True, exist_ok=True) field = REDIS_TRAINING_FIELD.encode() # TODO: Create a custom process group instead. + pool = None try: world_size = streaming_config.broadcast.external_world_size + 1 + consumer_rank = consumer_index + 1 backend = DistributedBackend.nccl if torch.cuda.is_available() else DistributedBackend.gloo - process_group = ProcessGroupPool( - rank=0, + pool = ProcessGroupPool( + rank=consumer_rank, world_size=world_size, - local_world_size=world_size, timeout=streaming_config.timeout, - use_cuda=backend == DistributedBackend.nccl, init_method=init_method, backend=backend, - ).get_process_group(range(world_size), 0) + device=( + torch.device("cuda", torch.cuda.current_device()) + if backend == DistributedBackend.nccl + else torch.device("cpu") + ), + ) + process_group = pool.get_process_group(range(world_size), consumer_rank) + timeout_ms = int(streaming_config.timeout * 1000) last_id = "0-0" while True: result = client.xread( streams={REDIS_TRAINING_STREAM: last_id}, count=1, - block=10000, + block=timeout_ms, ) if not result: - raise TimeoutError("No message received after 10000 ms...") + raise TimeoutError(f"No message received after {timeout_ms} ms...") ((stream, events),) = result Assert.eq(stream.decode(), REDIS_TRAINING_STREAM) @@ -102,15 +111,14 @@ def _run_event_consumer( elif message["type"] == "weights_ready": weights = {} while True: - meta = [None] - torch.distributed.broadcast_object_list(meta, group=process_group, group_src=0) - if meta[0] is None: + meta = _broadcast_object(None, process_group, src=0) + if meta is None: print(f"Weight broadcast finished") break - logging.info(f"receiving {meta[0]}") - shard_name, layer_name, tensor_size, tensor_type = meta[0] + logging.info(f"receiving {meta}") + shard_name, layer_name, tensor_size, tensor_type = meta tensor = torch.zeros(tuple(tensor_size), dtype=tensor_type, device="cuda") - torch.distributed.broadcast(tensor, group=process_group, group_src=0) + _broadcast(tensor, 0, process_group) if shard_name == "weights": weights[layer_name] = tensor safetensors.torch.save_file( @@ -118,7 +126,8 @@ def _run_event_consumer( ) finally: - torch.distributed.destroy_process_group(process_group) + if pool is not None: + pool.shutdown() def _run_model_streaming_configs( @@ -127,23 +136,24 @@ def _run_model_streaming_configs( # Import all dynamic classes. import fast_llm.cli # noqa - for config in _DISTRIBUTED_STREAMING_CONFIGS: + for config_index, config in enumerate(_DISTRIBUTED_STREAMING_CONFIGS): + config_port = port + config_index model_testing_config = update_and_add_testing_config( model_testing_config, None, updates={ - ("data", "datasets"): {"training": {"port": port, "timeout": 1.0}}, + ("data", "datasets"): {"training": {"port": config_port, "timeout": 1.0}}, ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { "type": "streaming", - "port": port, + "port": config_port, "broadcast": { - "port": port + 1000, + "port": config_port + 1000, "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, - "timeout": 1.0, + "timeout": 120, } }, # Disable tensor logging. @@ -192,9 +202,12 @@ def test_run_model_distributed_streaming( ): if torch.cuda.device_count() < 2: pytest.skip(f"Not enough GPUs") + model_testing_config.get_dataset() + # Use a fixed shift to avoid port conflicts with other distributed tests. + port = worker_resources.torchrun_port + 4321 run_parallel_script( _run_model_streaming_configs, - (run_test_script_base_path, model_testing_config, worker_resources.torchrun_port), + (run_test_script_base_path, model_testing_config, port), world_size=torch.cuda.device_count(), backend=model_testing_config.distributed_backend, ) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index b8f0b5b7a..78e4d4357 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -43,7 +43,10 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cuda=self._use_cuda + timeout=self._timeout, + init_method=self._init_method, + backend=self._backend, + device=None if self._use_cuda else torch.device("cpu"), ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size From 08e0b5e081c1362664a8351160bd8f02b4ce683b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 31 Mar 2026 00:38:04 -0400 Subject: [PATCH 33/33] Various fixes across data, layers, and conversions - token.py: fix confusing variable names in _get_nearest_split (left/right were swapped relative to what searchsorted returns) - grpo_loss.py: split shared_kwargs from epsilon kwargs to avoid passing epsilon params to the parallel_max_logits kernel that doesn't accept them - linear/config.py: inherit LinearBaseConfig from ModuleConfig and mark LinearConfig/AffineLinearConfig/CausalConv1dConfig as non-abstract so architecture comparison works correctly for linear layers - base_model/config.py: assert non-architecture fields are not ModuleConfig subclasses (which would silently skip nested architecture validation) - apriel2.py: use .value on sampling_strategy and activation enum fields when serializing to dict - qwen2.py: fix attention_bias default (False, not True) for Qwen2 import Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/data/dataset/memmap/token.py | 10 +++++----- fast_llm/engine/base_model/config.py | 2 ++ fast_llm/functional/triton/grpo_loss.py | 11 +++++++---- fast_llm/layers/common/linear/config.py | 11 +++++++++-- fast_llm/models/gpt/conversion/apriel2.py | 4 ++-- fast_llm/models/gpt/conversion/qwen2.py | 2 +- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/dataset/memmap/token.py b/fast_llm/data/dataset/memmap/token.py index 05a16d22d..84d34613b 100644 --- a/fast_llm/data/dataset/memmap/token.py +++ b/fast_llm/data/dataset/memmap/token.py @@ -51,11 +51,11 @@ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dic def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: - left = torch.searchsorted(cumsum, value, side="right") - if left == len(cumsum): - return left.item() - prev = cumsum[left - 1].item() if left > 0 else 0 - return left.item() + 1 if (value - prev) / (cumsum[left].item() - prev) > 0.5 else left.item() + right = torch.searchsorted(cumsum, value, side="right") + if right == len(cumsum): + return right.item() + left = cumsum[right - 1].item() if right > 0 else 0 + return right.item() + 1 if (value - left) / (cumsum[right].item() - left) > 0.5 else right.item() class TokenWriter(MemmapWriter): diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index a68c9ebc8..074412c9f 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -45,6 +45,8 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + else: + assert not isinstance(field, ModuleConfig) return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py index deb261f09..39d832ccd 100644 --- a/fast_llm/functional/triton/grpo_loss.py +++ b/fast_llm/functional/triton/grpo_loss.py @@ -152,15 +152,18 @@ def triton_grpo_loss_forward_backward( block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) - kwargs = { + shared_kwargs = { "logits_stride_0": logits.stride(-2), "n_cols": n_cols, "logits_scale_factor": logits_scale_factor, - "epsilon_low": epsilon_low, - "epsilon_high": epsilon_high, "block_size": block_size, "num_warps": num_warps, } + kwargs = { + **shared_kwargs, + "epsilon_low": epsilon_low, + "epsilon_high": epsilon_high, + } if grad_output is None: backward_kwargs = {} else: @@ -205,7 +208,7 @@ def triton_grpo_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, predicted_logits_ptr=predicted_logits_local, col_min=n_cols * group.rank(), - **kwargs, + **shared_kwargs, ) max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) torch.distributed.all_reduce(predicted_logits_local, op=torch.distributed.ReduceOp.SUM, group=group) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index 4c64f0816..803edc302 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,6 +1,7 @@ import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -14,7 +15,7 @@ @config_class() -class LinearBaseConfig(Config): +class LinearBaseConfig(ModuleConfig): """ Configuration for a linear-like layer without bias. """ @@ -47,6 +48,8 @@ class AffineLinearBaseConfig(LinearBaseConfig): class LinearConfig(LinearBaseConfig): """Configuration for a linear (weight-only, no bias) layer with optional PEFT and tensor-parallelism support.""" + _abstract = False + apply_peft: bool | None = Field( default=None, desc="Wrap this layer ." @@ -108,6 +111,8 @@ def get_layer( class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): """Configuration for an affine linear layer (weight + optional bias) with optional PEFT and tensor-parallelism support.""" + _abstract = False + def get_layer( self, in_dim: TensorDim, @@ -171,6 +176,8 @@ class CausalConv1dConfig(AffineLinearBaseConfig): Configuration for a 1d causal convolution, as used in mamba layers. """ + _abstract = False + kernel_size: int = Field( default=4, desc="Convolution kernel size.", diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index d073830c4..91e3be508 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -495,7 +495,7 @@ def export_config(cls, config: StochasticMixerConfig) -> dict: "type": "stochastic", "mixers": mixers, "main_mixer_name": config.main_mixer_name, - "sampling_strategy": config.sampling_strategy, + "sampling_strategy": config.sampling_strategy.value, } @classmethod @@ -620,7 +620,7 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: mlp = { "type": "mlp", "intermediate_size": config.mlp.intermediate_size, - "activation": config.mlp.activation, + "activation": config.mlp.activation.value, "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 4ebf18c3a..473135648 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -25,7 +25,7 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: - config["attention_bias"] = True + config["attention_bias"] = False out = super().import_config(config) out["query_layer"] = {"bias": {"enabled": True}} out["key_layer"] = {"bias": {"enabled": True}}