diff --git a/.gitignore b/.gitignore index f468ffd00..8a01a55f7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ # Doc build .cache site +docs/reference/ # Distribution / packaging *.egg-info/ diff --git a/README.md b/README.md index d02e7f95e..2a1b9f4e2 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,12 @@ As a truly open-source project, Fast-LLM allows full customization and extension We'll walk you through how to use Fast-LLM to train a large language model on a cluster with multiple nodes and GPUs. We'll show an example setup using a Slurm cluster and a Kubernetes cluster. -For this demo, we will train a Mistral-7B model from scratch for 100 steps on random data. The config file `examples/mistral-4-node-benchmark.yaml` is pre-configured for a multi-node setup with 4 DGX nodes, each with 8 A100-80GB or H100-80GB GPUs. +For this demo, we will train a Mistral-7B model from scratch for 100 steps on random data. The config file `examples/mistral.yaml` defines the model architecture and training settings, while the example launch scripts are pre-configured for a 4-node setup with 8 GPUs per node. > [!NOTE] > Fast-LLM scales from a single GPU to large clusters. You can start small and expand based on your resources. -Expect to see a significant speedup in training time compared to other libraries! For training Mistral-7B, Fast-LLM is expected to achieve a throughput of **9,800 tokens/s/H100** (batch size 32, sequence length 8k) on a 4-node cluster with 32 H100s. +Expect to see a significant speedup in training time compared to other libraries! For training Mistral-7B, Fast-LLM is expected to achieve a throughput of **9,800 tokens/s/H100** (micro-batch size 8k tokens, total batch size 256k tokens) on a 4-node cluster with 32 H100s. ### Running Fast-LLM on a Slurm Cluster @@ -77,7 +77,7 @@ Expect to see a significant speedup in training time compared to other libraries #### Steps -1. Deploy the [nvcr.io/nvidia/pytorch:24.07-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) Docker image to all nodes (recommended), because it contains all the necessary dependencies. +1. Deploy the [nvcr.io/nvidia/pytorch:25.11-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) Docker image to all nodes (recommended), because it contains all the necessary dependencies. 2. Install Fast-LLM on all nodes: ```bash @@ -88,7 +88,7 @@ Expect to see a significant speedup in training time compared to other libraries #SBATCH --ntasks=$(scontrol show node | grep -c NodeName) #SBATCH --exclusive - srun bash -c 'pip install --no-cache-dir -e "git+https://github.com/ServiceNow/Fast-LLM.git#egg=llm[CORE,OPTIONAL,DEV]"' + srun bash -c 'pip install --no-cache-dir "fast-llm[CORE,OPTIONAL] @ git+https://github.com/ServiceNow/Fast-LLM.git"' EOF ``` @@ -115,7 +115,7 @@ Now, you can sit back and relax while Fast-LLM trains your model at full speed! #### Steps -1. Create a Kubernetes [PersistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) (PVC) named `fast-llm-home` that will be mounted to `/home/fast-llm` in the container using [examples/fast-llm-pvc.yaml](examples/fast-llm-pvc.yaml): +1. Create a Kubernetes [PersistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) (PVC) named `pvc-fast-llm-home` that will be mounted to `/home/fast-llm` in the container using [examples/fast-llm-pvc.yaml](examples/fast-llm-pvc.yaml): ```bash kubectl apply -f examples/fast-llm-pvc.yaml diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 6f42d8b6a..4ce982110 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -76,124 +76,99 @@ class AwesomeHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler) ### Configuration conversion -The configuration conversion utility interfaces between two configurations in the form of nested dictionaries: -a serialized Fast-LLM configuration and an external configuration. -The `_load_config` method is expected to read the configuration on disk, as expected by the checkpoint format, -and return the same configuration in the forma of a nested dictionary, -with `_save_config` handling the reverse operation. -See the [Hugging Face implementation](https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/engine/checkpoint/huggingface.py) for an example. - -To perform the conversion, the checkpoint handler relies on a list of `ParamConverter` objects, -which describe how individual parameters (or in some case multiple ones) should be converted. -The `ParamConverter` base interface is a dataclass consisting of two variables and two methods: - -* `fast_llm_names: tuple[tuple[str, ...], ...]`: An array of entry names on the Fast-LLM side, in tuple format. -For example, `((transformer, head_groups),)` refers to the single entry `config["transformer"]["head_groups"]`. -* `export_names: tuple[tuple[str, ...], ...]`: An array of entry names on the external side, in the same tuple format. -* `export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]`: -This method takes the configuration parameters corresponding to `fast_llm_names` (in the same order), -and returns converted parameters corresponding to `export_names`. -* `import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]`: -The converse of`export_params`, converting parameters corresponding to `export_names` into those corresponding to `fast_llm_names`. - -While not strictly part of the interface, it may also be useful to define a dataclass `__post_init__`, -for example to restrict the number of parameters in `fast_llm_names` and `export_names`. - -Fast-LLM offers several generic configuration converter classes, including: - -* `RenameParamConverter`: A simple 1-1 mapping between parameters, with optional renaming but identical value. -Typically, most converters are of this type. -* `ConstantImportParamConverter`: A 1-0 mapping for Fast-LLM parameters that without an equivalent in the external format, -that must take a specific value `fast_llm_value` for conversion to make sense (i.e., they take a hard-coded value in the external format). -This type of converter is common for Hugging Face converters, as Hugging Face models support much fewer configuration parameters. -* `ConstantExportParamConverter`: A 0-1 mapping, the converse of `ConstantImportParamConverter` -* `MappedConfigParamConverter`: A 1-1 mapping similar to `RenameParamConverter`, but with a non-trivial relation between values. - -In addition to those, you may need to implement your own custom converter. -Here is an example that associates several Fast-LLM variables with a tuple. +Configuration conversion is handled by a `HuggingFaceBaseModelConverter` subclass, +which is linked to the handler via a `base_model_converter_class` class variable. +The converter implements three class methods: -```python -@dataclasses.dataclass(kw_only=True) -class PackingParamConverter(ParamConverter): - def __post_init__(self): - # There may be any number of Fast-LLM variables, but only one external one - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values): - # Pack the values into a single tuple. - return (fast_llm_values,) - - def import_params(self, export_values): - # Unpack the values. We can safely assume `export_values` has length one because of the assertion in `__post_init__` - return export_values[0] -``` +* `import_config(cls, config: dict) -> dict`: +Reads the external (e.g., Hugging Face) configuration dict and returns a Fast-LLM `base_model` config dict. +* `export_config(cls, config: BaseModelConfig) -> dict`: +Takes a Fast-LLM `BaseModelConfig` object and returns the corresponding external configuration dict. +* `get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]`: +Returns the list of weight converters for this model (described in the next section). -Now that we've seen how parameter converters work, we're ready to add them to our handler class. -We do so by creating a list of converters in the `_create_config_converters` class method. -Continuing our `AwesomeModel` handler example, we define: +The `_load_config` and `_save_config` methods on the handler read and write the external configuration file. +See the [Hugging Face implementation](https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/engine/checkpoint/huggingface.py) for their default implementation. + +Continuing our `AwesomeModel` example, the base model converter class could look like: ```python +class AwesomeBaseModelConverter(HuggingFaceBaseModelConverter): @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - # For Hugging Face handlers, we need to call the superclass method. - return super()._create_config_converters() + [ - # A trivial example where both the name and value are the same on both sides. - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - # A non-trivial example of `RenameParamConverter` with renaming and handling of nested dictionaries. - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - # A constant import example indicating that the external format does not support absolute positional embeddings. - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - # The `architectures` parameter is a common use case for `ConstantExportParamConverter` in Hugging Face models. - ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AwesomeModelForCausalLM"]), - # A value mapping example, where we match Fast-LLM activation types with their Hugging Face equivalents. - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - # A more hypothetical example using `PackingParamConverter` to pack two parameters `epsilon_1`, `epsilon_2` into a tuple `eps`. - PackingParamConverter( - fast_llm_names=(("epsilon_1",),("epsilon_2",)), - export_names=(("eps",),), - ), - ] -``` + def import_config(cls, config: dict) -> dict: + # Build and return a Fast-LLM base_model config dict from the external config. + return { + "hidden_size": config["hidden_size"], + "embeddings": {"vocab_size": config["vocab_size"]}, + "decoder": { + "num_blocks": config["num_hidden_layers"], + "block": { + "mixer": { + "heads": config["num_attention_heads"], + "head_groups": config.get("num_key_value_heads", config["num_attention_heads"]), + "rotary": {"type": "default", "theta": config.get("rope_theta", 10000)}, + "add_linear_biases": False, + }, + "mlp": { + "intermediate_size": config["intermediate_size"], + "gated": True, + "activation": ActivationType.from_hf_name(config["hidden_act"]), + "add_linear_biases": False, + }, + "normalization": {"type": "rms_norm", "epsilon": config["rms_norm_eps"]}, + }, + }, + "head": {"normalization": {"type": "rms_norm", "epsilon": config["rms_norm_eps"]}}, + "tied_embedding_weight": config.get("tie_word_embeddings", False), + } -!!! note "How conversion works" - The once the converters are defined, the conversion utility takes it from there. - Exporting works as follows (importing work similarly): - *The handler creates an empty export config dict, then loops over its list of converters. For each converter, it: - * Reads the value of each parameter defined in `fast_llm_names`, and gathers them in a tuple. - *Calls `converter.export_params`, providing the set of read values as argument. - * Ensure that the returned value has the correct length (that of `export_names`) - * Set the respective values in the export config dict. + @classmethod + def export_config(cls, config: AwesomeBaseModelConfig) -> dict: + # Build and return the external config dict from the Fast-LLM config object. + decoder_block = config.decoder.block + return { + "model_type": "awesome_model", + "architectures": ["AwesomeModelForCausalLM"], + "hidden_size": config.hidden_size, + "vocab_size": config.embeddings.vocab_size, + "num_hidden_layers": config.decoder.num_blocks, + "num_attention_heads": decoder_block.mixer.heads, + "num_key_value_heads": decoder_block.mixer.head_groups, + "rope_theta": decoder_block.mixer.rotary.theta, + "intermediate_size": decoder_block.mlp.intermediate_size, + "hidden_act": decoder_block.mlp.activation.hf_name, + "rms_norm_eps": decoder_block.normalization.epsilon, + "tie_word_embeddings": config.tied_embedding_weight, + } -!!! note "About `MISSING` and `DEFAULT`" - If a value is not found during import, it will be replaced by the `MISSING` tag. - The converter's `import_params` has the opportunity to handle this missing value, - and if a `MISSING`, the handler will throw an error because it does not know what value to set on the Fast-LLM side. + @classmethod + def get_converters(cls, config: AwesomeBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + # Described in the next section. + ... +``` - The `MISSING` tag is also supported during export, - but has a different meaning as the value is always expected to be found in the Fast-LLM configuration. - Instead, `export_params` may return a `MISSING` tag indicating that no value should not be added to the Fast-LLM config. - It may also return `DEFAULT`, which will be replaced by the default value for the configuration parameter. +Then wire the converter into the handler via `base_model_converter_class`: - Note that the handling of `MISSING` and `DEFAULT` is experimental and may be improved in the future. +```python +class AwesomeHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model_class = AwesomeModelConfig + architecture = "AwesomeModelForCausalLM" + base_model_converter_class = AwesomeBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + from transformers import AutoConfig + return AutoConfig +``` ### State conversion State conversion follows the same principle as configuration conversion, but acts on flat dictionaries of state tensors. Converters are defined by subclassing `WeightConverter`, with the interface: -* `fast_llm_name: str | tuple[str, ...]`: An entry name or array of entry names on the Fast-LLM side. -For example, `((transformer, head_groups),)` refers to the single entry `config["transformer"]["head_groups"]`. -* `export_name: str | tuple[str, ...]`: An entry name or array of entry names on the external side. +* `fast_llm_name: str | tuple[str, ...]`: A state dict key, or tuple of keys, on the Fast-LLM side. +For example, `"layers.0.mixer.weight"` or `("layers.0.weight_1", "layers.0.weight_2")`. +* `export_name: str | tuple[str, ...]`: A state dict key, or tuple of keys, on the external side. * `export_weight(self, weight: tuple[torch.Tensor | SafeTensorSlice, ...]) -> tuple[torch.Tensor | SafeTensorSlice, ...]`: This method takes the state dict entries corresponding to `fast_llm_name` (in the same order), and returns converted entries corresponding to `export_name`. @@ -225,19 +200,20 @@ class TransposeWeightConverter(WeightConverter): return (weight[0][:].transpose().contiguous(),) ``` -We define the list of weight converters in the `_create_weight_converters` method. -Continuing our `AwesomeModel` handler example, we define: +We define the list of weight converters in the `get_converters` class method of the base model converter. +Continuing our `AwesomeModel` example, we define: ```python - def _create_weight_converters(self) -> list[WeightConverter]: + @classmethod + def get_converters(cls, config: AwesomeBaseModelConfig, exported_config: dict) -> list[WeightConverter]: converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = len(self._model.config.base_model.decoder) + # The set of converters may depend on the base model configuration. + num_layers = config.decoder.num_blocks # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - # We usually want to loop dynamically over layers + # We usually want to loop dynamically over layers. for i in range(num_layers): # A `SplitWeightConverter` example, splitting a weight in two. converters.append(SplitWeightConverter( diff --git a/docs/developer_guide/parallelism.md b/docs/developer_guide/parallelism.md new file mode 100644 index 000000000..651d0b4d6 --- /dev/null +++ b/docs/developer_guide/parallelism.md @@ -0,0 +1,212 @@ +--- +title: Parallelism Internals +--- + +This document describes how Fast-LLM's four parallelism strategies are implemented. It is aimed at contributors adding new layer types, modifying the distributed engine, or debugging performance issues. + +For user-facing configuration, see the [Parallelism guide](../user_guide/parallelism.md). + +--- + +## Rank Assignment + +All rank arithmetic lives in `fast_llm/engine/distributed/config.py`. Given `world_size`, `tensor_parallel`, `pipeline_parallel`, and `sequence_data_parallel`, the derived dimensions are: + +```python +data_parallel = world_size // (tensor_parallel * pipeline_parallel) +batch_data_parallel = data_parallel // sequence_data_parallel + +tensor_rank = rank % tensor_parallel +data_rank = (rank // tensor_parallel) % data_parallel +pipeline_rank = rank // (tensor_parallel * data_parallel) +batch_data_rank = data_rank // sequence_data_parallel +sequence_data_rank = data_rank % sequence_data_parallel +``` + +When `pipeline_first=True`, `data_rank` and `pipeline_rank` are swapped: + +```python +pipeline_rank = (rank // tensor_parallel) % pipeline_parallel +data_rank = (rank // tensor_parallel) // pipeline_parallel +``` + +This alternative ordering places pipeline stages nearer in global rank space, which can improve NUMA locality when each node runs multiple pipeline stages. + +--- + +## Process Groups + +`fast_llm/engine/distributed/distributed.py` constructs the NCCL (or Gloo for CPU) process groups from the `DistributedConfig`. Groups are de-duplicated through `ProcessGroupPool` — if two parallelism dimensions happen to cover the same set of ranks, they share a single underlying `torch.distributed.ProcessGroup`. + +The named groups used throughout the engine are: + +| Name | Members | Primary use | +| --- | --- | --- | +| `world` | All ranks | Global barriers | +| `tensor` | Same `data_rank`, `pipeline_rank` | TP all-reduces | +| `data` | Same `tensor_rank`, `pipeline_rank` | ZeRO reduce-scatter / all-gather | +| `pipeline` | Same `tensor_rank`, `data_rank` | Pipeline send/recv | +| `sequence_data` | Same `tensor_rank`, `pipeline_rank`, `batch_data_rank` | Sequence-parallel reduction | +| `batch_data` | Same `tensor_rank`, `pipeline_rank`, `sequence_data_rank` | Non-sequence data reduction | +| `tensor_and_data` | Same `pipeline_rank` | ZeRO with TP | +| `tensor_and_sequence_data` | Same `pipeline_rank`, `batch_data_rank` | Sequence-TP activations | +| `model_and_sequence_data` | Same `batch_data_rank` | Cross-pipeline sequence gradient | + +`Distributed.set_step(step, phase)` reseeds per-step generators (`pp_generator`, `tp_generator`) using large prime offsets per dimension, so dropout and other stochastic ops are deterministically reproducible across ranks and resumptions. + +--- + +## Tensor Parallelism + +### Sharded linear layers + +Tensor parallelism is implemented by two linear layer variants in `fast_llm/layers/common/linear/linear.py`: + +**`OutputParallelLinear`** (column split): + +- Weight shape: `[output_dim / tensor_parallel, input_dim]` +- Each rank computes a different slice of the output columns +- Forward: local `Y_local = X @ W_local`; output stays partitioned — no communication on the output +- If `sequence_parallel`: input is first **all-gathered** across the tensor group before the matmul +- Backward: grad_input is **all-reduced** (or **reduce-scattered** with sequence-TP) across the tensor group +- Used for: Q/K/V projections, MLP gate/up projections + +**`InputParallelLinear`** (row split): + +- Weight shape: `[output_dim, input_dim / tensor_parallel]` +- Each rank holds a row slice of the weight (a slice of the input dimension) +- Forward: local `Y_local = X_local @ W_local`, then **all-reduce** output across the tensor group (so every rank has the full output) +- If `sequence_parallel`: output is **reduce-scattered** instead of all-reduced, so each rank ends up with a sequence slice +- Custom autograd via `input_parallel_linear_autograd` to correctly handle gradient flow +- Used for: attention output projection, MLP down projection + +### Sequence-tensor parallelism + +The standard (non-sequence-TP) TP pattern replicates the full sequence on every rank between layers. Sequence-tensor parallelism keeps activations distributed across the sequence dimension between layers, reducing activation memory by a factor of `tensor_parallel`. + +At each transformer sub-layer (attention or MLP), the flow is: + +- **`OutputParallelLinear`**: **all-gather** the sequence-chunked input → full sequence × partial output columns per rank +- Attention / elementwise ops: operate on full-sequence slices +- **`InputParallelLinear`**: matmul → **reduce-scatter** the output → each rank returns to holding `seq_len / tensor_parallel` rows + +The total communication volume (all-gather + reduce-scatter) equals that of a single all-reduce, so there is no extra bandwidth cost. The benefit is smaller activation tensors between layers. + +### Adding a new tensor-parallel layer + +1. Declare weight dimensions using `TensorDim` objects from `fast_llm/engine/config_utils/tensor_dim.py`. Mark the sharded dimension with the appropriate `DistributedDim`. +2. Inherit from `OutputParallelLinear` or `InputParallelLinear`, or implement the custom forward/backward directly in `fast_llm/functional/`. +3. Ensure the layer's `forward()` uses the standard signature: `(input, kwargs, losses, metrics) → Tensor`. + +--- + +## Pipeline Parallelism + +### Stage splitting + +`MultiStageModel._split_into_stages()` (in `fast_llm/engine/multi_stage/multi_stage.py`) partitions the flat list of `Layer` objects returned by `BaseModel.get_layers()`. Each stage holds `layers_per_stage` consecutive layers. The mapping from stage index to pipeline rank is: + +```python +pipeline_rank = (stage_index // stages_per_pipeline_stage) % pipeline_parallel +``` + +Stages owned by this rank have full weights and compute. Stages on other pipeline ranks are metadata-only stubs (except for tied weights, see below). + +### Tied weights + +Embedding and LM-head weights are often shared. When these layers land on different pipeline stages, `Stage._tied_weight_copies` holds a reference-only copy: + +- Forward pass: tied weights are **broadcast** from the owning stage to all stages that need them. +- Backward pass: gradients from non-owning stages are **all-reduced** back to the owning stage. + +### Schedule + +The schedule (`fast_llm/engine/schedule/`) builds a DAG of `ScheduleNode` operations (forward, backward, send, recv, optimizer step) and executes them across three CUDA streams (compute, send, recv). Pipeline communication uses PyTorch `isend` / `irecv` for overlap with compute. + +`breadth_first_micro_batches` controls how many micro-batches are in-flight at once. With `N` pipeline stages and `breadth_first_micro_batches = N`, the pipeline bubble fraction approaches zero for large batches. + +--- + +## Data Parallelism and ZeRO/FSDP + +Data parallelism in Fast-LLM is inseparable from the ZeRO sharding implementation in `fast_llm/engine/multi_stage/fsdp.py`. The `FSDP` class owns the per-stage weight and gradient buffers and orchestrates all-gather / reduce-scatter across the data-parallel group. + +### Buffer layout + +Each `FSDP` instance maintains flat buffers for a pipeline stage's parameters: + +```text +_weight_shard : [total_params / data_parallel] # local shard, always resident +_weight_buffer : [total_params] # full weights, reconstructed on demand (ZeRO-3) +_grad_shard : [total_params / data_parallel] # local gradient shard +_grad_buffer : [total_params] # full gradient accumulation buffer +``` + +Every parameter is a view into the appropriate buffer slice, so there are no copies during forward/backward — the buffer *is* the parameter storage. + +Shards are padded to a multiple of `SHARD_PAD_TO_MULTIPLE` (32) per rank to ensure aligned communication. + +### Forward pass (`restore_parameters`) + +Before each forward pass through a stage: + +1. Copy `_weight_shard` into the local slice of `_weight_buffer` +2. If ZeRO stage 3: `all_gather(_weight_buffer)` across the data-parallel group + +With double-buffering (`num_weight_buffers=2`), the all-gather for stage `i+1` is issued asynchronously while stage `i` is computing. + +### Backward pass (`reduce_gradients`) + +After each backward pass: + +1. If sequence-parallel: `all_reduce` sequence-parallel gradient contributions across the tensor-and-sequence-data group +2. `reduce_scatter(_grad_buffer → _grad_shard)` across the data-parallel group (average reduction) +3. If the gradient shard dtype differs from the buffer dtype (e.g. fp32 grad shard, bf16 buffer), copy and cast + +With double-buffering (`num_grad_buffers=2`), the reduce-scatter for stage `i` is overlapped with the backward pass for stage `i-1`. + +### ZeRO stage effect on buffers + +| ZeRO stage | `_weight_buffer` | `_grad_buffer` | Communication | +| --- | --- | --- | --- | +| 1 | Not used (weights replicated) | Not used (grads replicated, then all-reduce) | All-reduce on grads | +| 2 | Not used | Used (grad reduce-scatter → shard) | Reduce-scatter on grads | +| 3 | Used (all-gather before forward) | Used | All-gather on weights + reduce-scatter on grads | + +--- + +## Sequence Data Parallelism + +Sequence data parallelism sub-divides the data-parallel group by the sequence dimension. The `sequence_data` process group covers ranks with the same `tensor_rank`, `pipeline_rank`, and `batch_data_rank`. + +During the forward pass of sequence-parallel layers, each rank holds a contiguous chunk of the sequence. When a layer requires the full sequence (e.g. attention), an all-gather is performed across the `sequence_data` group. After the layer, a reduce-scatter returns each rank to its sequence chunk. + +Gradient synchronization is handled in `FSDP.reduce_gradients`: gradients from the sequence-parallel group are all-reduced before the data-parallel reduce-scatter, so the sequence dimension is handled before any ZeRO sharding. + +--- + +## Seeding and Reproducibility + +`Distributed.set_step(step, phase)` is called at the start of each forward/backward pass. It reseeds two per-rank generators: + +- `pp_generator`: seeded by `(step, phase, tensor_rank, data_rank)` — ensures dropout is identical across pipeline ranks within the same TP group. +- `tp_generator`: seeded by `(step, phase, pipeline_rank, data_rank)` — ensures TP ranks sample the same dropout mask. + +Large prime offsets per dimension ensure seeds are distinct across all rank combinations. This guarantees deterministic training regardless of which GPU processes which micro-batch, and allows exact resumption from a checkpoint. + +--- + +## Key Source Files + +| File | Purpose | +| --- | --- | +| `fast_llm/engine/distributed/config.py` | `DistributedConfig`: rank arithmetic, derived fields | +| `fast_llm/engine/distributed/distributed.py` | `Distributed`: process group construction, `ProcessGroupPool`, seeding | +| `fast_llm/engine/multi_stage/fsdp.py` | `FSDP`: buffer management, all-gather, reduce-scatter, checkpoint loading | +| `fast_llm/engine/multi_stage/multi_stage.py` | `MultiStageModel`: stage splitting, tied weights | +| `fast_llm/engine/multi_stage/config.py` | `MultiStageConfig`: ZeRO stage, buffer counts | +| `fast_llm/layers/common/linear/linear.py` | `OutputParallelLinear`, `InputParallelLinear` | +| `fast_llm/functional/linear.py` | Functional forward/backward for TP linear ops | +| `fast_llm/engine/schedule/config.py` | `ScheduleConfig`: micro-batch and pipeline scheduling | +| `fast_llm/engine/schedule/runner.py` | `ScheduleRunner`: DAG execution, CUDA stream management | +| `tests/utils/distributed_configs.py` | 20+ reference configurations combining all strategies | diff --git a/docs/help.md b/docs/help.md index ed59dffa7..e368e349e 100644 --- a/docs/help.md +++ b/docs/help.md @@ -10,7 +10,7 @@ Welcome to the Fast-LLM Help Center! Here, you'll find fixes for common hiccups, Let's stay one step ahead of those pesky gotchas. Here's a list of common issues and quick fixes: -- **CUDA Out of Memory**: When the GPU throws a fit, a few tweaks can help. First, try lowering `micro_batch_size` or `sequence_length` in the configuration to fit within the available memory. Still stuck? Try setting the `mlp_recompute_level` option to `activation` or `full` to save memory in the backward pass, or experiment with higher ZeRO stages for reduced memory usage. And if that's not enough, tensor or model parallelism may be your friend. +- **CUDA Out of Memory**: When the GPU throws a fit, a few tweaks can help. First, try lowering `micro_batch_size` or `maximum_document_length` under `data:` in the configuration to fit within the available memory. Still stuck? Try setting the `recompute_level` option under `model: base_model: decoder: block: mlp:` to `activation` or `full` to save memory in the backward pass, or experiment with higher ZeRO stages for reduced memory usage. And if that's not enough, tensor or model parallelism may be your friend. - **Python Hash Seed Sync Error**: Encountering an error like @@ -28,7 +28,7 @@ Let's stay one step ahead of those pesky gotchas. Here's a list of common issues Watchdog caught collective operation timeout: WorkNCCL(SeqNum=408951, OpType=_ALLGATHER_BASE, … , Timeout(ms)=600000) ran for 600351 milliseconds before timing out ``` - appearing across all GPU workers, it usually means one or more hosts failed to complete a NCCL operation, causing others to block. NCCL errors can be frustrating to diagnose since they rarely specify which node or GPU caused the issue. It is difficult to surface which messages and operations are in progress during these crashes. If the issue happens at a specific moment of training like dataset preparation or model export, the issue might be that this specific procedure took too long and timed out other processes (e.g. when preparing large datasets for long training runs, or saving large models on slow storage). In this case, it can help to increase the timeout `distributed_timeout: 3600`. + appearing across all GPU workers, it usually means one or more hosts failed to complete a NCCL operation, causing others to block. NCCL errors can be frustrating to diagnose since they rarely specify which node or GPU caused the issue. It is difficult to surface which messages and operations are in progress during these crashes. If the issue happens at a specific moment of training like dataset preparation or model export, the issue might be that this specific procedure took too long and timed out other processes (e.g. when preparing large datasets for long training runs, or saving large models on slow storage). In this case, it can help to increase the timeout by setting `model: distributed: timeout: 3600` in your config. In some other cases, the best we can do is to restart the training job and hope it doesn't happen again. If the issue persists, it might be because of network congestion or a problematic GPU. If the worker that crashed is consistent across multiple runs, it's likely a hardware issue. If you can't resolve it, open an issue on GitHub, and we'll help you troubleshoot. For more detailed solutions, check out our GitHub Issues page. Odds are someone's already tackled a similar problem, and you might find the exact fix you need. diff --git a/docs/index.md b/docs/index.md index 80277ffd2..de3698f1a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,7 +34,7 @@ Fast-LLM isn't just another library, **it's a platform for powering the next gen Fast-LLM offers all the capabilities you need to accelerate your LLM training and **push the boundaries of what's possible**: -- **🚀 Speed Like No Other:** Achieve record-breaking training throughput with Fast-LLM. For instance, train Mistral-7B at **10,350 tokens/s/GPU** on a 4-node cluster with 32 H100 GPUs (batch size 64, sequence length 8k). Our optimized kernels, advanced parallelism, and memory-efficient techniques drastically reduce training time and cost. +- **🚀 Speed Like No Other:** Achieve record-breaking training throughput with Fast-LLM. For instance, train Mistral-7B at **10,350 tokens/s/GPU** on a 4-node cluster with 32 H100 GPUs (micro-batch size 8k tokens, total batch size 256k tokens). Our optimized kernels, advanced parallelism, and memory-efficient techniques drastically reduce training time and cost. - **📡 Unmatched Scalability:** Seamlessly scale from a single GPU to large compute clusters. Fast-LLM supports 3D parallelism (data, tensor, and pipeline), sequence length parallelism, and ZeRO-1,2,3 techniques for maximum memory efficiency. Scale to the size you need without sacrificing performance. diff --git a/docs/overrides/hooks/generate_config_docs_hook.py b/docs/overrides/hooks/generate_config_docs_hook.py new file mode 100644 index 000000000..39af6fd05 --- /dev/null +++ b/docs/overrides/hooks/generate_config_docs_hook.py @@ -0,0 +1,27 @@ +"""MkDocs hook: regenerate config reference docs before each build.""" + +import importlib.util +import pathlib +import sys + +_REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.parent +_SCRIPT = _REPO_ROOT / "tools" / "generate_config_docs.py" + +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + + +def _load_gen(): + spec = importlib.util.spec_from_file_location("generate_config_docs", _SCRIPT) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def on_pre_build(config) -> None: # noqa: ANN001 + """Regenerate config reference markdown before the build processes files.""" + gen = _load_gen() + # Regenerate pages but do not update mkdocs.yaml — nav must be updated + # manually by running `python tools/generate_config_docs.py` when config + # classes are added or modules are restructured. + gen.generate(update_nav=False, verbose=False) diff --git a/docs/quick-start.md b/docs/quick-start.md index 20fc1a2b1..68ae9c056 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,7 +11,7 @@ To follow this guide, you'll need: - **Hardware**: At least one NVIDIA GPU, preferably with Ampere architecture or newer. Note that this tutorial is designed for 80 GB A100s or H100 GPUs, and some adjustments are needed to run it with less memory or an earlier architecture. - **Software**: Depending on your setup, you'll need one of the following: - **Docker**: If you're using the prebuilt Docker image on your local machine. - - **Python 3.10**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. + - **Python 3.12**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. - **Cluster Setup**: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments. ## 🏗 Step 1: Initial Setup @@ -69,7 +69,7 @@ Now, select the compute environment that matches your setup or preferred workflo Install PyTorch and pybind11 to meet Fast-LLM's requirements: ```bash - pip install pybind11 "torch>=2.2.2" + pip install pybind11 "torch>=2.9.0" ``` 4. **Install NVIDIA APEX**: @@ -86,7 +86,7 @@ Now, select the compute environment that matches your setup or preferred workflo Finally, install Fast-LLM along with its remaining dependencies, including [FlashAttention-2](https://github.com/Dao-AILab/flash-attention): ```bash - pip install --no-build-isolation "git+https://github.com/ServiceNow/Fast-LLM.git#egg=fast_llm[CORE,OPTIONAL,DEV]" + pip install --no-build-isolation "fast-llm[CORE,OPTIONAL] @ git+https://github.com/ServiceNow/Fast-LLM.git" ``` 6. **Verify the Installation**: @@ -220,7 +220,7 @@ Choose based on your goals for this tutorial. git clone https://huggingface.co/meta-llama/Llama-3.1-8B ./fast-llm-tutorial/pretrained-model ``` -## 📚 Step 3: Prepare the Training Data +## 📚 Step 4: Prepare the Training Data For this tutorial, we'll use text from the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run! @@ -471,7 +471,7 @@ Fast-LLM ships with a `prepare` command that will download and preprocess the da You can follow the job's progress by running `kubectl get pods` and checking the logs with `kubectl logs fast-llm-prepare-master-0`. -## ⚙️ Step 4: Configure Fast-LLM +## ⚙️ Step 5: Configure Fast-LLM Next, we'll create a configuration file for Fast-LLM. @@ -481,7 +481,7 @@ Next, we'll create a configuration file for Fast-LLM. !!! warning "Micro-Batch Size" - The `micro_batch_size` in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease the `sequence_length` to reduce the memory footprint. + The `micro_batch_size` in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease `maximum_document_length` under `data:` to reduce the memory footprint. Save the following as `fast-llm-tutorial/train-config.yaml`: @@ -506,31 +506,31 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: project_name: fast-llm-tutorial group_name: Small entity_name: null - batch: - micro_batch_size: 60 # (4)! - sequence_length: 1024 - batch_size: 480 # (5)! data: datasets: training: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (5)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (5)! + micro_batch_size: 61440 # (4)! + maximum_document_length: 1024 optimizer: learning_rate: base: 6.0e-04 pretrained: - format: llama # (7)! + format: llama # (6)! path: fast-llm-tutorial/pretrained-model - model_weights: no # (8)! + model_weights: no # (7)! model: base_model: - transformer: - use_flash_attention: yes # (9)! + decoder: + block: + mixer: + use_flash_attention: yes # (8)! distributed: - training_dtype: bf16 # (10)! + compute_dtype: bf16 # (9)! run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -538,13 +538,12 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: 1. For the small run, we'll stop after 100 iterations. 2. The trained model will be saved in `Transformers` Llama format to `fast-llm-tutorial/experiment/export/llama/100` at the end of the small run. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`. 3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section. - 4. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well. - 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch. - 6. Location of the dataset metadata files generated in Step 4. - 7. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`. - 8. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory). - 9. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. - 10. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. + 4. Adjust the micro-batch size based on GPU memory. For SmolLM2-135M with a maximum document length of 1024 tokens and a 80GB GPU, a `micro_batch_size` of 61440 tokens should work well. At 1024 tokens per document, this corresponds to about 500,000 tokens per batch on 8 GPUs. + 5. Location of the dataset metadata files generated in Step 4. + 6. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`. + 7. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory). + 8. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. + 9. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. === "Big" @@ -563,7 +562,6 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (2)! format: llama interval: 20_000 @@ -571,59 +569,56 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: project_name: fast-llm-tutorial group_name: Big entity_name: null - batch: - micro_batch_size: 2 # (4)! - sequence_length: 4096 - batch_size: 512 # (5)! data: datasets: training: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (5)! validation: type: file - path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! - optimizer: # (7)! + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (5)! + micro_batch_size: 8192 # (4)! + maximum_document_length: 4096 + optimizer: # (6)! weight_decay: 0.1 beta_1: 0.9 beta_2: 0.95 - learning_rate: # (8)! + learning_rate: # (7)! base: 6.0e-04 minimum: 6.0e-05 decay_style: cosine decay_iterations: 100_000 warmup_iterations: 2000 pretrained: - format: llama # (9)! + format: llama # (8)! path: fast-llm-tutorial/pretrained-model - model_weights: yes # (10)! + model_weights: yes # (9)! model: base_model: - transformer: - use_flash_attention: yes # (11)! - cross_entropy_impl: fused # (12)! + decoder: + block: + mixer: + use_flash_attention: yes # (10)! multi_stage: - zero_stage: 2 # (13)! + zero_stage: 2 # (11)! distributed: - training_dtype: bf16 # (14)! + compute_dtype: bf16 # (12)! run: experiment_dir: fast-llm-tutorial/experiment ``` - 1. Total number of training tokens will be approximately 210B: 100,000 iterations * 512 * 4096 tokens per batch. + 1. Total number of training tokens will be approximately 26B: 100,000 iterations × 32 GPUs × 8,192 tokens per micro-batch. 2. A permanent model checkpoint in `Transformers` Llama format will be saved to `fast-llm-tutorial/experiment/export/llama/[iteration]/` every 20,000 iterations. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`. 3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section. - 4. Adjust the number of sequences per GPU based on GPU memory. Considering a 4k token sequence length and 80GB GPUs, a `micro_batch_size` of 1 should work well. - 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 4k tokens per sequence, 512 corresponds to about 2.1 million tokens per batch. - 6. Location of the dataset metadata file generated in Step 4. - 7. These are good default optimizer settings for training models. - 8. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. - 9. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. - 10. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. - 11. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. - 12. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code. - 13. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. - 14. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. + 4. Adjust the micro-batch size based on GPU memory. Considering a maximum document length of 4096 tokens and 80GB GPUs, a `micro_batch_size` of 8192 tokens should work well. + 5. Location of the dataset metadata file generated in Step 4. + 6. These are good default optimizer settings for training models. + 7. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. + 8. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. + 9. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. + 10. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. + 11. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. + 12. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. ## 🔑 (Optional) Step 6: Add Your Weights & Biases API Key diff --git a/docs/recipes/continue-training.md b/docs/recipes/continue-training.md index d7df7a196..6c1e347ce 100644 --- a/docs/recipes/continue-training.md +++ b/docs/recipes/continue-training.md @@ -48,15 +48,12 @@ This is not much different from a pretraining config. We will: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (1)! format: llama interval: 20_000 - batch: - micro_batch_size: 2 - sequence_length: 4096 - batch_size: 256 data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -80,13 +77,14 @@ This is not much different from a pretraining config. We will: model_weights: yes # (5)! model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/Llama-3.1-8B-cpt ``` @@ -107,15 +105,12 @@ This is not much different from a pretraining config. We will: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: # (1)! format: qwen2 interval: 20_000 - batch: - micro_batch_size: 1 - sequence_length: 8192 - batch_size: 256 data: + micro_batch_size: 8192 + maximum_document_length: 8192 datasets: training: type: file @@ -139,13 +134,14 @@ This is not much different from a pretraining config. We will: model_weights: yes # (5)! model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/qwen-2.5-7B-cpt ``` diff --git a/docs/recipes/data-preparation.md b/docs/recipes/data-preparation.md index be0f8ef00..b3e3274f2 100644 --- a/docs/recipes/data-preparation.md +++ b/docs/recipes/data-preparation.md @@ -12,7 +12,7 @@ For this guide, you would need: - **Software**: Depending on your setup, you'll need one of the following: - **Docker**: If you're using the prebuilt Docker image on your local machine. - - **Python 3.10**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. + - **Python 3.12**: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine. - **Cluster Setup**: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments. ## 📚 Step 1: Download the dataset from Huggingface @@ -104,7 +104,7 @@ Fast-LLM's prepare command processes the dataset by tokenizing and saving it in === "Custom Installation" - Please follow the instructions in the [Quick-Start guide](quick-start/#step-1-initial-setup-custom-installation) to set up Fast-LLM in your environment. + Please follow the instructions in the [Quick-Start guide](quick-start/#step-1-initial-setup) to set up Fast-LLM in your environment. Then, run the following command: diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md index d6d2333e1..77ea609a2 100644 --- a/docs/recipes/generate.md +++ b/docs/recipes/generate.md @@ -37,8 +37,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Optional: updates to Fast-LLM config before loading the model updates = { - ("base_model", "transformer", "use_flash_attention"): True, - ("distributed", "training_dtype"): "bf16" + ("base_model", "decoder", "block", "mixer", "use_flash_attention"): True, + ("distributed", "compute_dtype"): "bf16" } # Load the model from the checkpoint with the given configuration diff --git a/docs/recipes/instruction-finetuning.md b/docs/recipes/instruction-finetuning.md index 2c58a987d..0e28b7dc8 100644 --- a/docs/recipes/instruction-finetuning.md +++ b/docs/recipes/instruction-finetuning.md @@ -107,7 +107,7 @@ splits: ## ⚙️ Step 4: Configure Fast-LLM -It's time to configure the Fast-LLM training config. This is very similar to [Quick Start](../quick-start.md) with two additional options, namely, `truncate_documents` and `cross_document_attention` which are important for improving the task performance of instruction-tuned models. +It's time to configure the Fast-LLM training config. This is very similar to [Quick Start](../quick-start.md) with one additional option, namely `truncate_documents`, which is important for improving the task performance of instruction-tuned models. ```yaml training: @@ -124,16 +124,12 @@ training: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: llama interval: 1000 -batch: - micro_batch_size: 1 - sequence_length: 4096 - batch_size: 32 - cross_document_attention: no # (1)! data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -141,7 +137,7 @@ data: validation: type: file path: ./sft-tutorial/tokenized/Llama-3.1-8B/fast_llm_config_validation.yaml - truncate_documents: no # (2)! + truncate_documents: no # (1)! sampling: use_loss_masking_spans: yes optimizer: @@ -160,19 +156,19 @@ pretrained: model_weights: yes model: base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused + decoder: + block: + mixer: + use_flash_attention: yes multi_stage: zero_stage: 3 distributed: timeout: 3600 - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: ./sft-tutorial/llama-3.1-8b-instruct-magpie ``` -1. Prevents paying attention to other documents in a packed sequence -2. Avoids truncating documents to fit into a packed sequence and starts a new sequence instead. Documents longer than sequence length will be skipped altogether. +1. Avoids truncating documents to fit into a packed sequence and starts a new sequence instead. Documents longer than sequence length will be skipped altogether. Launching the training run is similar to Step 7 in the [Quick Start](../quick-start.md) guide. diff --git a/docs/recipes/train.md b/docs/recipes/train.md index efdf6111b..224fd23a1 100644 --- a/docs/recipes/train.md +++ b/docs/recipes/train.md @@ -28,15 +28,12 @@ Let's start from the following training configuration: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: llama interval: 20_000 - batch: - micro_batch_size: 2 - sequence_length: 4096 - batch_size: 256 data: + micro_batch_size: 4096 + maximum_document_length: 4096 datasets: training: type: file @@ -55,12 +52,10 @@ Let's start from the following training configuration: decay_iterations: 100_000 warmup_iterations: 2000 model: - base_model: - cross_entropy_impl: fused multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -80,15 +75,12 @@ Let's start from the following training configuration: checkpoint: interval: 1000 keep: 5 - test_iters: 0 export: format: qwen2 interval: 20_000 - batch: - micro_batch_size: 1 - sequence_length: 8192 - batch_size: 256 data: + micro_batch_size: 8192 + maximum_document_length: 8192 datasets: training: type: file @@ -107,12 +99,10 @@ Let's start from the following training configuration: decay_iterations: 100_000 warmup_iterations: 2000 model: - base_model: - cross_entropy_impl: fused multi_stage: zero_stage: 2 distributed: - training_dtype: bf16 + compute_dtype: bf16 run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -155,47 +145,72 @@ Alternatively, we define the model architecture ourselves as follows: ```yaml model: base_model: - tie_word_embeddings: false - use_position_embeddings: false - vocab_size: 128256 - transformer: - activation_type: silu - add_linear_biases: false - ffn_hidden_size: 14336 - gated: true - head_groups: 8 - hidden_size: 4096 # (1)! - kv_channels: 128 + tied_embedding_weight: false + hidden_size: 4096 # (1)! + embeddings: + vocab_size: 128256 + decoder: + num_blocks: 32 + block: + mixer: + heads: 32 + head_groups: 8 + head_size: 128 + add_linear_biases: false + rotary: + type: llama3 + theta: 500_000 + mlp: + intermediate_size: 14336 + gated: true + activation: silu + add_linear_biases: false + normalization: + type: rms_norm + head: normalization: type: rms_norm - num_attention_heads: 32 - num_layers: 32 - rotary: - type: llama3 - theta: 500_000 ``` === "Qwen 2.5 7B" ```yaml model: base_model: - tie_word_embeddings: false - use_position_embeddings: false - vocab_size: 152064 - transformer: - activation_type: silu - add_linear_biases: only_attn_qkv - ffn_hidden_size: 18944 - gated: true - head_groups: 4 - hidden_size: 3584 # (1)! + tied_embedding_weight: false + hidden_size: 3584 # (1)! + embeddings: + vocab_size: 152064 + decoder: + num_blocks: 28 + block: + mixer: + heads: 28 + head_groups: 4 + head_size: 128 + add_linear_biases: false + query_layer: + bias: + enabled: true + key_layer: + bias: + enabled: true + value_layer: + bias: + enabled: true + rotary: + type: default + theta: 1_000_000 + mlp: + intermediate_size: 18944 + gated: true + activation: silu + add_linear_biases: false + normalization: + type: rms_norm + epsilon: 1e-06 + head: normalization: type: rms_norm epsilon: 1e-06 - num_attention_heads: 28 - num_layers: 28 - rotary: - type: default - theta: 1_000_000 ``` 1. Hidden-size/num-layers will be used to provide good defaults for weight initialization std. diff --git a/docs/user_guide/parallelism.md b/docs/user_guide/parallelism.md index 406908cd8..1b34ff40d 100644 --- a/docs/user_guide/parallelism.md +++ b/docs/user_guide/parallelism.md @@ -2,6 +2,181 @@ title: Parallelism --- -!!! warning +Fast-LLM supports four complementary parallelism strategies that can be combined to train models at any scale. This guide explains each strategy, how to configure it, and when to use it. - Looking for the parallelism guide? It's on its way, come back soon! +For background on memory sharding (ZeRO), see the [Multi-Stage guide](multi-stage.md). The strategies below focus on how the computation itself is distributed. + +--- + +## Overview + +| Strategy | Config key | What it splits | Primary benefit | +| --- | --- | --- | --- | +| Data parallelism | `distributed.batch_data_parallel` (derived) | Batch across GPUs | Throughput | +| Tensor parallelism | `distributed.tensor_parallel` | Layers horizontally (weight matrices) | Model memory | +| Pipeline parallelism | `distributed.pipeline_parallel` | Layers vertically (by depth) | Model memory | +| Sequence data parallelism | `distributed.sequence_data_parallel` | Sequence dimension across GPUs | Activation memory | + +These strategies compose: a 64-GPU run might use 4-way tensor parallelism, 4-way pipeline parallelism, and 4-way data parallelism simultaneously. + +--- + +## Data Parallelism + +Data parallelism replicates the full model on every GPU and processes different micro-batches in parallel. Gradients are averaged across all replicas before the optimizer step. + +Fast-LLM infers the data-parallel degree automatically: + +```text +data_parallel = world_size / (tensor_parallel × pipeline_parallel) +``` + +You do not set `data_parallel` directly — it falls out from the other settings. + +Data parallelism is the baseline scaling strategy: it increases training throughput proportionally to the number of replicas (assuming sufficient batch size) and adds no memory pressure for the model itself. Any GPUs not used by tensor or pipeline parallelism are automatically used for data parallelism. Its only constraint is that the global batch size grows with the number of replicas. + +--- + +## Tensor Parallelism + +Tensor parallelism (sometimes called *intra-layer model parallelism*) shards individual weight matrices across GPUs within a group. Each GPU holds a slice of the weight and computes its portion of the output; an all-reduce synchronizes results. + +```yaml +model: + distributed: + tensor_parallel: 4 # shard weights across 4 GPUs +``` + +Valid values are 1 (disabled) or any integer that divides `world_size`. In practice, powers of two work best (1, 2, 4, 8). + +**When to use:** When a single model layer (e.g. attention projection or MLP) does not fit on one GPU, or when activation memory from large hidden dimensions is the bottleneck. Tensor parallelism requires high-bandwidth interconnects (NVLink within a node) because it adds an all-reduce communication on every forward *and* backward pass of every sharded layer. + +**Rule of thumb:** Keep tensor parallelism within a node (≤ 8 GPUs). Crossing node boundaries with tensor parallelism incurs heavy inter-node communication overhead. + +### Sequence-Tensor Parallelism + +When tensor parallelism is active, you can enable an additional optimization that keeps activations distributed along the sequence dimension between layers, rather than replicating the full sequence on every GPU: + +```yaml +model: + distributed: + tensor_parallel: 4 + sequence_tensor_parallel: true +``` + +With this enabled, each GPU holds only `1 / tensor_parallel` of the sequence at any given time. Activations are gathered before layers that need the full sequence, and scatter-reduced afterward. This reduces peak activation memory per GPU by a factor of `tensor_parallel`, at the same total communication cost as without the option. It is recommended whenever `tensor_parallel > 1`. + +--- + +## Pipeline Parallelism + +Pipeline parallelism splits the model by depth: each GPU holds a consecutive block of layers. Activations flow forward from stage to stage; gradients flow backward. Multiple micro-batches can be in-flight simultaneously to keep all stages busy. + +```yaml +model: + distributed: + pipeline_parallel: 4 # split model across 4 GPUs +``` + +The number of layers per pipeline stage is controlled by how the total layer count divides across stages (see the [Multi-Stage guide](multi-stage.md) for `layers_per_stage`). + +Pipeline parallelism works well across slow interconnects (e.g. InfiniBand between nodes) because point-to-point sends only occur at stage boundaries, and their volume is proportional to the activation size of a single layer rather than the full model. + +### Scheduling micro-batches + +To hide pipeline bubbles, Fast-LLM uses *breadth-first* scheduling: it keeps several micro-batches in flight simultaneously so each stage always has work to do. + +```yaml +schedule: + micro_batch_splits: 1 # sub-divide each micro-batch along the sequence + breadth_first_micro_batches: 4 # interleave this many micro-batches across stages + depth_first_micro_batches: 1 # gradient accumulation steps within one stage +``` + +A larger `breadth_first_micro_batches` reduces idle (bubble) time but increases activation memory, since activations from all in-flight micro-batches are held simultaneously. Start with a value equal to the number of pipeline stages. + +!!! note + The total number of micro-batches per step (`breadth_first_micro_batches × depth_first_micro_batches`) must be at least equal to `pipeline_parallel`. Otherwise some pipeline stages will be idle for the entire step. + +**When to use:** When the model is too large to fit on a single node, or when you want to spread memory across nodes without incurring the per-layer all-reduce cost of tensor parallelism. Pipeline parallelism is naturally suited to slow cross-node links. + +--- + +## Sequence Data Parallelism + +Sequence data parallelism sub-divides the data-parallel group along the sequence dimension. Instead of each GPU processing an independent sequence in full, a group of GPUs collectively processes one sequence by splitting it into chunks. + +```yaml +model: + distributed: + sequence_data_parallel: 2 # 2 GPUs share each sequence +``` + +`sequence_data_parallel` must divide `data_parallel`. The effective batch dimension is: + +```text +batch_data_parallel = data_parallel / sequence_data_parallel +``` + +**When to use:** When training on very long sequences and activation memory is the primary constraint. Sequence data parallelism reduces per-GPU activation memory roughly in proportion to its value, at the cost of added gradient synchronization along the sequence dimension. + +--- + +## Combining Strategies + +All four strategies compose freely. A typical large-scale configuration looks like: + +```yaml +model: + distributed: + tensor_parallel: 4 # within-node weight sharding + sequence_tensor_parallel: true # sequence-split activations + pipeline_parallel: 8 # cross-node layer sharding + sequence_data_parallel: 1 # each sequence lives on one GPU + # data_parallel is inferred: world_size / (4 × 8) = e.g. 4 for a 128-GPU run + +schedule: + breadth_first_micro_batches: 8 # match pipeline depth +``` + +### Choosing a configuration + +Start with the simplest setup that fits the model in memory and scale from there: + +1. **Single GPU**: no parallelism needed. +2. **Multi-GPU, single node**: add `tensor_parallel` up to the number of GPUs (typically 8). Enable `sequence_tensor_parallel` alongside it. +3. **Multi-node**: add `pipeline_parallel` across nodes. Keep `tensor_parallel` within nodes. +4. **Very long sequences**: add `sequence_data_parallel` to reduce activation memory. +5. **Still out of memory**: increase `zero_stage` (see [Multi-Stage guide](multi-stage.md)). + +### Rank ordering + +By default, Fast-LLM assigns global ranks in tensor → data → pipeline order. If pipeline stages are on different sockets of the same machine, setting `pipeline_first: true` can improve NUMA locality: + +```yaml +model: + distributed: + pipeline_first: true +``` + +--- + +## Configuration Reference + +All distributed settings live under `model.distributed`: + +| Field | Default | Description | +| --- | --- | --- | +| `tensor_parallel` | `1` | Size of the tensor-parallel group | +| `pipeline_parallel` | `1` | Number of pipeline stages | +| `sequence_data_parallel` | `1` | Sub-divide data-parallel group by sequence | +| `sequence_tensor_parallel` | `false` | Enable sequence-parallel activation splitting in TP layers | +| `pipeline_first` | `false` | Swap data and pipeline rank ordering for NUMA locality | + +Schedule settings live under `schedule`: + +| Field | Default | Description | +| --- | --- | --- | +| `breadth_first_micro_batches` | `1` | Micro-batches in flight simultaneously (reduces pipeline bubble) | +| `depth_first_micro_batches` | `1` | Gradient accumulation steps within a stage | +| `micro_batch_splits` | `1` | Sub-divide each micro-batch along the sequence dimension | diff --git a/examples/fast-llm.pytorchjob.yaml b/examples/fast-llm.pytorchjob.yaml index 13a7a4df8..03c51fcb9 100644 --- a/examples/fast-llm.pytorchjob.yaml +++ b/examples/fast-llm.pytorchjob.yaml @@ -42,7 +42,7 @@ spec: --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral-4-node-benchmark.yaml + --config examples/mistral.yaml env: - name: NCCL_DEBUG value: "INFO" @@ -102,7 +102,7 @@ spec: --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral-4-node-benchmark.yaml + --config examples/mistral.yaml env: - name: NCCL_DEBUG value: "INFO" diff --git a/examples/fast-llm.sbat b/examples/fast-llm.sbat index 13a966ec3..8099bc141 100644 --- a/examples/fast-llm.sbat +++ b/examples/fast-llm.sbat @@ -34,4 +34,4 @@ srun --gpus-per-node=$SLURM_GPUS_PER_NODE \ --rdzv_conf=timeout=3600 \ --no_python \ fast-llm train gpt \ - --config examples/mistral_4_node_benchmark.yaml" + --config examples/mistral.yaml" diff --git a/fast_llm/config.py b/fast_llm/config.py index 6b947bce5..61e22737a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -40,17 +40,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): _AUTO_VALIDATE = self._old_value -class UpdateType(str, enum.Enum): +class UpdateType(enum.StrEnum): # Override entries no matter what they contain. override = "override" # Override atomic entries and lists, but update dicts recursively by setting or overriding only the specified entries. update = "update" -class FieldHint: +class FieldHint(enum.StrEnum): """ A label defined for each config field, to let the user and some methods know how important each field is. - * core: """ core = "core" @@ -127,7 +126,7 @@ def __init__( *, desc: str | None = None, doc: str | None = None, - hint: str = FieldHint.unknown, + hint: FieldHint = FieldHint.unknown, # Validation function on the field to satisfy. # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). @@ -164,9 +163,9 @@ def __init__( # self.auto_instantiate = auto_instantiate -class FieldUpdate(dict): +class FieldOverride(dict): """ - Specify some entries in the field that should be updated from the base class. + Override some entries in the field inherited from the base class. Useful for changing the default or description in a derived class. Processed in `__init_subclass__`. """ @@ -185,20 +184,6 @@ def valid(x): return valid -def test_field(fn, *args, **kwargs): - """ - Helper function to define a condition that a config field should satisfy, - in the form of a function that returns a boolean. - """ - - def valid(x): - if not fn(x, *args, **kwargs): - raise ValueError(fn, x, args, kwargs) - return x - - return valid - - def process_field(fn, *args, **kwargs): """ Helper function to apply non-standard processing during validation, @@ -536,7 +521,7 @@ def _validate_array(cls, value, type_, name: str): ) else: if not issubclass(origin, tuple) and len(args) != 1: - FieldTypeError(f"Invalid array specification") + raise FieldTypeError(f"Invalid array specification") new_value = origin( cls._validate_nested(value_, args[0], f"{name}[{i}]", None, errors, True) for i, value_ in enumerate(value) @@ -649,8 +634,8 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: - # Exclude class variables and derived fields unless requested explicitly. + if field is not None and (field._field_type != dataclasses._FIELD or (not field.init and not all_fields)): + # Always exclude class variables; exclude derived (init=False) fields unless all_fields=True. return explicit_field = ( field is None @@ -865,7 +850,7 @@ def _from_dict_array(cls, value, type_, strict: bool): new_value += value[len(value) - len(new_value) :] else: if not issubclass(origin, tuple) and len(args) != 1: - FieldTypeError(f"Invalid array specification") + raise FieldTypeError(f"Invalid array specification") new_value = origin(cls._from_dict_nested(value_, args[0], strict) for i, value_ in enumerate(value)) return new_value @@ -973,7 +958,7 @@ def __init_subclass__(cls): for name in list(cls.__dict__): value = getattr(cls, name) - if isinstance(value, FieldUpdate): + if isinstance(value, FieldOverride): # In case of multiple inheritance, the base class field may not appear in `cls.__dataclass_fields__`. # so we iterate over superclasses following mro and use the first match. base_class_field = None diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 16f7d92c8..a71cbc306 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -164,9 +164,7 @@ def send( assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # send not supported for gloo on GPU. - tensor_cpu = tensor.cpu() - group.send([tensor_cpu], dst, tag).wait() - tensor.copy_(tensor_cpu) + group.send([tensor.cpu()], dst, tag).wait() return None work = group.send([tensor], dst, tag) if async_op: @@ -182,7 +180,7 @@ def recv( assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # recv not supported for gloo on GPU. - tensor_cpu = tensor.cpu() + tensor_cpu = tensor.new_empty(device="cpu") group.recv([tensor_cpu], src, tag).wait() tensor.copy_(tensor_cpu) return None diff --git a/fast_llm/csrc/data.cpp b/fast_llm/csrc/data.cpp index a1a24c7c9..1696af449 100644 --- a/fast_llm/csrc/data.cpp +++ b/fast_llm/csrc/data.cpp @@ -181,7 +181,7 @@ py::array build_padded_token_cumsum(const py::array_t& sizes_, }); const auto byte_size = sizeof(int64_t); - return py::array(std::vector{token_cumsum.size()}, + return py::array(std::vector{static_cast(token_cumsum.size())}, {byte_size}, token_cumsum_result, free_when_done); diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 78bc20636..98444d149 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,7 +1,7 @@ import enum -class MultiprocessingContext(str, enum.Enum): +class MultiprocessingContext(enum.StrEnum): # Fast but risk of segfaults due to interactions with triton # (for example https://github.com/openai/triton/issues/2088). fork = "fork" diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 415ddc195..9253d0311 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -112,9 +112,9 @@ def get_iterator( ), num_workers=num_workers, prefetch_factor=prefetch_factor, - pin_memory=True, + pin_memory=self._distributed_config.use_cuda, collate_fn=functools.partial(self._collate_fn, dataset_name=dataset_name, preprocess=preprocess), - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + multiprocessing_context=self._config.multiprocessing_context if num_workers > 0 else None, ) if self._datasets[dataset_name].requires_broadcast: data_loader = DistributedDataLoaderWrapper( diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index bfe1509d6..bef80f468 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -18,7 +18,9 @@ logger = logging.getLogger(__name__) -class ShufflingType(str, enum.Enum): +class ShufflingType(enum.StrEnum): + """Strategy for shuffling dataset samples across training epochs.""" + # Shuffle all epochs together. Not extendable. full = "full" # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. @@ -79,6 +81,13 @@ class SamplingConfig(SamplingConfigBase): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. predicted_tokens: int = Field(default=1) + token_cumsum_rate: int = Field( + default=10, + desc="Sampling interval for the token cumulative sum index." + " A smaller value reduces per-sample seek time at the cost of a larger index.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) cache_directory: pathlib.Path | None = Field(default=None) dataset_name: str = Field(default="dataset") world_size: int = Field(default=1) @@ -108,6 +117,8 @@ def sampling_maximum_document_length(self) -> int: @config_class() class DatasetConfig[DocumentType: Document](Config): + """Abstract base configuration for all dataset types.""" + _abstract: typing.ClassVar[bool] = True @@ -123,6 +134,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class SamplableDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): + """Abstract configuration for datasets that can be built and then sampled.""" + def build(self) -> SamplableDataset[DocumentType]: raise NotImplementedError() @@ -132,6 +145,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class IndexedDatasetConfig[DocumentType: Document](SamplableDatasetConfig[DocumentType]): + """Abstract configuration for indexed datasets that support random access by index.""" + def build(self) -> "IndexedDataset[DocumentType]": raise NotImplementedError() @@ -204,6 +219,8 @@ def build(self) -> "DatasetSlice": @config_class(dynamic_type={SampledDatasetConfig: "blended"}) class BlendedDatasetConfig[DocumentType: Document](SampledDatasetConfig[DocumentType]): + """Mixes multiple datasets together, sampling from each according to specified weights.""" + _abstract = False name: str = Field( default="blended", @@ -258,6 +275,8 @@ def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) @config_class() class RedisConfig(Config): + """Configuration for connecting to a Redis server (host, port, timeout).""" + REDIS_FIELD: typing.ClassVar[str] = "data" REDIS_FIELD_B: typing.ClassVar[bytes] = REDIS_FIELD.encode() REDIS_GROUP_NAME: typing.ClassVar[str] = "fast_llm_group" diff --git a/fast_llm/data/dataset/memmap/config.py b/fast_llm/data/dataset/memmap/config.py index cc8665204..d043a20af 100644 --- a/fast_llm/data/dataset/memmap/config.py +++ b/fast_llm/data/dataset/memmap/config.py @@ -449,8 +449,8 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi [metadata_["chosen_spans"] for metadata_ in metadata] ) if "rejected_spans" in metadata[0]: - out["image_patches"] = RangeReaderConfig.blend_metadata( - [metadata_["image_patches"] for metadata_ in metadata] + out["rejected_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["rejected_spans"] for metadata_ in metadata] ) if "image_patches" in metadata[0]: out["image_patches"] = PatchReaderConfig.blend_metadata( diff --git a/fast_llm/data/dataset/memmap/memmap.py b/fast_llm/data/dataset/memmap/memmap.py index d44ed9093..69175c893 100644 --- a/fast_llm/data/dataset/memmap/memmap.py +++ b/fast_llm/data/dataset/memmap/memmap.py @@ -38,7 +38,7 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: config_bytes = stream.read(int.from_bytes(stream.read(4), signed=False)) reader_config = MemmapIndexDatasetReaderConfig.from_dict(json.loads(config_bytes.decode("utf-8"))) - self._memmap = np.memmap(self._path, mode="r") + self._memmap = np.memmap(self._path, mode="c") self._reader = reader_config.get_reader(memoryview(self._memmap)) def __getstate__(self) -> tuple[str, pathlib.Path]: diff --git a/fast_llm/data/dataset/memmap/token.py b/fast_llm/data/dataset/memmap/token.py index 3e8b86a3c..84d34613b 100644 --- a/fast_llm/data/dataset/memmap/token.py +++ b/fast_llm/data/dataset/memmap/token.py @@ -51,10 +51,11 @@ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dic def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: - left = torch.searchsorted(cumsum, value, side="right") - if left == len(cumsum): - return left.item() - return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + right = torch.searchsorted(cumsum, value, side="right") + if right == len(cumsum): + return right.item() + left = cumsum[right - 1].item() if right > 0 else 0 + return right.item() + 1 if (value - left) / (cumsum[right].item() - left) > 0.5 else right.item() class TokenWriter(MemmapWriter): diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index bffa9ff66..db123a354 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -63,10 +63,6 @@ def _lazy_load(self): self._array = np.load(self._path, mmap_mode="r") -# TODO: Make configurable? -TOKEN_CUMSUM_RATE = 10 - - class SampledIndexedDataset[DocumentType: Document](SampledDataset[DocumentType]): """ A sampled dataset. @@ -253,9 +249,9 @@ def _sample(self) -> None: # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals (`token_cumsum_rate`). + # A larger rate reduces pre-computation overhead at the cost of more runtime scanning per sample. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::token_cumsum_rate]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( document_sizes, @@ -288,7 +284,7 @@ def _sample(self) -> None: ) self._token_cumsum_shuffled.save(token_cumsum_shuffled) self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( + document_shuffling[: (token_cumsum_shuffled.size + 1) * self._config.token_cumsum_rate].numpy( force=self._config.gpu ) ) @@ -298,10 +294,12 @@ def _sample(self) -> None: def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._config.truncate_documents: # Create the output tensor. - out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) + out = sizes.new_empty(sizes.numel() // self._config.token_cumsum_rate + 1, dtype=dtype.torch) # Get partial sums for regular intervals, excluding the last incomplete interval. torch.sum( - sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), + sizes[: sizes.numel() - sizes.numel() % self._config.token_cumsum_rate].view( + -1, self._config.token_cumsum_rate + ), dim=1, out=out[1:], ) @@ -319,7 +317,9 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - return out.numpy(force=self._config.gpu), None else: # TODO: dynamically handle int64 or int32 in CPP - out = build_padded_token_cumsum(sizes.cpu().numpy(), self._config.sample_size, TOKEN_CUMSUM_RATE, offset) + out = build_padded_token_cumsum( + sizes.cpu().numpy(), self._config.sample_size, self._config.token_cumsum_rate, offset + ) num_tokens = out[-1] out = out[:-1][ : np.clip( @@ -358,7 +358,9 @@ def __getitem__(self, index: int) -> list[DocumentType]: # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 - document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset + document_sampling_index = ( + token_start_cumsum_index * self._config.token_cumsum_rate + token_start_array_document_offset + ) token_count = token_start_array[token_start_cumsum_index].item() diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index e3fce4eb3..ec8fe7bd1 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -3,6 +3,7 @@ import logging import time import typing +import warnings import redis import torch.utils.data @@ -33,20 +34,25 @@ class RedisStreamingDocumentData(Config): def _validate(self): # Decode message - if isinstance(self.tokens, bytes): - self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) - elif isinstance(self.tokens, (list, tuple)): - self.tokens = torch.tensor(self.tokens, dtype=torch.int64) + with warnings.catch_warnings(): + # The tensors are read-only in practice; the non-writable-buffer warning is expected. + warnings.simplefilter("ignore", UserWarning) + if isinstance(self.tokens, bytes): + self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) + elif isinstance(self.tokens, (list, tuple)): + self.tokens = torch.tensor(self.tokens, dtype=torch.int64) if isinstance(self.loss_masking_spans, str): self.loss_masking_spans = json.loads(self.loss_masking_spans) if isinstance(self.chosen_span, str): self.chosen_span = json.loads(self.chosen_span) if isinstance(self.rejected_span, str): self.rejected_span = json.loads(self.rejected_span) - if isinstance(self.old_log_probabilities, bytes): - self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) - elif isinstance(self.old_log_probabilities, (list, tuple)): - self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + if isinstance(self.old_log_probabilities, bytes): + self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) + elif isinstance(self.old_log_probabilities, (list, tuple)): + self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) super()._validate() if self.old_log_probabilities is not None: Assert.eq(len(self.old_log_probabilities), self.num_tokens) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 4cb529463..d0d634b63 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -97,7 +97,7 @@ class LayerBaseWithNamespace(LayerBase): TODO: Consider namespace for losses and metrics? """ - def __init__(self, layer: LayerBase, namespace: str = None): + def __init__(self, layer: LayerBase, namespace: str): super().__init__(layer._distributed_config) self._layer = layer self._namespace = namespace @@ -139,7 +139,7 @@ def _layers_with_namespace(self) -> list[Layer]: class LayerWithNamespace(LayerBaseWithNamespace, Layer): _layer: Layer - def __init__(self, layer: Layer, namespace: str = None): + def __init__(self, layer: Layer, namespace: str): super().__init__(layer, namespace) self.layer_count = self._layer.layer_count diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 2770e67a2..074412c9f 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -45,6 +45,8 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + else: + assert not isinstance(field, ModuleConfig) return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: @@ -57,7 +59,7 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: elif isinstance(value, (list, tuple, set)): return [self._serialize_architecture_field(value_) for value_ in value] elif isinstance(value, dict): - return {self._serialize_architecture_field(value_) for name, value_ in value.items()} + return {name: self._serialize_architecture_field(value_) for name, value_ in value.items()} else: return self._serialize_value(value) @@ -106,15 +108,10 @@ class ResourceUsageConfig: class ReductionType(enum.StrEnum): - """ - An enum to represent data types independently of third party libraries, - so we can swap them more easily and allow for lazy imports. - """ - - sum = "float64" - average = "float32" - minimum = "float16" - maximum = "bfloat16" + sum = "sum" + average = "average" + minimum = "minimum" + maximum = "maximum" @property def torch(self) -> "typing.Callable[[torch.Tensor], torch.Tensor]": diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 98303539e..190be62f1 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -7,7 +7,7 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldOverride, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -92,6 +92,8 @@ def load_fast_llm(self) -> bool: @config_class() class CheckpointConfigBase(Config): + """Abstract base configuration for all checkpoint operations, holding the checkpoint format.""" + _abstract = True # Note: the `format` may be a str when configuring from file or cli. # The actual class should be set through `setup` in a parent config validation. @@ -117,6 +119,8 @@ def setup(self, model_config: "FastLLMModelConfig| type[FastLLMModelConfig]") -> @config_class() class CheckpointStateConfigBase(CheckpointConfigBase): + """Abstract base configuration for checkpoint operations that include model weights and/or optimizer state.""" + _abstract = True # Defaults and descriptions are set in derived classes. model_weights: bool = Field(default=True, hint=FieldHint.feature) @@ -125,6 +129,8 @@ class CheckpointStateConfigBase(CheckpointConfigBase): @config_class() class CheckpointSaveConfigBase(CheckpointConfigBase): + """Abstract base configuration for saving checkpoints, with file-size and dtype options.""" + _abstract = True parameters_per_file: int = Field( default=2**32, @@ -141,9 +147,11 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): + """Configuration for saving model weights and/or optimizer state to a checkpoint.""" + _abstract = False - model_weights: bool = FieldUpdate(desc="Save the model weights.") - optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.") + model_weights: bool = FieldOverride(desc="Save the model weights.") + optimizer_state: bool = FieldOverride(desc="Save the optimizer state. Default: save if supported by the `format`.") def _validate(self) -> None: if self.optimizer_state is None and hasattr(self.format, "support_optimizer"): @@ -157,6 +165,8 @@ def _validate(self) -> None: @config_class() class CheckpointPathConfigBase(CheckpointConfigBase): + """Abstract base configuration for checkpoint operations that require a filesystem path and optional timeout.""" + _abstract = True path: pathlib.Path | None = Field( default=None, @@ -173,16 +183,22 @@ class CheckpointPathConfigBase(CheckpointConfigBase): @config_class() class CheckpointSaveMetadataConfig(CheckpointPathConfigBase): + """Configuration for saving checkpoint metadata (path and format) without weights or optimizer state.""" + _abstract = False @config_class() class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConfigBase): + """Full configuration for saving a checkpoint: path, format, weights, optimizer state, and file options.""" + _abstract = False @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): + """Configuration for loading checkpoint metadata, controlling which config sections are loaded.""" + _abstract = False # TODO: Set default to model? (Not backward compatible) load_config: ModelConfigType = Field( @@ -194,10 +210,12 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): + """Full configuration for loading a checkpoint: path, format, and which state to restore.""" + _abstract = False - model_weights: bool = FieldUpdate(desc="Load the model weights.") - optimizer_state: bool = FieldUpdate(default=False, desc="Load the optimizer state.") + model_weights: bool = FieldOverride(desc="Load the model weights.") + optimizer_state: bool = FieldOverride(default=False, desc="Load the optimizer state.") def _validate(self) -> None: super()._validate() diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 103d9488c..728877792 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -70,7 +70,7 @@ def _convert_model_partial( logger.info(f"Saving {output.format} checkpoint to {output.path}...") output.path.mkdir(parents=True, exist_ok=self.exist_ok) model.save_checkpoint(output) - (output.path / "ok").open("w") + (output.path / "ok").touch() logger.info(f"Done!") def run(self): @@ -120,7 +120,7 @@ def run(self): global_rename_map = {} file_count = 0 for step_path in step_paths: - step_index = json.load((step_path / index_filename).open("r")) + step_index = json.loads((step_path / index_filename).read_text()) if len(index) == 0: index.update(step_index) index["weight_map"] = weight_map @@ -141,7 +141,7 @@ def run(self): path = self.output.path / index_filename # Save the index. - json.dump(index, path.open("w"), indent=4) + path.write_text(json.dumps(index, indent=4)) # Copy the config (step_paths[0] / config_filename).rename(self.output.path / config_filename) @@ -158,5 +158,5 @@ def run(self): step_path.rmdir() # All good! - (self.output.path / "ok").open("w") + (self.output.path / "ok").touch() logger.info(f">>> All done!") diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index fecc35ef7..22782e49c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -32,17 +32,17 @@ class DistributedCheckpointHandler(CheckpointHandler): def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): serialized_metadata = metadata.to_dict() config.path.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + (config.path / "metadata.yaml").write_text(yaml.safe_dump(serialized_metadata)) @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: - return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) + return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").read_text())) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = metadata.to_dict() config.path.mkdir(parents=True, exist_ok=True) if self._model.config.distributed.rank == 0: - yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + (config.path / "metadata.yaml").write_text(yaml.safe_dump(serialized_metadata)) safetensors.torch.save_file( tensors={f"{shard_name}_shard": self._model.get_shard(shard_name) for shard_name in metadata.shards}, filename=config.path / f"rank_{self._model.config.distributed.rank}.safetensors", diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 270171755..8cdb779dd 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -58,11 +58,7 @@ def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadat path = config.path / f"{cls.base_file_name}.safetensors.index.json" logger.info(f"Saving index to {path}") # Save the index. - json.dump( - {"metadata": metadata, "weight_map": index}, - path.open("w"), - indent=4, - ) + path.write_text(json.dumps({"metadata": metadata, "weight_map": index}, indent=4)) def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata) -> dict: huggingface_config = self._export_config(self._model.config) @@ -145,7 +141,7 @@ def _load_weights( logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { config.path / path - for path in json.load((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).open("r"))[ + for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } @@ -155,7 +151,7 @@ def _load_weights( logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { config.path / path - for path in json.load((config.path / transformers.utils.WEIGHTS_INDEX_NAME).open("r"))[ + for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } @@ -169,7 +165,7 @@ def _load_weights( for key in f.keys(): yield key, "weights", f.get_slice(key) elif path.suffix == ".bin": - # TODO: Confirm that loading works with `weights_only=True` - yield from torch.load(path, weights_only=True) + for key, tensor in torch.load(path, weights_only=True).items(): + yield key, "weights", tensor else: raise NotImplementedError(f"Unknown file format for {path}") diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index d3f72a47c..9667fa98b 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -149,17 +149,12 @@ def _check_parameters(self, errors: list[str]) -> None: f' and shard "{shard_name}": loaded {counter}, expected {local_size}' ) - counter_ = counter # Accumulate in a list for global counter check. if ( not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 ) or stage.is_tied_weight_copy: # Ignore the counter from duplicate tensors. counter = 0 - if parameter_name == "layers.1.norm_1.weight": - logger.info( - f"Parameter {parameter_name} local {counter_} keep {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" - ) counters.append(counter) # Check for unexpected parameters. @@ -179,10 +174,6 @@ def _check_parameters(self, errors: list[str]) -> None: for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: counter = counters.pop(0) - if parameter_name == "layers.1.norm_1.weight": - logger.info( - f"Parameter {parameter_name} global {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" - ) parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 32eea2db6..8106d85dc 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -154,7 +154,7 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler): def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: path = config.path / f"metadata.yaml" logger.warning(f"Loading metadata from {path}") - return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) + return CheckpointMetadata.from_dict(yaml.safe_load(path.read_text())) @classmethod def _save_serialized_metadata( @@ -166,7 +166,7 @@ def _save_serialized_metadata( if "metadata" not in serialized_metadata: serialized_metadata["metadata"] = {} serialized_metadata["metadata"]["state_index"] = index - yaml.safe_dump(serialized_metadata, path.open("w")) + path.write_text(yaml.safe_dump(serialized_metadata)) @classmethod def _get_key(cls, parameter_name: str, shard_name: str) -> str: @@ -259,15 +259,15 @@ def _merge_index(self) -> None: if self._do_save and self._distributed_config.pipeline_parallel != 1: # Combine the indexes from all pipeline ranks. logger.info(f"Merging pipeline-parallel indexes.") - yaml.dump( - self._index, (self._config.path / f"index_{self._distributed_config.pipeline_rank}.yaml").open("w") + (self._config.path / f"index_{self._distributed_config.pipeline_rank}.yaml").write_text( + yaml.dump(self._index) ) safe_barrier(self._distributed.pipeline_group, "save state dict", timeout=self._config.timeout) self._index = {} if self._distributed_config.pipeline_rank == 0: for rank in range(self._distributed_config.pipeline_parallel): file_name = self._config.path / f"index_{rank}.yaml" - local_index = yaml.safe_load(file_name.open("r")) + local_index = yaml.safe_load(file_name.read_text()) for key, value in local_index.items(): assert key not in self._index, key self._index[key] = value diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index 2f12a45d2..0395324f6 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -46,7 +46,7 @@ def get_initializer(self) -> "Initializer": @config_class(dynamic_type={InitializationConfig: "fill"}) class FillInitializationConfig(InitializationConfig): """ - Normal initialization: normal(mean, std).clamp(min,max) + Fill initialization: fills the tensor with a constant value. """ _abstract = False @@ -88,7 +88,7 @@ class NormalInitializationConfig(InitializationConfig): ) max: float | None = Field( default=None, - desc="Min value for initialization clamping.", + desc="Max value for initialization clamping.", hint=FieldHint.optional, ) @@ -105,16 +105,14 @@ class UniformInitializationConfig(InitializationConfig): _abstract = False scale: float = Field( - default=None, desc="Initialization scale.", - hint=FieldHint.optional, + hint=FieldHint.core, valid=check_field(Assert.geq, 0), ) mean: float = Field( - default=None, + default=0.0, desc="Initialization mean.", hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), ) def get_initializer(self) -> "Initializer": diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 943b8de38..32deb4562 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -4,7 +4,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -73,7 +73,7 @@ class TensorLogsConfig(Config): default=8, desc="Maximum number of tensor values to print for each tensor when posting tensor logs to stdout.", hint=FieldHint.logging, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), + valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index c0910c09a..3e2b61120 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -38,6 +38,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): @config_class() class ParameterConfig(ModuleConfig): + _abstract = False initialization: InitializationConfig = Field( desc="If provided, override the default initialization method set by the parent layer.", hint=FieldHint.feature, @@ -50,9 +51,6 @@ class ParameterConfig(ModuleConfig): ) # TODO: Initialization, lr_scale - def _validate(self) -> None: - pass - def get_parameter( self, dims: tuple[TensorDim, ...], @@ -83,9 +81,6 @@ class OptionalParameterConfig(ParameterConfig): default=None, ) - def _validate(self) -> None: - pass - def get_parameter( self, dims: tuple[TensorDim, ...], @@ -97,8 +92,6 @@ def get_parameter( default_enabled: bool = False, peft: PeftConfig | None, ) -> "ParameterMeta|None": - pass - if (self.enabled is None and default_enabled) or self.enabled: return super().get_parameter( dims, diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index ab6f27489..e32d7cd46 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -144,9 +144,9 @@ def __init__( (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) run = len(list((self._experiment_directory / "runs").iterdir())) (self._experiment_directory / "runs" / str(run)).mkdir() - yaml.safe_dump(config_dict, (self._experiment_directory / "config.yaml").open("w")) + (self._experiment_directory / "config.yaml").write_text(yaml.safe_dump(config_dict)) # Dumping a verbose version of the config - yaml.safe_dump(config_dict_verbose, (self._experiment_directory / "config_verbose.yaml").open("w")) + (self._experiment_directory / "config_verbose.yaml").write_text(yaml.safe_dump(config_dict_verbose)) else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 58c490cb9..a19893b40 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -153,7 +153,7 @@ def _load_default_config_dict(cls, parsed: argparse.Namespace) -> typing.Any: elif urllib.parse.urlparse(parsed.config).scheme == "https": return yaml.safe_load(cls._load_url(parsed.config, parsed.config_auth_token_file)) elif pathlib.Path(parsed.config).is_file(): - return yaml.safe_load(pathlib.Path(parsed.config).open("r").read()) + return yaml.safe_load(pathlib.Path(parsed.config).read_text()) else: raise FileNotFoundError(parsed.config) @@ -165,9 +165,8 @@ def _load_url(cls, config_url: str, config_auth_token_file: pathlib.Path | None headers = {"Accept": "application/vnd.github.v3.raw"} if config_auth_token_file is not None: - config_auth_token = config_auth_token_file.open("r").read().strip() - with open(config_auth_token_file) as f: - headers["Authorization"] = f"token {config_auth_token}" + config_auth_token = config_auth_token_file.read_text().strip() + headers["Authorization"] = f"token {config_auth_token}" response = requests.get(config_url, headers=headers) if response.status_code == 200: return response.text diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index a214e8e50..b5ae7b4f5 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,5 +1,6 @@ import dataclasses import enum +import functools import logging import os import typing @@ -172,11 +173,6 @@ class DistributedConfig(Config): pipeline_parallel: int = Field( default=1, desc="Pipeline parallelism group size.", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) ) - data_parallel: int = Field(init=False, desc="Data parallelism group size.", hint=FieldHint.derived) - model_parallel: int = Field( - init=False, desc="Model parallelism group size (tensor * pipeline).", hint=FieldHint.derived - ) - num_nodes: int = Field(init=False, desc="Number of GPU nodes.", hint=FieldHint.derived) sequence_tensor_parallel: bool = Field( default=False, desc="Enable sequence tensor parallelism.", hint=FieldHint.performance ) @@ -186,7 +182,6 @@ class DistributedConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - batch_data_parallel: int = Field(init=False, desc="Batch data parallelism group size.", hint=FieldHint.performance) world_size: int = Field( default=None, desc="Size of the world group, e.e., total number of GPUs. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", @@ -199,23 +194,6 @@ class DistributedConfig(Config): hint=FieldHint.expert, valid=check_field(Assert.geq, 0), ) - data_rank: int = Field(init=False, desc="Data-parallel rank of the local process.", hint=FieldHint.derived) - pipeline_rank: int = Field(init=False, desc="Pipeline-parallel rank of the local process.", hint=FieldHint.derived) - tensor_rank: int = Field(init=False, desc="Tensor-parallel rank of the local process.", hint=FieldHint.derived) - sequence_data_rank: int = Field( - init=False, desc="Sequence-data-parallel rank of the local process.", hint=FieldHint.derived - ) - batch_data_rank: int = Field( - init=False, desc="Batch-data-parallel rank of the local process.", hint=FieldHint.derived - ) - distributed_dims: dict[str, DistributedDim] = Field( - init=False, desc="The `DistributedDim` objects for the distributed dimensions.", hint=FieldHint.derived - ) - local_rank: int = Field( - init=False, - desc="The rank of the process on the current node.", - hint=FieldHint.derived, - ) local_world_size: int = Field( default=None, desc="Number of GPUs in each node. Typically provided by torchrun or equivalent through the `LOCAL_WORLD_SIZE` environment variable.", @@ -310,6 +288,112 @@ class DistributedConfig(Config): hint=FieldHint.derived, ) + @functools.cached_property + def model_parallel(self) -> int: + return self.tensor_parallel * self.pipeline_parallel + + @functools.cached_property + def data_parallel(self) -> int: + return div(self.world_size, self.model_parallel) + + @functools.cached_property + def num_nodes(self) -> int: + return div(self.world_size, self.local_world_size) + + @functools.cached_property + def local_rank(self) -> int: + return self.rank % self.local_world_size + + @functools.cached_property + def tensor_rank(self) -> int: + return self.rank % self.tensor_parallel + + @functools.cached_property + def data_rank(self) -> int: + if self.pipeline_first: + # Smaller models can be more demanding on pipeline parallel. + return (self.rank // self.tensor_parallel) // self.pipeline_parallel + else: + # Larger models are more demanding on data parallel. + return (self.rank // self.tensor_parallel) % self.data_parallel + + @functools.cached_property + def pipeline_rank(self) -> int: + if self.pipeline_first: + return (self.rank // self.tensor_parallel) % self.pipeline_parallel + else: + return (self.rank // self.tensor_parallel) // self.data_parallel + + @functools.cached_property + def batch_data_parallel(self) -> int: + return div(self.data_parallel, self.sequence_data_parallel) + + @functools.cached_property + def sequence_data_rank(self) -> int: + return self.data_rank % self.sequence_data_parallel + + @functools.cached_property + def batch_data_rank(self) -> int: + return self.data_rank // self.sequence_data_parallel + + @functools.cached_property + def distributed_dims(self) -> dict[str, "DistributedDim"]: + if self.reference_config is not None: + return self.reference_config.distributed_dims + dims: dict[str, DistributedDim] = {} + tensor_stride = 1 + sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + batch_data_stride = sequence_data_stride * self.sequence_data_parallel + pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) + self._add_distributed_dim_from_sizes_and_strides(dims, DistributedDimNames.world, (self.world_size, 1)) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.data, + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.sequence_data, (self.sequence_data_parallel, sequence_data_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.tensor_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.tensor_and_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + dims, + DistributedDimNames.model_and_sequence_data, + (self.tensor_parallel, tensor_stride), + ( + (self.pipeline_parallel, pipeline_stride) + if self.pipeline_first + else (self.sequence_data_parallel, sequence_data_stride) + ), + ( + (self.sequence_data_parallel, sequence_data_stride) + if self.pipeline_first + else (self.pipeline_parallel, pipeline_stride) + ), + ) + return dims + def _validate(self) -> None: if self.world_size is None: self.world_size = self.default_world_size @@ -317,112 +401,36 @@ def _validate(self) -> None: self.rank = self.default_rank if self.local_world_size is None: self.local_world_size = self.default_local_world_size - self.model_parallel = self.tensor_parallel * self.pipeline_parallel - self.data_parallel = div(self.world_size, self.model_parallel) - self.num_nodes = div(self.world_size, self.local_world_size) - self.local_rank = self.rank % self.local_world_size - Assert.multiple(self.local_world_size, self.tensor_parallel) - - if self.pipeline_first: - # Smaller models can be more demanding on pipeline parallel. - self.data_rank = (self.rank // self.tensor_parallel) // self.pipeline_parallel - self.pipeline_rank = (self.rank // self.tensor_parallel) % self.pipeline_parallel - else: - # Larger models are more demanding on data parallel. - self.data_rank = (self.rank // self.tensor_parallel) % self.data_parallel - self.pipeline_rank = (self.rank // self.tensor_parallel) // self.data_parallel - self.sequence_data_rank = self.data_rank % self.sequence_data_parallel - self.batch_data_parallel = div(self.data_parallel, self.sequence_data_parallel) - self.batch_data_rank = self.data_rank // self.sequence_data_parallel - - self.tensor_rank = self.rank % self.tensor_parallel if self.tensor_parallel == 1 and self.sequence_tensor_parallel: self.sequence_tensor_parallel = False - if self.reference_config is not None: self.reference_config.validate() if self.reference_config.reference_config is not None: self.reference_config = self.reference_config.reference_config assert self.reference_config.reference_config is None - self.distributed_dims = self.reference_config.distributed_dims - else: - self.distributed_dims = {} - - tensor_stride = 1 - sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) - batch_data_stride = sequence_data_stride * self.sequence_data_parallel - pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.world, - (self.world_size, 1), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.data, - (self.sequence_data_parallel, sequence_data_stride), - (self.batch_data_parallel, batch_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.sequence_data, - (self.sequence_data_parallel, sequence_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor_and_sequence_data, - (self.tensor_parallel, tensor_stride), - (self.sequence_data_parallel, sequence_data_stride), - ) - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.tensor_and_data, - (self.tensor_parallel, tensor_stride), - (self.sequence_data_parallel, sequence_data_stride), - (self.batch_data_parallel, batch_data_stride), - ) - - self._add_distributed_dim_from_sizes_and_strides( - DistributedDimNames.model_and_sequence_data, - (self.tensor_parallel, tensor_stride), - ( - (self.pipeline_parallel, pipeline_stride) - if self.pipeline_first - else (self.sequence_data_parallel, sequence_data_stride) - ), - ( - (self.sequence_data_parallel, sequence_data_stride) - if self.pipeline_first - else (self.pipeline_parallel, pipeline_stride) - ), - ) - super()._validate() + Assert.multiple(self.local_world_size, self.tensor_parallel) if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None: - self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) + def _add_distributed_dim_from_sizes_and_strides( + self, dims: dict[str, DistributedDim], name: str, *sizes_and_strides: tuple[int, int] + ) -> None: + self._add_distributed_dim(dims, DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) - def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + def _add_distributed_dim(self, dims: dict[str, DistributedDim], distributed_dim: DistributedDim) -> None: Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) - try: distributed_dim.check_ranks_in_range(0, self.world_size) except: logger.info(str(self)) raise - if distributed_dim.name in self.distributed_dims: - Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) + if distributed_dim.name in dims: + Assert.eq(distributed_dim, dims[distributed_dim.name]) else: - self.distributed_dims[distributed_dim.name] = distributed_dim + dims[distributed_dim.name] = distributed_dim def get_distributed_dim(self, name: str) -> DistributedDim: return self.distributed_dims[name] diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index b0ab08482..372fb7f68 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -27,9 +27,9 @@ def __init__( world_size: int | None = None, local_world_size: int | None = None, timeout: float = 60, - use_cuda: bool = True, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, + device: torch.device | None = None, ): self._rank = DistributedConfig.default_rank if rank is None else rank @@ -38,20 +38,24 @@ def __init__( DistributedConfig.default_local_world_size if local_world_size is None else local_world_size ) self._timeout = timeout - self._use_cuda = use_cuda self._backend = backend self._process_groups = {} - if self._use_cuda: + if device is None: assert torch.cuda.is_available() Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) torch.cuda.set_device(self._device) + elif device.type == "cuda": + assert torch.cuda.is_available() + torch.cuda.init() + self._device = device + torch.cuda.set_device(self._device) else: if backend == DistributedBackend.nccl: Assert.eq(self._world_size, 1) - self._device = torch.device("cpu") + self._device = device if self._world_size > 1: if self._rank == 0: @@ -165,8 +169,8 @@ def __init__(self, config: DistributedConfig): self._config.world_size, self._config.local_world_size, self._config.timeout, - self._config.use_cuda, backend=self._config.backend, + device=None if self._config.use_cuda else torch.device("cpu"), ) else: self._pool = _default_pool diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 4db258093..19ffc87c2 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -26,7 +26,7 @@ def setup( run_count: int, ) -> None: if "HUGGINGFACE_API_KEY_PATH" in os.environ: - os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).read_text().strip() else: if not "HF_TOKEN" in os.environ: logger.warning( @@ -62,4 +62,4 @@ def run( metrics: dict[str, typing.Any], ) -> None: assert self._is_setup - self._flm_wrapper.run(self._config.cli_args, metrics.get("completed_steps", 0), self._run.index) + self._flm_wrapper.run(self._config.cli_args, metrics.get("completed_steps", 0), run_index) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 56a2588c0..511bf6359 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -116,7 +116,7 @@ def max_length(self): if isinstance(self._config.fast_llm_config.base_model.transformer.mixer.rotary, NoRotaryConfig): return self._config.fast_llm_config.base_model.max_position_embeddings - # check if tokenizer holds model sequence leigh info + # check if tokenizer holds model sequence length info if hasattr(self._tokenizer, "model_max_length"): if self._tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH @@ -528,7 +528,7 @@ def tok_batch_encode( if left_truncate_len: original_lengths = encoding["input_ids"].size(1) if original_lengths > left_truncate_len: - logger.warn( + logger.warning( f"Left truncation applied. Original sequence length was {original_lengths}, " f"truncating to last {left_truncate_len} tokens. Some content will be lost.", ) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d19e2478d..b0b3b33a0 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -72,7 +72,7 @@ def _get_config_dict( kwargs.pop("_from_pipeline", None) kwargs.pop("_from_auto", False) kwargs.pop("_commit_hash", None) - kwargs.get("gguf_file", None) + kwargs.pop("gguf_file", None) # Get the pretrained config. if "pretrained" in kwargs: diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 5a07bd51b..67be46558 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -154,7 +154,7 @@ def forward( # TODO: Bypassed if passed as positional argument. assert kwargs.get("past_key_values") is None and not kwargs.get("use_cache") broadcast_kwargs = {**kwargs, **{i: arg for i, arg in enumerate(args)}, "continue_work": continue_work} - tensor_kwargs = {key: value for key, value in broadcast_kwargs if torch.is_tensor(value)} + tensor_kwargs = {key: value for key, value in broadcast_kwargs.items() if torch.is_tensor(value)} broadcast_object( [(key, tensor.shape, tensor.dtype) for key, tensor in tensor_kwargs.items()], distributed.tensor_group, @@ -162,7 +162,7 @@ def forward( ) for tensor in tensor_kwargs.values(): broadcast(tensor.to(distributed.device), 0, distributed.tensor_group) - non_tensor_kwargs = {key: value for key, value in broadcast_kwargs if key not in tensor_kwargs} + non_tensor_kwargs = {key: value for key, value in broadcast_kwargs.items() if key not in tensor_kwargs} broadcast_object( non_tensor_kwargs, distributed.tensor_group, @@ -240,6 +240,6 @@ def stop_workers(self): self.forward(coordinator_forward=True, continue_work=False) safe_barrier(distributed.world_group, "forward_work_end") - def inner_forward(*args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: + def inner_forward(self, *args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: # Meant to be overridden in derived classes raise NotImplementedError() diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 41736aed6..958a3d228 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -42,7 +42,7 @@ class ShardName: grads = "grads" -class StageMode(str, enum.Enum): +class StageMode(enum.StrEnum): # Allow forward and backward passes and optimizer. # TODO: Add mode for forward and backward but not optimizer? training = "training" @@ -72,6 +72,8 @@ def on_device(self) -> bool: @config_class() class StageConfig(Config): + """Configuration for a single model stage: gradient precision, frozen weight storage, and debug logging.""" + full_precision_gradients: bool = Field( default=True, desc="Reduce and accumulate gradients in fp32 to improve numerical stability.", @@ -80,7 +82,7 @@ class StageConfig(Config): store_frozen_weights_in_optimization_precision: bool = Field( # TODO: Implement and set default to False default=True, - desc="Store frozen weights in full precision even if not not needed." + desc="Store frozen weights in full precision even if not needed." "Allows preserving the precision for saved checkpoints," " at the cost of memory and compute (copy) overheads.", hint=FieldHint.optional, @@ -141,6 +143,8 @@ class StageConfig(Config): @config_class() class MultiStageConfig(StageConfig): + """Configuration for the multi-stage model layout: layers per stage, ZeRO sharding, and buffer counts.""" + layers_per_stage: float = Field( default=1.0, desc="Number of layers to include in each Fast LLM stage.", @@ -206,6 +210,8 @@ def _validate(self) -> None: @config_class(registry=True) class FastLLMModelConfig(Config): + """Abstract base configuration for a Fast-LLM model: base model, multi-stage layout, and distributed config.""" + _abstract = True checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = ( DistributedCheckpointFormat, @@ -283,6 +289,8 @@ def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None: @config_class() class PretrainedFastLLMModelConfig(Config): + """Configuration wrapper that optionally loads model weights and config from a pretrained checkpoint.""" + # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index ed293b103..cd781beb7 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -487,13 +487,13 @@ def get_state_tensor_iterator( self, shard_names: tuple[str, ...], data_type: DataType | None = None ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: - shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) + shard_split = self._shards[shard_name].split(self._get_stage_shard_sizes(shard_name), 0) for shard_index, ((stage_index, stage), shard) in enumerate( zip(self._stages_on_device.items(), shard_split, strict=True) ): if stage_index in self._stages_owned: for name, tensor in stage._export_shard( - shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type + shard.split(self._get_fsdp_shard_sizes(shard_name)[shard_index]), data_type=data_type ): # noqa yield name, shard_name, tensor @@ -508,8 +508,8 @@ def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torc shard_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] stage_shards = ( self._shards[shard_name] - .split(self._stage_weight_shard_sizes, 0)[shard_index] - .split(self._fsdp_weight_shard_sizes[shard_index]) + .split(self._get_stage_shard_sizes(shard_name), 0)[shard_index] + .split(self._get_fsdp_shard_sizes(shard_name)[shard_index]) ) return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shards, tensor) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 23ee5d8bd..ea737524b 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -204,17 +204,19 @@ def initialize_weights(self) -> None: meta.init_parameter(parameter, self._distributed, debug=self._config.debug_param_init) if self.mode.on_device: - fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) + for fsdp in self._fsdps: + fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) if self._config.debug_param_init: if self._mode.on_device: - fsdp.log_shard( - name="param", - shard=fsdp.weight_shard, - distributed=self._distributed, - level=self._config.debug_param_init, - global_=self._config.debug_global_tensors, - ) + for fsdp in self._fsdps: + fsdp.log_shard( + name="param", + shard=fsdp.weight_shard, + distributed=self._distributed, + level=self._config.debug_param_init, + global_=self._config.debug_global_tensors, + ) def get_param_groups( self, optimizer_state_shards: dict[str, tuple[torch.Tensor]], param_group_cls: type[ParamGroup] @@ -238,9 +240,9 @@ def get_param_groups( continue chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale)) buffer_begin = fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name) - for i, lr_scale in enumerate(parameter_meta.lr_scale): - begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size) - end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) + for lr_scale_index, lr_scale in enumerate(parameter_meta.lr_scale): + begin = fsdp.index_buffer_to_shard(buffer_begin + lr_scale_index * chunk_size) + end = fsdp.index_buffer_to_shard(buffer_begin + (lr_scale_index + 1) * chunk_size) if lr_scale == 0 or begin == end: continue optimizer_params = (parameter_meta.param_weight_decay, lr_scale) @@ -279,7 +281,7 @@ def get_param_groups( grads_norm_slices = [] for name in grad_norm_names: begin, end = fsdp._get_parameter_range_in_shard(name) - if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: + if len(grads_norm_slices) > 0 and begin == grads_norm_slices[-1].stop: grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) else: grads_norm_slices.append(slice(begin, end)) diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index f4303a5d3..2b0e8709b 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -17,6 +17,8 @@ class LearningRateStageType: @config_class() class LearningRateScheduleConfig(Config): + """Configuration for the learning rate schedule (warmup, decay style, and bounds).""" + base: float = Field(default=0.0001, desc="Base learning rate for the optimizer.", hint=FieldHint.core) decay_style: str = Field(default="constant", desc="The learning rate decay formula.", hint=FieldHint.feature) decay_iterations: int | None = Field( @@ -38,6 +40,8 @@ class LearningRateScheduleConfig(Config): @config_class() class GradientScalerConfig(Config): + """Configuration for loss scaling, either fixed (constant) or dynamic (for fp16 training).""" + constant: float | None = Field( default=None, desc="Constant multiplier applied to the loss. Setting this disables dynamic scaling.", @@ -72,6 +76,7 @@ class GradientScalerConfig(Config): @config_class() class OptimizerConfig(Config): + """Configuration for the AdamW optimizer: learning rate schedule, gradient scaling, and hyperparameters.""" learning_rate: LearningRateScheduleConfig = Field( desc="A schedule for the learning rate.", diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index c6912e4f1..3f58c953c 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -59,22 +59,20 @@ def __post_init__(self) -> None: super().__post_init__() Assert.gt(self.power, 0) - @abc.abstractmethod def _interpolate(self, coeff: float) -> float: return coeff**self.power @dataclasses.dataclass() class CosineLRStage(InterpolateLRStage): - lr: int - end_lr: int + lr: float + end_lr: float power: float = 1.0 def __post_init__(self) -> None: super().__post_init__() Assert.gt(self.power, 0) - @abc.abstractmethod def _interpolate(self, coeff: float) -> float: return 0.5 * (1.0 - math.cos(math.pi * coeff**self.power)) diff --git a/fast_llm/engine/optimizer/optimizer.py b/fast_llm/engine/optimizer/optimizer.py index 0dd094390..80def6a28 100644 --- a/fast_llm/engine/optimizer/optimizer.py +++ b/fast_llm/engine/optimizer/optimizer.py @@ -242,8 +242,11 @@ def unscale_and_check_nans(self, tensor: torch.Tensor) -> None: class ConstantGradScaler(VariableGradScaler): def load(self, state, validate=True) -> None: - if validate: - Assert.eq(self._scale, state["scale"]) + if hasattr(self, "_scale"): + if validate: + Assert.eq(self._scale, state["scale"]) + else: + self._set_scale(state["scale"]) super().load(state, validate=validate) def _set_scale(self, value) -> None: @@ -282,6 +285,7 @@ def save(self) -> dict[str, typing.Any]: def load(self, state, validate=True) -> None: super().load(state, validate=validate) + self._set_scale(state["scale"]) self._growth_tracker = state["growth"] self._hysteresis_tracker = state["hysteresis"] diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48714db40..29720b90b 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -1,26 +1,29 @@ import enum import functools -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, test_field +from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert -class StepType(str, enum.Enum): +class StepType(enum.StrEnum): forward = "forward" backward = "backward" @config_class() class ScheduleConfig(Config): + """Configuration for the micro-batch execution schedule: pipeline overlap, CPU throttling, and debug options.""" + depth_first_micro_batches: int = Field( default=1, - desc="Size of individual micro-batches. May be derived or constrained be other quantities.", + desc="Number of micro-batches processed depth-first, i.e., each runs through all model stages before the next" + " begins. This is the standard way to perform gradient accumulation.", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) breadth_first_micro_batches: int = Field( default=1, - desc="Size of individual micro-batches. May be derived or constrained be other quantities.", + desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -69,10 +72,6 @@ class ScheduleConfig(Config): desc="Detailed time table for the schedule execution (cpu and gpu times).", hint=FieldHint.logging, ) - # TODO: Remove - estimate_critical_batch: bool = Field( - default=False, desc="No longer supported.", hint=FieldHint.deprecated, valid=test_field(lambda x: not x) - ) # Skip the weight update and related ops (debug) skip_step: bool = Field( default=False, @@ -89,18 +88,18 @@ def num_inputs(self) -> int: return self.sequential_micro_batches * self.micro_batch_splits -class StreamType(str, enum.Enum): +class StreamType(enum.StrEnum): compute = "compute" data = "data" pipeline = "pipeline" -class StepScheduleType(str, enum.Enum): +class StepScheduleType(enum.StrEnum): breadth_first = "breadth_first" depth_first = "depth_first" -class EventType(str, enum.Enum): +class EventType(enum.StrEnum): # Global events batch_begin = "batch_begin" batch_end = "batch_end" diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 7ad03b24c..b95d39463 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -164,7 +164,7 @@ def run_step( if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"Beginning of {context.phase.value} iteration {iteration}", str) + lambda: log_memory_usage(f"Beginning of {context.phase} iteration {iteration}", str) ) self._multi_stage.train(context.is_training) self._distributed.set_step(iteration, schedule.phase) @@ -278,7 +278,7 @@ def run_step( if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) + lambda: log_memory_usage(f"End of {context.phase} iteration {iteration}", str) ) return self._reduce_losses(context), update_successful, metrics @@ -487,12 +487,12 @@ def _handle_events(self, context: BatchContext) -> None: def _save_events(self, events, context: BatchContext) -> None: out = { "iteration": context.iteration, - "phase": context.phase.value, + "phase": context.phase, "rank": self._distributed_config.rank, "events": [ { - "event_type": type_.value, - "stream": stream.value, + "event_type": type_, + "stream": stream, "gpu_time": gpu_time, "cpu_time": cpu_time, **( @@ -500,7 +500,7 @@ def _save_events(self, events, context: BatchContext) -> None: if step is None else { "step_idx": step.global_index, - "step_type": step.type_.value, + "step_type": step.type_, "step_stage": step.stage, "step_depth_first_micro_batch": step.depth_first_micro_batch, "step_breadth_first_micro_batch": step.breadth_first_micro_batch, @@ -514,7 +514,7 @@ def _save_events(self, events, context: BatchContext) -> None: yaml.safe_dump( out, get_run().open_artifact( - f"schedule_profile_rank_{self._distributed_config.rank}_{context.phase.value}_step_{context.iteration}" + f"schedule_profile_rank_{self._distributed_config.rank}_{context.phase}_step_{context.iteration}" ), ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index e2a9c75b5..6f7bf1d95 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -149,7 +149,7 @@ def __init__( self._setup_metas() if self._config.debug_schedule: - logger.info(f"{self._phase.value} schedule:\n{self._steps}") + logger.info(f"{self._phase} schedule:\n{self._steps}") @property def phase(self) -> PhaseType: @@ -210,7 +210,7 @@ def _create_index(self) -> None: for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( step_map.pop((type_, stage, data_index), None) is not None - ), f"Missing {type_.value} step with stage={stage}, data_index={data_index}" + ), f"Missing {type_} step with stage={stage}, data_index={data_index}" Assert.empty(step_map) # Related steps diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index d5c6fbc7c..bece3cb49 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -1,4 +1,5 @@ import abc +import functools import os import pathlib import shlex @@ -10,7 +11,7 @@ Configurable, Field, FieldHint, - FieldUpdate, + FieldOverride, NoAutoValidate, check_field, config_class, @@ -50,6 +51,8 @@ def _validate_script(value: str | list[str]) -> list[str]: @config_class() class CallbackConfig(Config): + """Configuration for an optional shell script callback invoked after a checkpoint, export, or shutdown event.""" + script: list[str] | None = Field( default=None, desc="Shell script to run.", @@ -71,12 +74,14 @@ def run(self) -> None: @config_class() class WandbAlertConfig(IntervalConfig): - interval = FieldUpdate( + """Configuration for periodic Weights & Biases status alerts during training.""" + + interval = FieldOverride( desc="The number of training iterations between each Wandb status post (alert)." " Setting to None will disable iteration-based wandb alerts." " Must be a sub-interval of the logging interval." ) - offset = FieldUpdate( + offset = FieldOverride( desc="Offset for the first Wandb status post (alert)." " Must be compatible with the logging offset.", ) status_updates: bool | None = Field( @@ -85,26 +90,28 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False) - def _validate(self) -> None: - if self.status_updates is None: - self.post_alerts = self.enabled() - super()._validate() + @functools.cached_property + def post_alerts(self) -> bool: + return self.status_updates if self.status_updates is not None else self.enabled() @config_class() class MetricsLogsConfig(IntervalConfig): - interval = FieldUpdate( + """Configuration for training metric logging interval (loss, throughput, etc.).""" + + interval = FieldOverride( default=100, desc="The number of training iterations between each metric logs." " Setting to None will disable metric logging.", ) - offset = FieldUpdate(desc="Offset for the first metric logs.") + offset = FieldOverride(desc="Offset for the first metric logs.") @config_class() class WandbConfig(Config): + """Configuration for Weights & Biases experiment tracking (project, entity, alerts).""" + alert: WandbAlertConfig = Field( desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", @@ -117,6 +124,8 @@ class WandbConfig(Config): @config_class() class TrainingCheckpointBaseConfig(IntervalConfig): + """Abstract base configuration for periodic saving operations (checkpoints and exports).""" + _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( @@ -157,14 +166,16 @@ def to_delete(self, iterations: list[int]) -> list[int]: @config_class() class TrainingCheckpointConfig(TrainingCheckpointBaseConfig): + """Configuration for saving full training checkpoints (weights + optimizer state) at a fixed interval.""" + _abstract = False save_name: typing.ClassVar[str] = "checkpoint" - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each checkpoint. Setting to None will disable checkpoints." ) - offset = FieldUpdate(desc="Offset for the first checkpoint.") - callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after checkpoint.") - keep: int | None = FieldUpdate(default=5) + offset = FieldOverride(desc="Offset for the first checkpoint.") + callback: CallbackConfig = FieldOverride(desc="Callback (shell script) to run after checkpoint.") + keep: int | None = FieldOverride(default=5) def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "checkpoint" @@ -190,13 +201,15 @@ def get_load_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi @config_class() class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConfigBase): + """Configuration for exporting model weights to an external format (e.g. HuggingFace) at a fixed interval.""" + _abstract = False save_name: typing.ClassVar[str] = "export" - interval = FieldUpdate( + interval = FieldOverride( desc="The number of training iterations between each export." " Setting to None will disable exports." ) - offset = FieldUpdate(desc="Offset for the first export.") - callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") + offset = FieldOverride(desc="Offset for the first export.") + callback: CallbackConfig = FieldOverride(desc="Callback (shell script) to run after export.") def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name @@ -207,18 +220,22 @@ def get_save_config(self, path: pathlib.Path, timeout: float | None) -> Checkpoi @config_class() class ShutdownConfig(IntervalConfig): - interval = FieldUpdate( + """Configuration for automatic training shutdown after a checkpoint, useful for preemptible jobs.""" + + interval = FieldOverride( desc="The number of training iterations between each automated shutdown." " Setting to None will disable automated shutdowns." " Must be a sub-interval of the checkpoint interval." ) - offset = FieldUpdate( + offset = FieldOverride( desc="Offset for the first automated shutdown." " Must be compatible with the checkpoint offset." ) @config_class() class TrainingConfig(Config): + """Configuration for training phases: iterations, checkpoints, exports, logging, evaluators, and W&B.""" + evaluators: dict[str, EvaluatorConfig] = Field( default_factory=dict, desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", @@ -261,6 +278,8 @@ def _validate(self) -> None: @config_class(registry=True) class TrainerCallbackConfig(Config): + """Abstract base configuration for trainer callbacks that hook into training events.""" + def get_callback(self, model: "FastLLMModel") -> "TrainerCallback": raise NotImplementedError() @@ -270,6 +289,8 @@ def setup(self, config: "TrainerConfig") -> None: @config_class() class WeightsBroadcastConfig(Config): + """Configuration for broadcasting model weights to an external process via NCCL (used in online RL pipelines).""" + # TODO: Have the external model send these instead? host: str = Field( default="localhost", @@ -295,9 +316,7 @@ class WeightsBroadcastConfig(Config): @config_class(dynamic_type={TrainerCallbackConfig: "streaming"}) class StreamingTrainerCallbackConfig(TrainerCallbackConfig, RedisConfig): - """ - Aggregates all trainer-side Redis-based event configurations. - """ + """Trainer callback for online RL: exports and broadcasts model weights via Redis after each update.""" broadcast: WeightsBroadcastConfig = Field( desc="Configuration for signaling weight-ready events via Redis.", @@ -320,6 +339,8 @@ def setup(self, config: "TrainerConfig") -> None: @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): + """Abstract base configuration for a training run: model, data, schedule, optimizer, callbacks, and checkpointing.""" + _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py index 7870b45bc..aec14530f 100644 --- a/fast_llm/engine/training/streaming.py +++ b/fast_llm/engine/training/streaming.py @@ -2,6 +2,8 @@ import logging import typing +import torch + from fast_llm.core.distributed import broadcast as _broadcast from fast_llm.core.distributed import broadcast_object as _broadcast_object from fast_llm.engine.distributed.config import DistributedBackend @@ -26,15 +28,16 @@ def __init__(self, config: ConfigType, model: "FastLLMModel"): init_method = f"tcp://{config.broadcast.host}:{config.broadcast.port}" logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") world_size = config.broadcast.external_world_size + 1 - self._process_group = ProcessGroupPool( + self._pool = ProcessGroupPool( rank=0, world_size=world_size, local_world_size=1, timeout=self._config.timeout, - use_cuda=self._config.broadcast.backend == DistributedBackend.nccl, + device=None if self._config.broadcast.backend == DistributedBackend.nccl else torch.device("cpu"), init_method=init_method, backend=self._config.broadcast.backend, - ).get_process_group(range(world_size), 0) + ) + self._process_group = self._pool.get_process_group(range(world_size), 0) logger.info(f"Weights broadcast rendezvous at {init_method} connected") def run_begin(self, step: int): @@ -61,8 +64,9 @@ def __del__(self): self._clear() def _clear(self): - if hasattr(self, "_process_group"): - self._process_group.shutdown() + if hasattr(self, "_pool"): + self._pool.shutdown() + del self._pool del self._process_group def _broadcast_weights(self, step: int): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 1e341077f..00cf2fa0d 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -192,7 +192,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: interrupter = Interrupter(self._config.training.checkpoint.enabled()) train_iterator = self._get_data_iterator( - PhaseType.training.value, + PhaseType.training, self._completed_steps, self._config.training.prefetch_factor, ) @@ -254,7 +254,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - metrics_key = PhaseType.training.value + metrics_key = PhaseType.training metrics[metrics_key] = { "batch_size": self._batch_size, **{ @@ -402,7 +402,7 @@ def _save_checkpoint( ) # Mark the checkpoint as complete. if self._run.is_main_rank: - (checkpoint_directory / "ok").open("w") + (checkpoint_directory / "ok").touch() logger.info(f"Saved {config.save_name} to {checkpoint_directory}") to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_base_directory.iterdir())) @@ -429,7 +429,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. - self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] + self._completed_steps = metadata["schedules"][PhaseType.training]["completed_steps"] else: self._completed_steps = metadata["completed_steps"] # TODO: Move barrier, ok file to FastLLMModel diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 724b5b718..3349cff26 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -19,14 +19,14 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): # Wandb login from file api_key_path = os.environ.get("WANDB_API_KEY_PATH") if api_key_path: - os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).open("r").read().strip() + os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).read_text().strip() wandb_path = ( None if self._run.experiment_directory is None else self._run.experiment_directory / "wandb_config.yaml" ) if wandb_path is not None and wandb_path.is_file(): - wandb_config = yaml.safe_load(wandb_path.open("r")) + wandb_config = yaml.safe_load(wandb_path.read_text()) else: wandb_config = { "id": wandb.sdk.lib.runid.generate_id(16), @@ -38,7 +38,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): "resume": "allow", } if wandb_path is not None: - yaml.safe_dump(wandb_config, wandb_path.open("w")) + wandb_path.write_text(yaml.safe_dump(wandb_config)) # TODO: Does wandb work with nested configs? self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config) else: @@ -53,8 +53,6 @@ def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: - pass - self._wandb.alert( # noqa title=title() if callable(title) else title, text=f"[{self._config.project_name}/{self._run.experiment_name}, run {self._run.index}]" diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index f5b394bfb..baf50099d 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -27,7 +27,20 @@ tl_arange = None tl_full = None elif triton_interpret: - # Workaround for a triton bug. + # Workaround for a triton interpreter bug: constexpr int arguments to device functions + # arrive as 1-d numpy arrays rather than scalars. The interpreter's _patch_lang_tensor sets + # tensor.__index__ = lambda self: int(self.handle.data), which fails for 1-d arrays. + # Patch _patch_lang_tensor to use .item() instead, which works for both 0-d and 1-d arrays. + import triton.runtime.interpreter as _triton_interpreter + + _orig_patch_lang_tensor = _triton_interpreter._patch_lang_tensor + + def _fixed_patch_lang_tensor(tensor): + _orig_patch_lang_tensor(tensor) + tensor.__index__ = lambda self: self.handle.data.item() + + _triton_interpreter._patch_lang_tensor = _fixed_patch_lang_tensor + @triton_jit def tl_arange(start, end): return tl.arange(int(start), int(end)) diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py new file mode 100644 index 000000000..39d832ccd --- /dev/null +++ b/fast_llm/functional/triton/grpo_loss.py @@ -0,0 +1,232 @@ +import torch + +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton.entropy_loss import ( + parallel_sum_exp_logits, + triton_cross_entropy_forward_from_labels_parallel_kernel, + triton_fused_softmax_base, +) +from fast_llm.functional.utils import reduce_losses + + +@triton_jit() +def triton_grpo_loss_forward_backward_kernel( + logits_ptr, + labels_ptr, + advantages_ptr, + old_log_probs_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + block_size: tl_constexpr, + losses_ptr=None, + new_logprobs_mean_parts_ptr=None, + num_labels_in_seq_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, + epsilon_low: tl_constexpr = 0.2, + epsilon_high: tl_constexpr = 0.2, + accumulate: tl_constexpr = False, +): + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + + label_idx = tl.load(labels_ptr + block_idx) + if label_idx < 0: + # Masked position. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if new_logprobs_mean_parts_ptr is not None: + tl.store(new_logprobs_mean_parts_ptr + block_idx, 0) + if grad_losses is not None and not accumulate: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + label_idx -= col_min + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + # Non-parallel: compute softmax and predicted logit in one forward pass. + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor + ) + if label_idx >= 0 and label_idx < n_cols: + predicted_logit = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logit *= logits_scale_factor + else: + # Parallel case only: target not in local vocab shard. + predicted_logit = 0.0 + else: + # Parallel case: use globally reduced values from the first pass. + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + predicted_logit = tl.load(predicted_logits_ptr + block_idx) + + # new_log_prob = log_softmax(logits * scale)[label] + # = logits[label]*scale - (max_logits + log(sum_exp_logits)) + new_log_prob = predicted_logit - max_logits - tl.log(sum_exp_logits) + old_log_prob = tl.load(old_log_probs_ptr + block_idx).to(tl.float32) + advantage = tl.load(advantages_ptr + block_idx).to(tl.float32) + + ratio = tl.exp(new_log_prob - old_log_prob) + clipped_ratio = tl.minimum(tl.maximum(ratio, 1.0 - epsilon_low), 1.0 + epsilon_high) + loss = -tl.minimum(ratio * advantage, clipped_ratio * advantage) + + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, loss) + + if new_logprobs_mean_parts_ptr is not None: + num_labels = tl.load(num_labels_in_seq_ptr + block_idx).to(tl.float32) + tl.store(new_logprobs_mean_parts_ptr + block_idx, new_log_prob / tl.maximum(num_labels, 1.0)) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + # effective_grad = probability_ratio_grad * ratio + # = (clamp_min(adv, 0) * (ratio <= 1+eps_high) + clamp_max(adv, 0) * (ratio >= 1-eps_low)) * ratio * grad_losses + effective_grad = ( + ( + tl.maximum(advantage, 0.0) * (ratio <= 1.0 + epsilon_high) + + tl.minimum(advantage, 0.0) * (ratio >= 1.0 - epsilon_low) + ) + * ratio + * grad_losses + ) + + # grad_logits_i = effective_grad * (p_i - delta_{i, label}) + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + prob = exp_logits / sum_exp_logits + if label_idx < 0 or label_idx >= n_cols: + # Target not in local vocab shard (parallel case): no delta term. + grad_logits = effective_grad * prob + else: + grad_logits = effective_grad * tl.where(col_offsets == label_idx, prob - 1.0, prob) + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) + + +def triton_grpo_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, + divisor: float | None = None, + block_size: int | None = None, + num_warps: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + assert logits.is_contiguous() + assert target.is_contiguous() + assert advantages.is_contiguous() + assert old_log_probabilities.is_contiguous() + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) + if divisor is None: + divisor = n_rows + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + shared_kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "block_size": block_size, + "num_warps": num_warps, + } + kwargs = { + **shared_kwargs, + "epsilon_low": epsilon_low, + "epsilon_high": epsilon_high, + } + if grad_output is None: + backward_kwargs = {} + else: + accumulate = grad_logits is not None + grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits + backward_kwargs = { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / divisor, + "grad_logits_stride_0": grad_logits.stride(-2), + "accumulate": accumulate, + } + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if num_labels_in_seq is not None: + assert num_labels_in_seq.is_contiguous() + new_logprobs_mean_parts = torch.empty(n_rows, dtype=torch.float, device=logits.device) + new_logprobs_mean_kwargs = { + "new_logprobs_mean_parts_ptr": new_logprobs_mean_parts, + "num_labels_in_seq_ptr": num_labels_in_seq, + } + else: + new_logprobs_mean_kwargs = {} + + if group is None: + triton_grpo_loss_forward_backward_kernel[(n_rows,)]( + logits, + target, + advantages, + old_log_probabilities, + losses_ptr=losses, + **kwargs, + **backward_kwargs, + **new_logprobs_mean_kwargs, + ) + else: + local_max_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + sum_exp_logits = torch.empty_like(local_max_logits) + predicted_logits_local = torch.empty_like(local_max_logits) + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits_local, + col_min=n_cols * group.rank(), + **shared_kwargs, + ) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + torch.distributed.all_reduce(predicted_logits_local, op=torch.distributed.ReduceOp.SUM, group=group) + triton_grpo_loss_forward_backward_kernel[(n_rows,)]( + logits, + target, + advantages, + old_log_probabilities, + losses_ptr=losses, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits_local, + col_min=n_cols * group.rank(), + **kwargs, + **backward_kwargs, + **new_logprobs_mean_kwargs, + ) + + loss = reduce_losses(losses, divisor) + new_logprobs_mean = new_logprobs_mean_parts.sum() if num_labels_in_seq is not None else None + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 4a8c5f179..52af93bde 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -156,7 +156,7 @@ def triton_mlp_activation_forward( input_, output, gated=gated, # noqa - activation_type=activation_type.value, # noqa + activation_type=activation_type, # noqa n_cols=n_cols, # noqa block_size=TritonConfig.POINTWISE_BLOCK_SIZE, ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 86469c3d9..fcb5bfaf6 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -42,6 +42,8 @@ class AttentionImplementation(enum.StrEnum): @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): + """Configuration for multi-head and grouped-query attention with optional rotary embeddings.""" + # TODO: Make mixer class dynamic. _abstract = False diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index acf807c69..805eae1e5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -68,7 +68,7 @@ def __call__( "", tensor, level=level, - meta=self._get_meta(tensor, name + f"{name}.grad", dims), + meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b6ed2d851..aa47a5f2e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -169,7 +169,7 @@ def _validate(self): if missing := used_blocks - available_blocks: raise ValueError(f"The following blocks are present in the pattern but undefined: {missing}") if extra := available_blocks - used_blocks: - raise warnings.warn(f"The following blocks are defined but unused: {extra}") + warnings.warn(f"The following blocks are defined but unused: {extra}") super()._validate() diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e92..803edc302 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,6 +1,7 @@ import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -14,7 +15,7 @@ @config_class() -class LinearBaseConfig(Config): +class LinearBaseConfig(ModuleConfig): """ Configuration for a linear-like layer without bias. """ @@ -45,6 +46,10 @@ class AffineLinearBaseConfig(LinearBaseConfig): @config_class() class LinearConfig(LinearBaseConfig): + """Configuration for a linear (weight-only, no bias) layer with optional PEFT and tensor-parallelism support.""" + + _abstract = False + apply_peft: bool | None = Field( default=None, desc="Wrap this layer ." @@ -104,6 +109,10 @@ def get_layer( @config_class() class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): + """Configuration for an affine linear layer (weight + optional bias) with optional PEFT and tensor-parallelism support.""" + + _abstract = False + def get_layer( self, in_dim: TensorDim, @@ -167,6 +176,8 @@ class CausalConv1dConfig(AffineLinearBaseConfig): Configuration for a 1d causal convolution, as used in mamba layers. """ + _abstract = False + kernel_size: int = Field( default=4, desc="Convolution kernel size.", @@ -175,6 +186,7 @@ class CausalConv1dConfig(AffineLinearBaseConfig): ) activation: ActivationType | None = Field( default=None, + desc="Activation function applied after the convolution. None means no activation.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 1c23d6d8a..fd9670d9f 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -60,7 +60,7 @@ def _forward_causal_conv1d( input_, self.weight.squeeze(1), self.bias, - activation=(None if self._activation == ActivationType.identity else self._activation.value), + activation=(None if self._activation == ActivationType.identity else self._activation), seq_idx=document_index, ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 4b8edaebe..274215bf2 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -14,7 +14,7 @@ from fast_llm.layers.common.normalization.normalization import Normalization -class NormalizationImplementation(str, enum.Enum): +class NormalizationImplementation(enum.StrEnum): """ An enum for the available implementations of layer norm. """ @@ -28,6 +28,8 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) class NormalizationConfig(ModuleConfig): + """Abstract base configuration for normalization layers. Use `type: layer_norm`, `rms_norm`, `gated_rms_norm`, or `none`.""" + lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." @@ -62,6 +64,8 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi @config_class(dynamic_type={NormalizationConfig: "none"}) class NoNormalizationConfig(NormalizationConfig): + """Disables normalization entirely (identity pass-through).""" + _abstract = False @property @@ -106,6 +110,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): + """Configuration for standard layer normalization (mean and variance, with learnable weight and bias).""" + bias: ParameterConfig = Field( desc="Configuration for the weight.", hint=FieldHint.architecture, @@ -121,6 +127,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) class RMSNormalizationConfig(LayerNormalizationBaseConfig): + """Configuration for RMS normalization (variance only, no mean subtraction, no bias).""" + _abstract = False @property @@ -132,6 +140,8 @@ def module_class(self): @config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) class GatedRMSNormalizationConfig(RMSNormalizationConfig): + """Configuration for gated RMS normalization, which applies a learned activation gate alongside the norm weight.""" + _abstract = False activation: ActivationType = Field( diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index eaf9f67f0..badfc91f2 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -61,8 +61,9 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor input_ = input_.detach().requires_grad_() with torch.enable_grad(): output = old_forward(input_) + layer_out = output if isinstance(output, tuple): - layer_out, tp_bias = output[0] + layer_out, tp_bias = output assert tp_bias is None lora_out = (alpha / rank) * module.lora_1( module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4cab2d39b..6ab259b2b 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -56,6 +56,8 @@ def get_layer( @config_class(registry=True) class MLPBaseConfig(BlockWithBiasConfig): + """Abstract base configuration for MLP (feedforward) layers. Use `type: mlp` or `type: moe` to select a variant.""" + _abstract = True def get_layer( @@ -200,9 +202,17 @@ def layer_class(self) -> "type[StochasticMixer]": @config_class(dynamic_type={BlockConfig: "decoder"}) class DecoderBlockConfig(BlockConfig): + """Configuration for a transformer decoder block (attention + MLP + normalization + residual).""" + _abstract = False - mixer: MixerConfig = Field() - mlp: MLPBaseConfig = Field() + mixer: MixerConfig = Field( + desc="Configuration for the attention/mixer layer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the feedforward (MLP) layer.", + hint=FieldHint.architecture, + ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the block normalization layers.", diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 36841b45b..28198f2e4 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -18,13 +18,15 @@ class MLPLossNames: router_z_loss = "router_z_loss" -class RoutingType(str, enum.Enum): +class RoutingType(enum.StrEnum): topk = "aux_loss" sinkhorn = "sinkhorn" @config_class(dynamic_type={MLPBaseConfig: "mlp"}) class MLPConfig(MLPBaseConfig): + """Configuration for a dense feedforward (MLP) layer with optional gating and activation recomputation.""" + # TODO: Review names # TODO: Separate MoE? _abstract = False @@ -81,6 +83,8 @@ def layer_class(self) -> "type[MLP]": @config_class(dynamic_type={MLPBaseConfig: "moe"}) class MoEMLPConfig(MLPConfig): + """Configuration for a Mixture-of-Experts (MoE) feedforward layer with top-k token routing.""" + router: LinearConfig = Field( # TODO: Improve default? desc="Configuration for the MoE router.", diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 48bc5a5e1..04f0cd04a 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -116,7 +116,7 @@ def _forward( if self._config.routing == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) if self._config.shared_experts > 0: - scores, top_experts = self._add_shared_experts(top_experts, scores) + scores, top_experts = self._add_shared_experts(scores, top_experts) elif self._config.routing == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 1048f7c2a..80599da97 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -71,7 +71,6 @@ def __init__( def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.intermediate_size, self._parallel_dim) if self._config.gated: - TensorDim("gate_and_up", 2) intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) else: intermediate_1_dim = intermediate_2_dim diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 97bd1f477..a3ea8b846 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -161,8 +161,6 @@ def _forward( return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) def get_preprocessing_config(self) -> dict[str, typing.Any]: - for mixer in self.mixers.values(): - mixer.get_preprocessing_config() return safe_merge_dicts(*(mixer.get_preprocessing_config() for mixer in self.mixers.values())) def _sample_allocation(self, num_layers: int, generator: torch.Generator) -> list[int]: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index a2c067a95..4381aa5d9 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -200,6 +200,11 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index af2149d36..059f808e5 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -13,8 +13,6 @@ def __init__(self, *args, **kwargs): raise NotImplementedError() if self._num_splits > 1: raise NotImplementedError() - if self._prediction_distance > 1: - raise NotImplementedError() if self._vocab_parallel: raise NotImplementedError() diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 62f591d9f..a933fec99 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -4,6 +4,7 @@ import torch from fast_llm.engine.base_model.config import LossDef +from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs @@ -19,7 +20,13 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - loss, grad, new_logprobs_mean = fused_grpo_loss_forward_backward( + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( logits, self._get_labels(kwargs, split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce54685e8..9e690e668 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -89,9 +89,6 @@ def layer_class(self) -> "type[GatedDeltaNet]": return GatedDeltaNet - def _validate(self) -> None: - super()._validate() - @config_class(dynamic_type={MixerConfig: "kda"}) class KimiDeltaAttentionConfig(MixerConfig): diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 16222b3c5..770139816 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,7 +1,7 @@ import logging import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Field, FieldHint, FieldOverride, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -56,9 +56,11 @@ def base_model_class(self) -> type["GPTBaseModel"]: @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): + """Configuration for the GPT model, including distributed, multi-stage, and HuggingFace checkpoint formats.""" + _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate() + base_model: GPTBaseModelConfig = FieldOverride() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, LlamaCheckpointFormat, @@ -93,15 +95,19 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): + """Configuration for a GPT model together with an optional pretrained checkpoint to load.""" + _abstract = False - model: GPTModelConfig = FieldUpdate() + model: GPTModelConfig = FieldOverride() @config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() + """Top-level configuration for training a GPT model. Entry point for `fast-llm train gpt`.""" + + data: GPTDataConfig = FieldOverride() # TODO: Use dynamic model type? - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() + reference_models: dict[str, PretrainedGPTModelConfig] = FieldOverride() def _validate(self) -> None: if self.model.base_model.use_megatron_initialization: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 505c62d70..ac732ba22 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -34,7 +34,7 @@ def import_config(cls, config: dict) -> dict: "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, "dt_rank": ( - math.ceil(config["hidden_size"]) + math.ceil(config["hidden_size"] / 16) if config["ssm_cfg"].get("dt_rank", "auto") == "auto" else config["ssm_cfg"]["dt_rank"] ), diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 38dc38586..491ddde6e 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -217,7 +217,7 @@ def import_config(cls, config: dict) -> dict: } ) else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") out = { "rotary": rotary_config, "heads": config["num_attention_heads"], diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 4ebf18c3a..473135648 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -25,7 +25,7 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: - config["attention_bias"] = True + config["attention_bias"] = False out = super().import_config(config) out["query_layer"] = {"bias": {"enabled": True}} out["key_layer"] = {"bias": {"enabled": True}} diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index becdcacbb..15d62ad9e 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -1,7 +1,7 @@ import logging import typing -from fast_llm.config import FieldUpdate, config_class +from fast_llm.config import FieldOverride, config_class from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig @@ -40,7 +40,7 @@ def base_model_class(self) -> type["MultiModalBaseModel"]: class MultiModalModelConfig(GPTModelConfig): _abstract = False model_name: typing.ClassVar[str] = "multimodal" - base_model: MultiModalBaseModelConfig = FieldUpdate() + base_model: MultiModalBaseModelConfig = FieldOverride() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat, @@ -69,13 +69,13 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModa @config_class() class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): _abstract = False - model: MultiModalModelConfig = FieldUpdate() + model: MultiModalModelConfig = FieldOverride() @config_class(dynamic_type={RunnableConfig: "train_multimodal", TrainerConfig: "multimodal"}) class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): # TODO: Use dynamic model type? - reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() + reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldOverride() @classmethod def get_trainer_class(cls) -> type["MultiModalTrainer"]: diff --git a/fast_llm/profile.py b/fast_llm/profile.py index a3902cf1e..58a72764d 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -class ProfileType(str, enum.Enum): +class ProfileType(enum.StrEnum): cpu = "cpu" cuda = "cuda" diff --git a/mkdocs.yaml b/mkdocs.yaml index 00e52a011..56c79f520 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -77,6 +77,7 @@ theme: # Hooks hooks: - docs/overrides/hooks/shortcodes.py + - docs/overrides/hooks/generate_config_docs_hook.py # Additional configuration extra: @@ -157,6 +158,10 @@ plugins: branch: main - bibtex: bib_file: "docs/refs.bib" + enable_inline_citations: false + +exclude_docs: | + README.md nav: - Welcome: index.md @@ -182,6 +187,7 @@ nav: - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md + - Parallelism: developer_guide/parallelism.md - Model: - Model: developer_guide/model.md - Conversion: developer_guide/conversion.md @@ -190,5 +196,249 @@ nav: - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md - Testing: contributing/testing.md + - How to Release: contributing/how-to-release.md + # BEGIN AUTO-GENERATED CONFIG REFERENCE + - Configuration Reference: + - reference/configuration/index.md + - Data: + - reference/configuration/data/index.md + - Data: + - reference/configuration/data/data/index.md + - reference/configuration/data/data/DataConfig.md + - Gpt: + - reference/configuration/data/data/gpt/index.md + - reference/configuration/data/data/gpt/GPTDataConfig.md + - Dataset: + - reference/configuration/data/dataset/index.md + - reference/configuration/data/dataset/BlendedDatasetConfig.md + - reference/configuration/data/dataset/ConcatenatedDatasetConfig.md + - reference/configuration/data/dataset/DatasetConfig.md + - reference/configuration/data/dataset/DatasetSliceConfig.md + - reference/configuration/data/dataset/IndexedDatasetConfig.md + - reference/configuration/data/dataset/RedisConfig.md + - reference/configuration/data/dataset/SamplableDatasetConfig.md + - reference/configuration/data/dataset/SampledDatasetConfig.md + - reference/configuration/data/dataset/SamplingConfig.md + - reference/configuration/data/dataset/SamplingConfigBase.md + - reference/configuration/data/dataset/StreamingDatasetConfig.md + - Gpt: + - reference/configuration/data/dataset/gpt/index.md + - reference/configuration/data/dataset/gpt/FimConfig.md + - reference/configuration/data/dataset/gpt/GPTDatasetFromFileConfig.md + - reference/configuration/data/dataset/gpt/GPTFimSampledDatasetConfig.md + - reference/configuration/data/dataset/gpt/GPTRandomDatasetConfig.md + - reference/configuration/data/dataset/gpt/GPTSamplingConfig.md + - reference/configuration/data/dataset/gpt/GPTTestSlowDatasetConfig.md + - Memmap: + - reference/configuration/data/dataset/memmap/index.md + - reference/configuration/data/dataset/memmap/LanguageModelReaderConfig.md + - reference/configuration/data/dataset/memmap/MemmapDatasetConfig.md + - reference/configuration/data/dataset/memmap/MemmapIndexDatasetReaderConfig.md + - reference/configuration/data/dataset/memmap/MemmapReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/MemmapReaderConfig.md + - reference/configuration/data/dataset/memmap/NullReaderConfig.md + - reference/configuration/data/dataset/memmap/PatchReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/PatchReaderConfig.md + - reference/configuration/data/dataset/memmap/RangeReaderBaseConfig.md + - reference/configuration/data/dataset/memmap/RangeReaderConfig.md + - reference/configuration/data/dataset/memmap/TokenDataReaderConfig.md + - reference/configuration/data/dataset/memmap/TokenReaderConfig.md + - Document: + - reference/configuration/data/document/index.md + - reference/configuration/data/document/BatchPreprocessingConfig.md + - reference/configuration/data/document/ImageNormalizationConfig.md + - reference/configuration/data/document/LanguageModelBatchPreprocessingConfig.md + - reference/configuration/data/document/LengthPreprocessingConfig.md + - reference/configuration/data/document/PatchPreprocessingConfig.md + - reference/configuration/data/document/TokenPreprocessingConfig.md + - Preparation: + - reference/configuration/data/preparation/index.md + - reference/configuration/data/preparation/DatasetPreparatorConfig.md + - Dataset Discovery: + - reference/configuration/data/preparation/dataset_discovery/index.md + - reference/configuration/data/preparation/dataset_discovery/DatasetDiscoveryConfig.md + - Gpt Memmap: + - reference/configuration/data/preparation/gpt_memmap/index.md + - reference/configuration/data/preparation/gpt_memmap/ConversationSourceConfig.md + - reference/configuration/data/preparation/gpt_memmap/DatasetPreparatorDistributedConfig.md + - reference/configuration/data/preparation/gpt_memmap/DocumentSourceConfig.md + - reference/configuration/data/preparation/gpt_memmap/GPTHuggingfaceDatasetConfig.md + - reference/configuration/data/preparation/gpt_memmap/GPTMemmapDatasetPreparatorConfig.md + - reference/configuration/data/preparation/gpt_memmap/LanguageModelSourceConfig.md + - Image Patch: + - reference/configuration/data/preparation/image_patch/index.md + - reference/configuration/data/preparation/image_patch/ImagePreparationConfig.md + - Tokenizer: + - reference/configuration/data/preparation/tokenizer/index.md + - reference/configuration/data/preparation/tokenizer/TokenizerConfig.md + - Engine: + - reference/configuration/engine/index.md + - Base Model: + - reference/configuration/engine/base_model/index.md + - reference/configuration/engine/base_model/BaseModelConfig.md + - reference/configuration/engine/base_model/ModuleConfig.md + - Checkpoint: + - reference/configuration/engine/checkpoint/index.md + - reference/configuration/engine/checkpoint/CheckpointConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointLoadConfig.md + - reference/configuration/engine/checkpoint/CheckpointLoadMetadataConfig.md + - reference/configuration/engine/checkpoint/CheckpointPathConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointSaveConfig.md + - reference/configuration/engine/checkpoint/CheckpointSaveConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointSaveMetadataConfig.md + - reference/configuration/engine/checkpoint/CheckpointStateConfigBase.md + - reference/configuration/engine/checkpoint/CheckpointStateSaveConfigBase.md + - Config Utils: + - reference/configuration/engine/config_utils/index.md + - Initialization: + - reference/configuration/engine/config_utils/initialization/index.md + - reference/configuration/engine/config_utils/initialization/DefaultInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/FillInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/InitializationConfig.md + - reference/configuration/engine/config_utils/initialization/NormalInitializationConfig.md + - reference/configuration/engine/config_utils/initialization/UniformInitializationConfig.md + - Interval: + - reference/configuration/engine/config_utils/interval/index.md + - reference/configuration/engine/config_utils/interval/IntervalConfig.md + - Logging: + - reference/configuration/engine/config_utils/logging/index.md + - reference/configuration/engine/config_utils/logging/TensorLogsConfig.md + - Parameter: + - reference/configuration/engine/config_utils/parameter/index.md + - reference/configuration/engine/config_utils/parameter/OptionalParameterConfig.md + - reference/configuration/engine/config_utils/parameter/ParameterConfig.md + - Run: + - reference/configuration/engine/config_utils/run/index.md + - reference/configuration/engine/config_utils/run/ExperimentConfig.md + - reference/configuration/engine/config_utils/run/RunConfig.md + - Runnable: + - reference/configuration/engine/config_utils/runnable/index.md + - reference/configuration/engine/config_utils/runnable/RunnableConfig.md + - Distributed: + - reference/configuration/engine/distributed/index.md + - reference/configuration/engine/distributed/DistributedConfig.md + - Evaluation: + - reference/configuration/engine/evaluation/index.md + - reference/configuration/engine/evaluation/EvaluatorConfig.md + - reference/configuration/engine/evaluation/LmEvalEvaluatorConfig.md + - reference/configuration/engine/evaluation/LossEvaluatorConfig.md + - Multi Stage: + - reference/configuration/engine/multi_stage/index.md + - reference/configuration/engine/multi_stage/CheckpointMetadata.md + - reference/configuration/engine/multi_stage/FastLLMModelConfig.md + - reference/configuration/engine/multi_stage/MultiStageConfig.md + - reference/configuration/engine/multi_stage/PretrainedFastLLMModelConfig.md + - reference/configuration/engine/multi_stage/StageConfig.md + - Optimizer: + - reference/configuration/engine/optimizer/index.md + - reference/configuration/engine/optimizer/GradientScalerConfig.md + - reference/configuration/engine/optimizer/LearningRateScheduleConfig.md + - reference/configuration/engine/optimizer/OptimizerConfig.md + - Schedule: + - reference/configuration/engine/schedule/index.md + - reference/configuration/engine/schedule/ScheduleConfig.md + - Training: + - reference/configuration/engine/training/index.md + - reference/configuration/engine/training/CallbackConfig.md + - reference/configuration/engine/training/MetricsLogsConfig.md + - reference/configuration/engine/training/ShutdownConfig.md + - reference/configuration/engine/training/StreamingTrainerCallbackConfig.md + - reference/configuration/engine/training/TrainerCallbackConfig.md + - reference/configuration/engine/training/TrainerConfig.md + - reference/configuration/engine/training/TrainingCheckpointBaseConfig.md + - reference/configuration/engine/training/TrainingCheckpointConfig.md + - reference/configuration/engine/training/TrainingConfig.md + - reference/configuration/engine/training/TrainingExportConfig.md + - reference/configuration/engine/training/WandbAlertConfig.md + - reference/configuration/engine/training/WandbConfig.md + - reference/configuration/engine/training/WeightsBroadcastConfig.md + - Layers: + - reference/configuration/layers/index.md + - Attention: + - reference/configuration/layers/attention/index.md + - reference/configuration/layers/attention/AttentionConfig.md + - Rotary: + - reference/configuration/layers/attention/rotary/index.md + - reference/configuration/layers/attention/rotary/DefaultRotaryConfig.md + - reference/configuration/layers/attention/rotary/Llama3RotaryConfig.md + - reference/configuration/layers/attention/rotary/NoRotaryConfig.md + - reference/configuration/layers/attention/rotary/Rotary2DConfig.md + - reference/configuration/layers/attention/rotary/RotaryConfig.md + - reference/configuration/layers/attention/rotary/YarnRotaryConfig.md + - Block: + - reference/configuration/layers/block/index.md + - reference/configuration/layers/block/BlockConfig.md + - reference/configuration/layers/block/BlockSequenceConfig.md + - reference/configuration/layers/block/FixedBlockSequenceConfig.md + - reference/configuration/layers/block/PatternBlockSequenceConfig.md + - Common: + - reference/configuration/layers/common/index.md + - Linear: + - reference/configuration/layers/common/linear/index.md + - reference/configuration/layers/common/linear/AffineLinearBaseConfig.md + - reference/configuration/layers/common/linear/AffineLinearConfig.md + - reference/configuration/layers/common/linear/CausalConv1dConfig.md + - reference/configuration/layers/common/linear/LinearBaseConfig.md + - reference/configuration/layers/common/linear/LinearConfig.md + - Normalization: + - reference/configuration/layers/common/normalization/index.md + - reference/configuration/layers/common/normalization/GatedRMSNormalizationConfig.md + - reference/configuration/layers/common/normalization/LayerNormalizationBaseConfig.md + - reference/configuration/layers/common/normalization/LayerNormalizationConfig.md + - reference/configuration/layers/common/normalization/NoNormalizationConfig.md + - reference/configuration/layers/common/normalization/NormalizationConfig.md + - reference/configuration/layers/common/normalization/RMSNormalizationConfig.md + - Peft: + - reference/configuration/layers/common/peft/index.md + - reference/configuration/layers/common/peft/LoRAConfig.md + - reference/configuration/layers/common/peft/NoPeftConfig.md + - reference/configuration/layers/common/peft/PeftConfig.md + - Decoder: + - reference/configuration/layers/decoder/index.md + - reference/configuration/layers/decoder/BlockWithBiasConfig.md + - reference/configuration/layers/decoder/DecoderBlockConfig.md + - reference/configuration/layers/decoder/MLPBaseConfig.md + - reference/configuration/layers/decoder/MixerConfig.md + - reference/configuration/layers/decoder/StochasticMixerConfig.md + - Mlp: + - reference/configuration/layers/decoder/mlp/index.md + - reference/configuration/layers/decoder/mlp/MLPConfig.md + - reference/configuration/layers/decoder/mlp/MoEMLPConfig.md + - Language Model: + - reference/configuration/layers/language_model/index.md + - reference/configuration/layers/language_model/LanguageModelConfig.md + - reference/configuration/layers/language_model/LanguageModelEmbeddingsConfig.md + - reference/configuration/layers/language_model/LanguageModelHeadConfig.md + - Loss: + - reference/configuration/layers/language_model/loss/index.md + - reference/configuration/layers/language_model/loss/LanguageModelDPOLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelDistillationLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelGRPOLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelLabelEntropyLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelLossConfig.md + - reference/configuration/layers/language_model/loss/LanguageModelZLossConfig.md + - Vision: + - reference/configuration/layers/vision/index.md + - reference/configuration/layers/vision/PatchEmbeddingsConfig.md + - reference/configuration/layers/vision/VisionEncoderConfig.md + - reference/configuration/layers/vision/VisionMultiModalModelConfig.md + - Models: + - reference/configuration/models/index.md + - Gpt: + - reference/configuration/models/gpt/index.md + - reference/configuration/models/gpt/GPTBaseModelConfig.md + - reference/configuration/models/gpt/GPTModelConfig.md + - reference/configuration/models/gpt/GPTTrainerConfig.md + - reference/configuration/models/gpt/PretrainedGPTModelConfig.md + - Multimodal: + - reference/configuration/models/multimodal/index.md + - reference/configuration/models/multimodal/MultiModalBaseModelConfig.md + - reference/configuration/models/multimodal/MultiModalModelConfig.md + - reference/configuration/models/multimodal/MultiModalTrainerConfig.md + - reference/configuration/models/multimodal/PretrainedMultiModalModelConfig.md + - Profile: + - reference/configuration/profile/index.md + - reference/configuration/profile/ProfilingConfig.md + # END AUTO-GENERATED CONFIG REFERENCE - About Us: about-us.md - Join Us: join-us.md diff --git a/pyproject.toml b/pyproject.toml index 8488623d7..c119daad4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,12 @@ testpaths = [ "fast_llm_external_models/tests" # External models tests ] norecursedirs = ["Megatron-LM"] +filterwarnings = [ + # PYTHONHASHSEED is not set by pytest; DataLoader workers will use a deterministic seed anyway. + "ignore:PYTHONHASHSEED should be set:UserWarning", + # Python 3.14 will remove pickle/copy support from itertools; comes from multiprocessing internals. + "ignore:Pickle, copy, and deepcopy support will be removed from itertools:DeprecationWarning", +] [tool.isort] profile = "black" diff --git a/setup.cfg b/setup.cfg index 955702907..e035cc0c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ install_requires = CORE = # Available through the nvidia base image torch>=2.9.0 + # apex # Available through the nvidia base image, requires manual build with --cuda_ext --fast_layer_norm numpy>=2.1.0 # Used for checkpoints safetensors>=0.6.2 diff --git a/tests/config/common.py b/tests/config/common.py index b341bd0cb..4b54fbc5e 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -33,6 +33,12 @@ class ExampleConfig(Config): core_field: int = Field(default=4, hint=FieldHint.core) complex_field: dict[str, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + @classmethod + def _from_dict(cls, default: dict, strict: bool = True): + cls._handle_renamed_field(default, "old_int_field", "int_field") + cls._handle_renamed_field(default, "original_float_field", "float_field", fn=lambda value: value * 2) + return super()._from_dict(default, strict) + def _validate(self) -> None: with self._set_implicit_default(): if self.implicit_field is None: diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 4c473fa6d..4f7722c00 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -1,7 +1,54 @@ +import typing + import pytest +import yaml + +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldVerboseLevel, + NoAutoValidate, + UpdateType, + ValidationError, + config_class, +) +from fast_llm.utils import Assert, check_equal_nested, header +from tests.config.common import ExampleConfig, ExampleNestedConfig + +# --- Dynamic dispatch fixtures --- + + +@config_class(registry=True) +class AnimalConfig(Config): + name: str = Field(default="", hint=FieldHint.optional) + + +@config_class(dynamic_type={AnimalConfig: "dog"}) +class DogConfig(AnimalConfig): + breed: str = Field(default="mutt", hint=FieldHint.optional) + + +@config_class(dynamic_type={AnimalConfig: "cat"}) +class CatConfig(AnimalConfig): + indoor: bool = Field(default=True, hint=FieldHint.optional) + -from fast_llm.config import NoAutoValidate -from tests.config.common import ExampleConfig +# --- Verbose level fixtures --- + + +@config_class() +class ExampleHintConfig(Config): + """One field at each FieldHint importance level for testing verbose output filtering.""" + + core_field: int = Field(default=1, hint=FieldHint.core) + architecture_field: int = Field(default=2, hint=FieldHint.architecture) + optional_field: int = Field(default=3, hint=FieldHint.optional) + performance_field: int = Field(default=4, hint=FieldHint.performance) + expert_field: int = Field(default=5, hint=FieldHint.expert) + + +# --- Lifecycle --- def test_auto_validate(): @@ -28,3 +75,231 @@ def test_auto_validate(): assert not (config := ExampleConfig.from_dict({}))._validated config.validate() assert config._validated + + +def test_multiple_validation_errors_all_reported(): + with pytest.raises(ValidationError) as exc_info: + ExampleConfig.from_dict({"int_field": "not_an_int", "float_field": "not_a_float"}) + error_message = str(exc_info.value) + assert "int_field" in error_message + assert "float_field" in error_message + + +# --- compare() --- + + +def test_compare_equal_returns_none(): + config_a = ExampleConfig.from_dict({"int_field": 5}) + config_b = ExampleConfig.from_dict({"int_field": 5}) + assert config_a.compare(config_b) is None + + +def test_compare_different(): + config_a = ExampleConfig.from_dict({"int_field": 5}) + config_b = ExampleConfig.from_dict({"int_field": 7}) + with pytest.raises(ValueError): + config_a.compare(config_b) + # Custom log_fn receives the difference instead of raising. + messages = [] + config_a.compare(config_b, log_fn=messages.append) + assert len(messages) == 1 + + +# --- strict mode --- + + +@pytest.mark.parametrize( + ("config_dict", "cls"), + [ + ({"int_field": 3, "unknown_field": 5}, ExampleConfig), + ({"nested_field": {"int_field": 3, "unknown_sub_field": 5}}, ExampleNestedConfig), + ], + ids=["top_level", "nested"], +) +def test_strict_unknown_field_raises(config_dict, cls): + with pytest.raises(ValidationError): + cls.from_dict(config_dict) + + +def test_strict_false_unknown_field_ignored(): + config = ExampleConfig.from_dict({"int_field": 3, "unknown_field": 5}, strict=False) + assert config.int_field == 3 + assert not hasattr(config, "unknown_field") + + +def test_strict_false_unknown_nested_field_ignored(): + config = ExampleNestedConfig.from_dict({"nested_field": {"int_field": 3, "unknown_sub_field": 5}}, strict=False) + assert config.nested_field.int_field == 3 + + +# --- Dynamic dispatch --- + + +@pytest.mark.parametrize( + ("input_dict", "expected_cls", "expected_field", "expected_value"), + [ + ({"type": "dog", "breed": "labrador"}, DogConfig, "breed", "labrador"), + ({"type": "cat", "indoor": False}, CatConfig, "indoor", False), + ], + ids=["dog", "cat"], +) +def test_dynamic_dispatch_selects_subclass(input_dict, expected_cls, expected_field, expected_value): + config = AnimalConfig.from_dict(input_dict) + assert isinstance(config, expected_cls) + Assert.eq(getattr(config, expected_field), expected_value) + + +def test_dynamic_dispatch_type_serialized(): + config = DogConfig.from_dict({"breed": "poodle"}) + result = config.to_dict() + Assert.eq(result["type"], "dog") + Assert.eq(result["breed"], "poodle") + + +def test_dynamic_dispatch_unknown_type_raises(): + with pytest.raises(ValidationError): + AnimalConfig.from_dict({"type": "fish"}) + + +def test_dynamic_dispatch_roundtrip(): + original = DogConfig.from_dict({"breed": "husky"}) + roundtrip = AnimalConfig.from_dict(original.to_dict()) + assert isinstance(roundtrip, DogConfig) + Assert.eq(roundtrip.breed, "husky") + + +# --- Renamed fields --- + + +def test_renamed_field(): + with pytest.warns(DeprecationWarning, match="old_int_field"): + config = ExampleConfig.from_dict({"old_int_field": 5}) + Assert.eq(config.int_field, 5) + # New name works without a deprecation warning. + Assert.eq(ExampleConfig.from_dict({"int_field": 7}).int_field, 7) + + +def test_renamed_field_with_transform(): + with pytest.warns(DeprecationWarning, match="original_float_field"): + config = ExampleConfig.from_dict({"original_float_field": 4.0}) + Assert.eq(config.float_field, 8.0) + + +# --- Verbose levels --- + +# At verbose >= optional (10), the base Config.type field (hint=feature, importance=10) also appears. +_VERBOSE_LEVEL_CASES = [ + (FieldVerboseLevel.explicit, {}), + (FieldVerboseLevel.core, {"core_field": 1, "architecture_field": 2}), + (FieldVerboseLevel.optional, {"core_field": 1, "architecture_field": 2, "optional_field": 3, "type": None}), + ( + FieldVerboseLevel.performance, + {"core_field": 1, "architecture_field": 2, "optional_field": 3, "performance_field": 4, "type": None}, + ), + ( + FieldVerboseLevel.debug, + { + "core_field": 1, + "architecture_field": 2, + "optional_field": 3, + "performance_field": 4, + "expert_field": 5, + "type": None, + }, + ), +] + + +@pytest.mark.parametrize(("verbose", "expected"), _VERBOSE_LEVEL_CASES) +def test_verbose_level(verbose, expected): + check_equal_nested(ExampleHintConfig.from_dict({}).to_dict(verbose=verbose), expected) + + +# --- Field definition error fixtures --- + + +with pytest.raises(ValueError, match="default_factory"): + # Defining this at module level triggers Field.__init__ validation immediately. + @config_class() + class _BothDefaultAndFactoryConfig(Config): + x: list = Field(default=[], default_factory=list, hint=FieldHint.optional) + + +with pytest.raises(ValueError, match="default_factory"): + + @config_class() + class _ConfigAsDefaultFactoryConfig(Config): + nested: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.optional) + + +with pytest.raises(TypeError, match="__post_init__"): + + @config_class() + class _PostInitConfig(Config): + def __post_init__(self): + pass + + +@config_class() +class _AbstractConfig(Config): + _abstract: typing.ClassVar[bool] = True + + +# --- Abstract config --- + + +def test_abstract_config_raises(): + with pytest.raises(ValidationError, match="abstract"): + _AbstractConfig() + + +# --- Delete on validated config --- + + +def test_delattr_after_validation_raises(): + config = ExampleConfig.from_dict({}) + with pytest.raises(RuntimeError, match="delete"): + del config.int_field + + +# --- to_logs / __repr__ --- + + +@pytest.mark.parametrize( + ("cls", "config_dict", "expected_core_dict"), + [ + (ExampleConfig, {}, {"core_field": 4}), + (ExampleConfig, {"int_field": 3}, {"int_field": 3, "core_field": 4}), + ( + ExampleConfig, + {"int_field": 3, "str_field": "hello"}, + {"int_field": 3, "str_field": "hello", "core_field": 4}, + ), + ( + ExampleNestedConfig, + {"nested_field": {"int_field": 5}}, + {"core_field": 4, "nested_field": {"int_field": 5, "core_field": 4}}, + ), + ], +) +def test_repr_and_to_logs(cls, config_dict, expected_core_dict): + config = cls.from_dict(config_dict) + expected = ( + f"\n{header(config._get_class_name(), 80, '-')}" + f"\n{yaml.safe_dump(expected_core_dict, sort_keys=False)}" + f"{header('end', 80, '-')}" + ) + Assert.eq(repr(config), expected) + messages = [] + config.to_logs(log_fn=messages.append) + Assert.eq(len(messages), 1) + Assert.eq(messages[0], expected) + + +# --- Validated config as update --- + + +def test_validated_config_as_update_raises(): + validated = ExampleConfig.from_dict({"int_field": 1}) + with pytest.raises(ValueError, match="Validated"): + ExampleConfig.from_dict({}, validated, update_type=UpdateType.update) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index bc5881167..2a49c8c60 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -1,203 +1,331 @@ +import dataclasses +import functools import math import pathlib +from typing import Any import numpy import pytest -from fast_llm.config import FieldVerboseLevel +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldOverride, + FieldVerboseLevel, + check_field, + config_class, + process_field, + skip_valid_if_none, +) from fast_llm.utils import Assert, check_equal_nested -from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig, check_config, check_invalid_config - - -def test_create_and_serialize_config(): - Assert.eq(ExampleConfig.from_dict({}).to_dict(), {}) - - -@pytest.mark.parametrize("value", (0, -6, 3)) -def test_int_field(value): - check_config({"int_field": value}) - - -@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4], True)) -def test_int_field_invalid(value): - check_invalid_config({"int_field": value}) - - -@pytest.mark.parametrize("value", (True, False)) -def test_bool_field(value): - check_config({"bool_field": value}) - - -@pytest.mark.parametrize("value", (1, "True", None, [True])) -def test_bool_field_invalid(value): - check_invalid_config({"bool_field": value}) - - -@pytest.mark.parametrize("value", ("", "text", "1")) -def test_str_field(value): - check_config({"str_field": str(value)}, {"str_field": value}) - - -@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a)) -def test_str_field_invalid(value): - check_invalid_config({"str_field": value}) - - -@pytest.mark.parametrize("value", (".", "text", "/a/b/c.d")) -def test_path_field(value): - check_config({"path_field": pathlib.Path(value)}, {"path_field": value}) - - -@pytest.mark.parametrize("value", (1, True, None, [pathlib.Path("a")])) -def test_path_field_invalid(value): - check_invalid_config({"path_field": value}) - - -@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, math.nan)) -def test_float_field(value): - check_config( - {"float_field": float(value)}, {"float_field": value}, serialized_config={"float_field": float(value)} - ) - - -@pytest.mark.parametrize("value", (None, [4.7], "0.0", True, numpy.float64(3))) -def test_float_field_invalid(value): - check_invalid_config({"float_field": value}) - - -@pytest.mark.parametrize("value", ("", None, "text")) -def test_optional_field(value): - check_config({"optional_field": value}) - - -@pytest.mark.parametrize("value", (True, 6, [None])) -def test_optional_field_invalid(value): - check_invalid_config({"optional": value}) - - -@pytest.mark.parametrize("value", ("", 0, "text", 7)) -def test_union_field(value): - check_config({"union_field": value}) - - -@pytest.mark.parametrize("value", (6.0, [""], True)) -def test_union_field_invalid(value): - check_invalid_config({"optional": value}) - - -def test_implicit_field_value(): - Assert.eq(ExampleConfig.from_dict({}).implicit_field, "implicit") - - -@pytest.mark.parametrize("value", ("implicit", "", "text")) -def test_implicit_field(value): - check_config({"implicit_field": value}) - - -ARRAY_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) -ARRAY_VALUES_INVALID = (6.0, {}, True, "text") - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_list_field(value): - check_config( - {"list_field": list(value)}, - {"list_field": value}, - serialized_config={"list_field": list(value)}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_list_field_invalid(value): - check_invalid_config({"list_field": value}) - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_tuple_field(value): - check_config( - {"tuple_field": list(value)}, - {"tuple_field": value}, - serialized_config={"tuple_field": list(value)}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_tuple_field_invalid(value): - check_invalid_config({"tuple_field": value}) - - -@pytest.mark.parametrize("value", ARRAY_VALUES) -def test_set_field(value): - check_config( - {"set_field": list(set(value))}, - {"set_field": set(value)}, - {"set_field": list(value)}, - {"set_field": tuple(value)}, - serialized_config={"set_field": list(set(value))}, - ) - - -@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) -def test_tuple_field_invalid(value): - check_invalid_config({"set_field": value}) - - -@pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) -def test_dict_field(value): - check_config({"dict_field": value}) - - -@pytest.mark.parametrize("value", ({True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text")) -def test_dict_field_invalid(value): - check_invalid_config({"dict_field": value}) +from tests.config.common import ( + ExampleConfig, + ExampleEnum, + ExampleNestedConfig, + ExampleVerboseConfig, + check_config, + check_invalid_config, +) -class IntClass(int): +class IntSubclass(int): pass -@pytest.mark.parametrize("value", (int, bool, IntClass)) -def test_type_field(value): - check_config({"type_field": value}, serialized_config={"type_field": str(value)}) - - -@pytest.mark.parametrize("value", (5, None, [], "text")) -def test_type_field_invalid(value): - check_invalid_config({"type_field": value}) +# --- Validator configs (referenced in _FIELD_TEST_CASES) --- -@pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) -def test_enum_field(value): - check_config({"enum_field": value}, {"enum_field": str(value)}) +@config_class() +class ExampleCheckFieldConfig(Config): + positive_field: int = Field(default=0, hint=FieldHint.optional, valid=check_field(Assert.geq, 0)) -@pytest.mark.parametrize("value", (5, None, [], "text")) -def test_enum_field_invalid(value): - check_invalid_config({"type_field": value}) +@config_class() +class ExampleSkipIfNoneConfig(Config): + optional_positive_field: int | None = Field( + default=None, + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) -def test_core_field(): - Assert.eq(ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core), {"core_field": 4}) +@config_class() +class ExampleProcessFieldConfig(Config): + doubled_field: int = Field(default=0, hint=FieldHint.optional, valid=process_field(lambda value: value * 2)) + + +# --- FieldOverride configs --- + + +@config_class() +class ExampleUpdatedDefaultConfig(ExampleConfig): + int_field = FieldOverride(default=42) + + +@config_class() +class ExampleUpdatedHintConfig(ExampleConfig): + # Promote str_field from optional to core so it appears at verbose=core. + str_field = FieldOverride(hint=FieldHint.core) + + +@dataclasses.dataclass +class ValidCase: + # Canonical Python-side value. Used as input to from_dict() and as expected to_dict(serialized=False) result. + internal: Any + # Expected to_dict() result. Defaults to internal. + serialized: Any = None + # Other input values that should produce the same internal/serialized result. + alternates: tuple = () + + def __post_init__(self): + if self.serialized is None: + self.serialized = self.internal + + +@dataclasses.dataclass +class FieldTestCase: + field_name: str + valid: list[ValidCase] + invalid: list[Any] + cls: type = ExampleConfig + # When the config class has other fields with non-empty defaults, check only this field. + fields: list[str] | None = None + + @functools.cached_property + def params(self) -> list: + return [ + *( + pytest.param( + self.field_name, + self.cls, + valid_case.internal, + valid_case, + self.fields, + id=f"{self.field_name}-{valid_case.internal!r}", + ) + for valid_case in self.valid + ), + *( + pytest.param( + self.field_name, + self.cls, + invalid_value, + None, + self.fields, + id=f"{self.field_name}-invalid-{invalid_value!r}", + ) + for invalid_value in self.invalid + ), + ] + + +_FIELD_TEST_CASES: list[FieldTestCase] = [ + FieldTestCase( + field_name="int_field", + valid=[ValidCase(0), ValidCase(-6), ValidCase(3)], + # Rejects float (even if whole number), bool, string, None, list. + invalid=[4.0, math.inf, "1", None, [4], True], + ), + FieldTestCase( + field_name="bool_field", + valid=[ValidCase(True), ValidCase(False)], + # Rejects int (bool is a subclass of int but the reverse is not accepted), string, None, list. + invalid=[1, "True", None, [True]], + ), + FieldTestCase( + field_name="str_field", + valid=[ValidCase(""), ValidCase("text"), ValidCase("1")], + # Rejects int, bool, None, list, Path, Enum. + invalid=[1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a], + ), + FieldTestCase( + field_name="path_field", + valid=[ + # Stores as pathlib.Path; serializes to string; accepts string input. + ValidCase(pathlib.Path("."), serialized=".", alternates=(".",)), + ValidCase(pathlib.Path("text"), serialized="text", alternates=("text",)), + ValidCase(pathlib.Path("/a/b/c.d"), serialized="/a/b/c.d", alternates=("/a/b/c.d",)), + ], + # Rejects int, bool, None, list. + invalid=[1, True, None, [pathlib.Path("a")]], + ), + FieldTestCase( + field_name="float_field", + valid=[ + # Accepts int and float; stores and serializes as float; inf and nan are valid. + ValidCase(4.0), + ValidCase(math.pi), + ValidCase(math.inf), + ValidCase(math.nan), + ValidCase(3.0, alternates=(3,)), # int input coerced to float + ], + # Rejects None, list, string, bool, numpy scalar. + invalid=[None, [4.7], "0.0", True, numpy.float64(3)], + ), + FieldTestCase( + field_name="optional_field", + valid=[ValidCase(None), ValidCase(""), ValidCase("text")], + # Rejects bool, int, list. + invalid=[True, 6, [None]], + ), + FieldTestCase( + field_name="union_field", + valid=[ValidCase(""), ValidCase(0), ValidCase("text"), ValidCase(7)], + # Rejects float, list, bool. + invalid=[6.0, [""], True], + ), + FieldTestCase( + field_name="implicit_field", + valid=[ + # _validate() sets "implicit" when not provided; explicit value overrides. + ValidCase("implicit"), + ValidCase(""), + ValidCase("text"), + ], + invalid=[], # Any string is valid; invalids are covered by str_field tests. + ), + FieldTestCase( + field_name="list_field", + valid=[ + # Stores as list; accepts tuple input; duplicates preserved. + ValidCase([]), + ValidCase([1], alternates=((1,),)), + ValidCase([3, 4, 6], alternates=((3, 4, 6),)), + ValidCase([4, 5, 4], alternates=((4, 5, 4),)), + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="tuple_field", + valid=[ + # Stores as tuple; serializes as list; accepts list or tuple input. + ValidCase([], serialized=[], alternates=((),)), + ValidCase([1], serialized=[1], alternates=((1,),)), + ValidCase([3, 4, 6], serialized=[3, 4, 6], alternates=((3, 4, 6),)), + ValidCase([4, 5, 4], serialized=[4, 5, 4], alternates=((4, 5, 4),)), + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="set_field", + valid=[ + # Deduplicates; serializes as list; accepts list/tuple/set input. + # Note: CPython iterates small-int sets in insertion/hash order, matching sorted order here. + ValidCase([], serialized=[], alternates=(set(), ())), + ValidCase([1], serialized=[1], alternates=({1}, (1,))), + ValidCase([3, 4, 6], serialized=[3, 4, 6], alternates=({3, 4, 6}, (3, 4, 6))), + ValidCase([4, 5], serialized=[4, 5], alternates=({4, 5}, [4, 5, 4], (4, 5, 4))), # deduplication + ], + # Rejects float, dict, bool, string. + invalid=[6.0, {}, True, "text"], + ), + FieldTestCase( + field_name="dict_field", + valid=[ValidCase({}), ValidCase({1: 2, 3: 4})], + # Rejects bool keys, wrong value types, nested dict values, None, int, set, list, string. + invalid=[{True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text"], + ), + FieldTestCase( + field_name="type_field", + valid=[ + # Accepts type objects that are subclasses of int; serializes as repr string. + ValidCase(int, serialized=str(int)), + ValidCase(bool, serialized=str(bool)), + ValidCase(IntSubclass, serialized=str(IntSubclass)), + ], + # Rejects non-type values. + invalid=[5, None, [], "text"], + ), + FieldTestCase( + field_name="enum_field", + valid=[ + # Accepts enum values and their string equivalents; serializes as string. + ValidCase(ExampleEnum.a, serialized="a", alternates=("a",)), + ValidCase(ExampleEnum.b, serialized="b", alternates=("b",)), + ValidCase(ExampleEnum.c, serialized="c", alternates=("c",)), + ], + # Rejects non-string, None, list, and strings not in the enum. + invalid=[5, None, [], "d"], + ), + FieldTestCase( + field_name="complex_field", + valid=[ + ValidCase({}), + ValidCase({"3": None, "text": [], "0": [["", 3], ["a", -7]]}), + ValidCase({"0": [[".", 8]]}), + ], + # Rejects non-string dict keys. + invalid=[{False: [["", 3]]}], + ), + FieldTestCase( + field_name="tuple_fixed_length_field", + valid=[ + # Fixed-length (int, str) tuple; stores and serializes as list; accepts list or tuple input. + ValidCase([0, ""], alternates=((0, ""),)), + ValidCase([5, "text"], alternates=((5, "text"),)), + ValidCase([7, "True"], alternates=((7, "True"),)), + ], + # Rejects wrong length (too short/long) and wrong element types. + invalid=[(), (5,), ("", 0), ("0", "True"), (0, "", "text")], + cls=ExampleVerboseConfig, + fields=["tuple_fixed_length_field"], + ), + FieldTestCase( + field_name="nested_field", + valid=[ + # Non-empty sub-configs only: empty nested config serializes back to {} (no nested_field key). + ValidCase({"int_field": 3}), + ValidCase({"int_field": 3, "str_field": "text"}), + ValidCase({"list_field": [1, 2], "dict_field": {1: 2}}), + ], + # Rejects None, non-dict, and dicts with invalid sub-field values. + invalid=[None, 5, {"int_field": "not_an_int"}, {"int_field": True}], + cls=ExampleNestedConfig, + ), + FieldTestCase( + field_name="positive_field", + valid=[ValidCase(0), ValidCase(5)], + # Rejects values failing check_field(>=0); type invalids already covered by int_field tests. + invalid=[-1], + cls=ExampleCheckFieldConfig, + ), + FieldTestCase( + field_name="optional_positive_field", + valid=[ValidCase(None), ValidCase(0), ValidCase(5)], + # Rejects negative values; None bypasses the validator (skip_valid_if_none). + invalid=[-1], + cls=ExampleSkipIfNoneConfig, + ), +] @pytest.mark.parametrize( - "value", - ( - {}, - {"3": None, "text": [], "0": [["", 3], ["a", -7]]}, - {"0": [[".", 8]]}, - ), + ("field_name", "cls", "value", "expected", "fields"), + [case for field_test_case in _FIELD_TEST_CASES for case in field_test_case.params], ) -def test_complex_field(value): - check_config({"complex_field": value}) +def test_field(field_name: str, cls: type, value: Any, expected: ValidCase | None, fields: list[str] | None): + if expected is None: + check_invalid_config({field_name: value}, cls=cls) + else: + check_config( + {field_name: value}, + *({field_name: alternate} for alternate in expected.alternates), + serialized_config={field_name: expected.serialized}, + cls=cls, + fields=fields, + ) -@pytest.mark.parametrize( - "value", - ({"3": None, "text": [], False: [["", 3], ["a", -7]]},), -) -def test_complex_field_invalid(value): - check_invalid_config({"complex_field": value}) +def test_implicit_field_value(): + # When implicit_field is not provided, _validate() fills it in as "implicit". + config = ExampleConfig.from_dict({}) + Assert.eq(config.implicit_field, "implicit") + # Implicitly-set fields are not included in the serialized dict; all other fields are default, + # so the empty config serializes back to {}. + Assert.eq(config.to_dict(), {}) def test_verbose_config_default(): @@ -214,21 +342,23 @@ def test_verbose_config_default(): check_equal_nested(config.to_dict(serialized=False), default_values) -@pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) -def test_tuple_fixed_length_field(value): - check_config( - {"tuple_fixed_length_field": list(value)}, - {"tuple_fixed_length_field": value}, - serialized_config={"tuple_fixed_length_field": list(value)}, - cls=ExampleVerboseConfig, - fields=["tuple_fixed_length_field"], - ) +def test_nested_field_empty(): + # An empty sub-config is accepted; sub-fields take their defaults. + config = ExampleNestedConfig.from_dict({"nested_field": {}}) + Assert.eq(config.nested_field.int_field, 0) + Assert.eq(config.nested_field.str_field, "") + + +def test_process_field_transforms_value(): + # process_field transforms the value during validation; input 5 is stored as 10. + Assert.eq(ExampleProcessFieldConfig.from_dict({"doubled_field": 5}).doubled_field, 10) -@pytest.mark.parametrize("value", ((), (5,), ("", 0), ("0", "True"), (0, "", "text"))) -def test_tuple_fixed_length_field_invalid(value): - check_invalid_config({"tuple_fixed_length_field": value}, cls=ExampleVerboseConfig) +def test_field_update_default(): + Assert.eq(ExampleUpdatedDefaultConfig.from_dict({}).int_field, 42) + Assert.eq(ExampleConfig.from_dict({}).int_field, 0) # parent default unchanged -# TODO: Test other fields with defaults. -# TODO: Test nested fields. +def test_field_update_hint(): + assert "str_field" in ExampleUpdatedHintConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core) + assert "str_field" not in ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core) # parent unchanged diff --git a/tests/config/test_update.py b/tests/config/test_update.py index 525c47694..eea1a5252 100644 --- a/tests/config/test_update.py +++ b/tests/config/test_update.py @@ -6,27 +6,58 @@ TEST_CONFIGS = ( ( - # Empty config + # Empty config: updating nothing changes nothing. {}, {}, {}, None, ), ( - # Update unset field; don't update set field; update + # Flat fields: update adds new fields and overwrites shared fields; unrelated base fields survive. {"int_field": 4, "str_field": "text"}, {"float_field": 3.0, "str_field": ""}, {"int_field": 4, "float_field": 3.0, "str_field": ""}, None, ), ( - # Update/override nested field. + # Nested field: update merges sub-fields; override replaces the whole nested config. {"nested_field": {"int_field": 4, "str_field": "text"}}, {"nested_field": {"float_field": 3.0, "str_field": ""}}, {"nested_field": {"int_field": 4, "float_field": 3.0, "str_field": ""}}, {"nested_field": {"float_field": 3.0, "str_field": ""}}, ), - # TODO: Add more complex cases + ( + # Top-level and nested fields together: top-level fields and nested sub-fields both update correctly. + {"int_field": 1, "nested_field": {"int_field": 4, "str_field": "text"}}, + {"str_field": "new", "nested_field": {"float_field": 3.0}}, + { + "int_field": 1, + "str_field": "new", + "nested_field": {"int_field": 4, "float_field": 3.0, "str_field": "text"}, + }, + {"int_field": 1, "str_field": "new", "nested_field": {"float_field": 3.0}}, + ), + ( + # Update from empty: base has no fields set; all update fields appear in result. + {}, + {"int_field": 7, "str_field": "hello"}, + {"int_field": 7, "str_field": "hello"}, + None, + ), + ( + # Update to empty: update has no fields set; base is preserved unchanged. + {"int_field": 7, "str_field": "hello"}, + {}, + {"int_field": 7, "str_field": "hello"}, + None, + ), + ( + # Collection fields: list and dict fields in update replace their counterparts in base. + {"int_field": 1, "list_field": [1, 2, 3]}, + {"list_field": [4, 5]}, + {"int_field": 1, "list_field": [4, 5]}, + None, + ), ) diff --git a/tests/conftest.py b/tests/conftest.py index 1d3264103..43f1fc65f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -277,7 +277,7 @@ def worker_resources(request) -> WorkerResources: return request.config.worker_resources -@pytest.mark.trylast +@pytest.hookimpl(trylast=True) def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" diff --git a/tests/data/common.py b/tests/data/common.py index 5771f9b11..1a17695bc 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -167,7 +167,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s token_ids = torch.stack([LanguageModelBatch.from_documents(sampled[i]).tokens for i in range(len(sampled))]).to( torch.int64 ) - Assert.all_equal(token_ids, validate_samples) + Assert.all_equal(token_ids, np.array(validate_samples)) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids diff --git a/tests/data/conftest.py b/tests/data/conftest.py new file mode 100644 index 000000000..306351471 --- /dev/null +++ b/tests/data/conftest.py @@ -0,0 +1,8 @@ +import pathlib + +import pytest + + +@pytest.fixture(scope="session") +def data_result_path(result_path: pathlib.Path) -> pathlib.Path: + return result_path / "data" diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 5d72c7152..edbe479cc 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -106,7 +106,7 @@ def test_blending(probs): Assert.all_equal(samples, samples_alt) -def test_gpt_blended(): +def test_gpt_blended(data_result_path): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() _, alt_config, _, _ = get_alt_test_dataset() @@ -127,10 +127,11 @@ def test_gpt_blended(): sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "blended", ) -def test_gpt_blended_mixed(): +def test_gpt_blended_mixed(data_result_path): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # Random dataset needs an explicit vocab size. @@ -155,4 +156,5 @@ def test_gpt_blended_mixed(): sequence_length=5, expected_samples=GPT_BLENDED_MIXED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "blended_mixed", ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 200a771f7..6774374bb 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -23,7 +23,7 @@ ] -def test_gpt_concatenate(): +def test_gpt_concatenate(data_result_path): # Make sure the dataset concatenation works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() @@ -47,4 +47,5 @@ def test_gpt_concatenate(): sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "concatenate", ) diff --git a/tests/data/test_dataset_discovery.py b/tests/data/test_dataset_discovery.py index e94da8499..0dd9c31a4 100644 --- a/tests/data/test_dataset_discovery.py +++ b/tests/data/test_dataset_discovery.py @@ -141,11 +141,11 @@ ), ) def test_dataset_discovery( - result_path: pathlib.Path, name: str, paths: tuple[pathlib.Path], ignore_paths, expected_config: dict + data_result_path: pathlib.Path, name: str, paths: tuple[pathlib.Path], ignore_paths, expected_config: dict ): """Test end-to-end discovery with multiple datasets in various structure.""" test_dataset_path = [get_common_test_dataset()[0], get_alt_test_dataset()[0]] - (dataset_path := result_path / f"dataset_discovery/{name}").mkdir(parents=True) + (dataset_path := data_result_path / f"dataset_discovery/{name}").mkdir(parents=True) for index, path in enumerate(paths): (path_ := dataset_path / path).mkdir(parents=True, exist_ok=True) shutil.copy( @@ -157,7 +157,7 @@ def test_dataset_discovery( # Run dataset discovery config = DatasetDiscoveryConfig( directory=dataset_path, - output=result_path / f"dataset_discovery/configs/{name}.yaml", + output=data_result_path / f"dataset_discovery/configs/{name}.yaml", ignore_paths=ignore_paths, ) config.run() diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index ec6ac3011..25e42fb97 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -20,7 +20,7 @@ ] -def test_gpt_fim(): +def test_gpt_fim(data_result_path): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. @@ -45,4 +45,5 @@ def test_gpt_fim(): sequence_length=5, expected_samples=GPT_FIM_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "fim", ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 9f0d4d9c6..9b7941600 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -15,7 +15,7 @@ ] -def test_gpt_random_dataset(): +def test_gpt_random_dataset(data_result_path): # Make sure the random dataset works and check for unintended changes in behavior. preprocessing = LanguageModelBatchPreprocessingConfig(vocab_size=8192) sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( @@ -30,4 +30,5 @@ def test_gpt_random_dataset(): sequence_length=7, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, preprocessing=preprocessing, + cache_directory=data_result_path / "random", ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2c753a98f..c45160ac2 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,9 +1,13 @@ +import dataclasses +import functools +import pathlib + import numpy as np import pytest import torch from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingConfig from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.utils import Assert @@ -35,24 +39,6 @@ ] -def test_gpt_sampled(): - # Make sure the memmap dataset works and check for unintended changes in behavior. - _, config, _, preprocessing = get_common_test_dataset() - sampled = get_dataset_config( - dataset_config := config, GPTDatasetFromFileConfig[LanguageModelDocument] - ).build_and_sample(*get_sampling_config(8, sequence_length=5, preprocessing=preprocessing)) - validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) - - # Test in data. - get_test_data_and_compare_samples( - {"datasets": {"training": dataset_config}}, - 8, - sequence_length=5, - expected_samples=GPT_MEMMAP_SAMPLES, - preprocessing=preprocessing, - ) - - class SimpleGPTIndexedDataset[DocumentType: LanguageModelDocument](IndexedDataset[DocumentType]): # TODO: worth adding to the main codebase? def __init__(self, samples): @@ -72,6 +58,7 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return len(self._samples[index]) + @property def name(self) -> str: return "dataset" @@ -87,6 +74,130 @@ def name(self) -> str: ] ) +# Document sizes: 3, 5, 2, 4, 6. +# With maximum_document_length=4, truncate_documents=False: docs of size 5 and 6 are dropped. +# With maximum_document_length=4, truncate_documents=True: docs of size 5 and 6 are split into chunks of ≤4. +TRUNCATE_DATASET = SimpleGPTIndexedDataset( + [ + [0, 1, 2], # length 3 — fits + [3, 4, 5, 6, 7], # length 5 — exceeds maximum_document_length=4 + [8, 9], # length 2 — fits + [10, 11, 12, 13], # length 4 — exactly at limit + [14, 15, 16, 17, 18, 19], # length 6 — exceeds + ] +) + + +@dataclasses.dataclass +class SamplingTestConfig: + name: str + num_samples: int + sequence_length: int = 5 + seed: int = 54983 + shuffle: ShufflingType = ShufflingType.epoch + truncate_documents: bool = True + maximum_document_length: int | None = None + expected_samples: list[list[int]] | None = None + # Tokens that must not appear in any sample (validated for drop/filter cases). + # Defaults to empty — the check is always run but trivially passes. + forbidden_tokens: frozenset[int] = frozenset() + # Tokens that must collectively appear across all samples (validated for truncate cases). + # Defaults to empty — the check is always run but trivially passes. + required_tokens: frozenset[int] = frozenset() + requires_extension: bool = False + dataset: SimpleGPTIndexedDataset | None = dataclasses.field(default=None, compare=False, repr=False) + + @functools.cached_property + def sampling_config_overrides(self) -> dict: + if self.maximum_document_length is not None: + return {"maximum_document_length": self.maximum_document_length} + return {} + + +_SAMPLING_TEST_CASES = [ + SamplingTestConfig( + name="simple", + num_samples=20, + ), + SamplingTestConfig( + # With truncate_documents=False, documents exceeding maximum_document_length are dropped entirely. + # Only the 3 docs with length ≤ 4 contribute tokens: [0,1,2], [8,9], [10,11,12,13] = 9 tokens. + name="maximum_document_length_drop", + num_samples=2, + sequence_length=4, + shuffle=ShufflingType.disabled, + truncate_documents=False, + maximum_document_length=4, + forbidden_tokens=frozenset(range(3, 8)) | frozenset(range(14, 20)), + dataset=TRUNCATE_DATASET, + requires_extension=True, + ), + SamplingTestConfig( + # With truncate_documents=True, documents exceeding maximum_document_length are split into chunks. + # All tokens should appear in the output; none should be dropped. + name="maximum_document_length_truncate", + num_samples=10, + sequence_length=4, + shuffle=ShufflingType.disabled, + truncate_documents=True, + maximum_document_length=4, + required_tokens=frozenset(range(20)), + dataset=TRUNCATE_DATASET, + ), +] + + +@pytest.mark.parametrize("test_config", [pytest.param(c, id=c.name) for c in _SAMPLING_TEST_CASES]) +def test_sampling(test_config: SamplingTestConfig): + if test_config.requires_extension and not _extension_available: + pytest.skip("CPP Extension not available") + + dataset = test_config.dataset if test_config.dataset is not None else TEST_DATASET + base_config, num_samples, seed = get_sampling_config( + test_config.num_samples, + sequence_length=test_config.sequence_length, + seed=test_config.seed, + shuffle=test_config.shuffle, + truncate_documents=test_config.truncate_documents, + ) + sampling_config = GPTSamplingConfig.from_dict(base_config.to_dict(), test_config.sampling_config_overrides) + sampled = dataset.sample(sampling_config, num_samples, seed) + + # validate_indexed_dataset_sampling's reference implementation concatenates tokens without padding, + # so it only applies when truncate_documents=True (no padding between documents). + if test_config.truncate_documents: + tokens = validate_indexed_dataset_sampling(sampled, test_config.expected_samples) + else: + tokens = torch.stack( + [ + LanguageModelBatch.from_documents(sampled[i], test_config.sequence_length + 1).tokens + for i in range(len(sampled)) + ] + ) + + valid_tokens = set(tokens[tokens >= 0].tolist()) + assert test_config.forbidden_tokens.isdisjoint(valid_tokens) + assert test_config.required_tokens.issubset(valid_tokens) + + +def test_gpt_sampled(data_result_path: pathlib.Path): + # Make sure the memmap dataset works and check for unintended changes in behavior. + _, config, _, preprocessing = get_common_test_dataset() + sampled = get_dataset_config( + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelDocument] + ).build_and_sample(*get_sampling_config(8, sequence_length=5, preprocessing=preprocessing)) + validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) + + # Test in data. + get_test_data_and_compare_samples( + {"datasets": {"training": dataset_config}}, + 8, + sequence_length=5, + expected_samples=GPT_MEMMAP_SAMPLES, + preprocessing=preprocessing, + cache_directory=data_result_path / "sampling/gpt_sampled", + ) + @pytest.mark.parametrize("seed", (0, 32, 88)) @pytest.mark.parametrize( @@ -111,6 +222,42 @@ def test_gpt_sample(seed, shuffle): previous_samples = samples +@pytest.mark.parametrize("token_cumsum_rate", (1, 3, 7, 20)) +def test_token_cumsum_rate(token_cumsum_rate): + # Different token_cumsum_rate values are a performance/memory tradeoff only — + # sampling output must be identical regardless of the rate chosen. + config, num_samples, seed = get_sampling_config(20, sequence_length=5) + reference = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + + config_with_rate = GPTSamplingConfig.from_dict(config.to_dict(), {"token_cumsum_rate": token_cumsum_rate}) + result = validate_indexed_dataset_sampling(TEST_DATASET.sample(config_with_rate, num_samples, seed)) + Assert.all_equal(result, reference) + + +def test_cache_directory(data_result_path: pathlib.Path): + # Verify that the cache is written on first run and reused on subsequent runs. + cache_dir = data_result_path / "sampling/cache_directory" + config, num_samples, seed = get_sampling_config(20, sequence_length=5, cache_directory=cache_dir) + + first = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + assert cache_dir.exists() and any(cache_dir.iterdir()) + + # Second run with the same config must produce identical output (reads from cache). + second = validate_indexed_dataset_sampling(TEST_DATASET.sample(config, num_samples, seed)) + Assert.all_equal(first, second) + + +def test_cache_invalidated_on_config_change(data_result_path: pathlib.Path): + # Changing a sampling parameter should raise rather than silently return stale data. + cache_dir = data_result_path / "sampling/cache_invalidation" + config, num_samples, seed = get_sampling_config(20, sequence_length=5, cache_directory=cache_dir) + TEST_DATASET.sample(config, num_samples, seed) + + config_changed = GPTSamplingConfig.from_dict(config.to_dict(), {"token_cumsum_rate": 3}) + with pytest.raises(RuntimeError, match="Invalid dataset cache"): + TEST_DATASET.sample(config_changed, num_samples, seed) + + @pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") def test_build_padded_token_cumsum(): sizes = np.array([100, 256, 580, 600, 550, 89, 339, 430, 400, 795, 680, 50], dtype=np.int32) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 838562f64..d5d09f58c 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -29,7 +29,7 @@ ] -def test_gpt_slice(): +def test_gpt_slice(data_result_path): # Make sure dataset splitting works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() @@ -73,4 +73,5 @@ def test_gpt_slice(): "validation": GPT_SLICE_VALIDATION_SAMPLES, }, preprocessing=preprocessing, + cache_directory=data_result_path / "slice", ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 83f7657a0..9ad96b961 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -231,9 +231,9 @@ def _run_test_data_streaming_distributed( @pytest.mark.parametrize("num_workers", (0, 1)) -def test_data_streaming(result_path, worker_resources, num_workers): +def test_data_streaming(data_result_path, worker_resources, num_workers): distributed_config = _get_distributed_config({}) - path = result_path / "data_streaming/single_gpu" + path = data_result_path / f"data_streaming/single_gpu_workers_{num_workers}" _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers) check_data_streaming_results(path, distributed_config) @@ -254,10 +254,10 @@ def test_data_streaming(result_path, worker_resources, num_workers): @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) -def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): +def test_run_data_streaming_distributed(run_parallel_script, data_result_path, worker_resources): run_parallel_script( _run_test_data_streaming_distributed, - (result_path / "data_streaming", worker_resources.torchrun_port), + (data_result_path / "data_streaming", worker_resources.torchrun_port), world_size=4, backend=DistributedBackend.gloo, use_cuda=False, # Disable device count check. @@ -267,7 +267,7 @@ def test_run_data_streaming_distributed(run_parallel_script, result_path, worker @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) @pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) -def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): - report_subtest(path := result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) +def test_data_streaming_distributed(data_result_path, name, num_gpus, distributed_config_dict, report_subtest): + report_subtest(path := data_result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) distributed_config = _get_distributed_config(distributed_config_dict, num_gpus) check_data_streaming_results(path, distributed_config) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 7980f05bf..07b8768da 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -48,7 +48,7 @@ def test_mlp_recomputation(gated, activation, testing_device): param_grad_refs = [param.grad for param in params] for i, recompute_level in enumerate(MLPRecomputeLevel): - print(recompute_level.value) # noqa + print(recompute_level) # noqa input_.grad = None for param in params: param.grad = None diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 3a68a999f..9b93aeb66 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -13,6 +13,7 @@ from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward @@ -250,6 +251,13 @@ def _test_grpo_loss( logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( num_columns, loss_masking, batch_shape, dtype ) + num_labels = int((target >= 0).sum().item()) + num_labels_in_seq = torch.where( + target >= 0, + torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), + torch.zeros(batch_shape, dtype=torch.int32, device=target.device), + ) + divisor = max(num_labels, 1) out_ref, grad_ref = loss_forward_backward( grad_output, lambda *args, **kwargs: reference_grpo_loss(*args, **kwargs)[0], @@ -263,7 +271,7 @@ def _test_grpo_loss( previous_grad = torch.randn_like(grad_ref) grad_ref = grad_ref + previous_grad local_previous_grad = split_op(previous_grad, group, -1).contiguous() - out_fused, grad_fused, _ = fused_grpo_loss_forward_backward( + out_fused, grad_fused, new_logprobs_fused = fused_grpo_loss_forward_backward( split_op(logits, group, -1), target, advantages, @@ -272,10 +280,29 @@ def _test_grpo_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - divisor=(target >= 0).sum().item(), + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, ) _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + if not triton_available: + return + out_triton, grad_triton, new_logprobs_triton = triton_grpo_loss_forward_backward( + split_op(logits, group, -1).contiguous(), + target, + advantages, + old_log_probabilities, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + block_size=block_size, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) + Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) + def _test_z_loss( batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5f0f5a80f..094cbc094 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -235,7 +235,7 @@ def test_load_pretrained( ): # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. reference_config = model_testing_config.model_config_class.from_dict( - yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] + yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").read_text())["model"] ) reference_shard = safetensors.torch.load_file( get_convert_path() / "rank_0.safetensors", device=str(testing_device) @@ -260,7 +260,7 @@ def test_load_pretrained( "base_model": yaml.safe_load( get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) .joinpath("metadata.yaml") - .open("r") + .read_text() )["config"]["base_model"] } ) diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 0c40f0a48..e65c128f6 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -9,6 +9,8 @@ import safetensors import torch +from fast_llm.core.distributed import broadcast as _broadcast +from fast_llm.core.distributed import broadcast_object as _broadcast_object from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.distributed.distributed import ProcessGroupPool from fast_llm.engine.training.config import StreamingTrainerCallbackConfig @@ -68,27 +70,34 @@ def _run_event_consumer( path.mkdir(parents=True, exist_ok=True) field = REDIS_TRAINING_FIELD.encode() # TODO: Create a custom process group instead. + pool = None try: world_size = streaming_config.broadcast.external_world_size + 1 + consumer_rank = consumer_index + 1 backend = DistributedBackend.nccl if torch.cuda.is_available() else DistributedBackend.gloo - process_group = ProcessGroupPool( - rank=0, + pool = ProcessGroupPool( + rank=consumer_rank, world_size=world_size, - local_world_size=world_size, timeout=streaming_config.timeout, - use_cuda=backend == DistributedBackend.nccl, init_method=init_method, backend=backend, - ).get_process_group(range(world_size), 0) + device=( + torch.device("cuda", torch.cuda.current_device()) + if backend == DistributedBackend.nccl + else torch.device("cpu") + ), + ) + process_group = pool.get_process_group(range(world_size), consumer_rank) + timeout_ms = int(streaming_config.timeout * 1000) last_id = "0-0" while True: result = client.xread( streams={REDIS_TRAINING_STREAM: last_id}, count=1, - block=10000, + block=timeout_ms, ) if not result: - raise TimeoutError("No message received after 10000 ms...") + raise TimeoutError(f"No message received after {timeout_ms} ms...") ((stream, events),) = result Assert.eq(stream.decode(), REDIS_TRAINING_STREAM) @@ -102,15 +111,14 @@ def _run_event_consumer( elif message["type"] == "weights_ready": weights = {} while True: - meta = [None] - torch.distributed.broadcast_object_list(meta, group=process_group, group_src=0) - if meta[0] is None: + meta = _broadcast_object(None, process_group, src=0) + if meta is None: print(f"Weight broadcast finished") break - logging.info(f"receiving {meta[0]}") - shard_name, layer_name, tensor_size, tensor_type = meta[0] + logging.info(f"receiving {meta}") + shard_name, layer_name, tensor_size, tensor_type = meta tensor = torch.zeros(tuple(tensor_size), dtype=tensor_type, device="cuda") - torch.distributed.broadcast(tensor, group=process_group, group_src=0) + _broadcast(tensor, 0, process_group) if shard_name == "weights": weights[layer_name] = tensor safetensors.torch.save_file( @@ -118,7 +126,8 @@ def _run_event_consumer( ) finally: - torch.distributed.destroy_process_group(process_group) + if pool is not None: + pool.shutdown() def _run_model_streaming_configs( @@ -127,23 +136,24 @@ def _run_model_streaming_configs( # Import all dynamic classes. import fast_llm.cli # noqa - for config in _DISTRIBUTED_STREAMING_CONFIGS: + for config_index, config in enumerate(_DISTRIBUTED_STREAMING_CONFIGS): + config_port = port + config_index model_testing_config = update_and_add_testing_config( model_testing_config, None, updates={ - ("data", "datasets"): {"training": {"port": port, "timeout": 1.0}}, + ("data", "datasets"): {"training": {"port": config_port, "timeout": 1.0}}, ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { "type": "streaming", - "port": port, + "port": config_port, "broadcast": { - "port": port + 1000, + "port": config_port + 1000, "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, - "timeout": 1.0, + "timeout": 120, } }, # Disable tensor logging. @@ -192,9 +202,12 @@ def test_run_model_distributed_streaming( ): if torch.cuda.device_count() < 2: pytest.skip(f"Not enough GPUs") + model_testing_config.get_dataset() + # Use a fixed shift to avoid port conflicts with other distributed tests. + port = worker_resources.torchrun_port + 4321 run_parallel_script( _run_model_streaming_configs, - (run_test_script_base_path, model_testing_config, worker_resources.torchrun_port), + (run_test_script_base_path, model_testing_config, port), world_size=torch.cuda.device_count(), backend=model_testing_config.distributed_backend, ) diff --git a/tests/test_config.py b/tests/test_config.py index bf76595f9..792eab077 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -70,7 +70,7 @@ def test_serialize_default_config_updates(cls): @pytest.mark.parametrize("load_config", tuple(ModelConfigType)) def test_pretrained_config(load_config: ModelConfigType, result_path): - config_path = result_path / "pretrained_config" + config_path = result_path / "pretrained_config" / load_config pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/test_generate_config_docs.py b/tests/tools/test_generate_config_docs.py new file mode 100644 index 000000000..c9c7dc654 --- /dev/null +++ b/tests/tools/test_generate_config_docs.py @@ -0,0 +1,470 @@ +"""Unit tests for tools/generate_config_docs.py.""" + +import importlib.util +import pathlib +import typing + +import pytest + +from fast_llm.config import Config, Field, FieldHint, config_class + +# --------------------------------------------------------------------------- +# Load the generator module via importlib (it is not a package). +# --------------------------------------------------------------------------- + +_SCRIPT = pathlib.Path(__file__).parent.parent.parent / "tools" / "generate_config_docs.py" +_spec = importlib.util.spec_from_file_location("generate_config_docs", _SCRIPT) +_gen = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_gen) + + +# --------------------------------------------------------------------------- +# Minimal synthetic Config classes used across multiple tests. +# --------------------------------------------------------------------------- + + +@config_class() +class _InnerConfig(Config): + """A simple inner config for doc-generation tests.""" + + value: int = Field(default=0, hint=FieldHint.core, desc="A value.") + + +@config_class() +class _OuterConfig(Config): + """An outer config that references _InnerConfig.""" + + inner: _InnerConfig = Field(hint=FieldHint.core, desc="Inner config.") + required: str = Field(hint=FieldHint.core, desc="Required string field.") + inner_optional: _InnerConfig | None = Field(default=None, hint=FieldHint.feature, desc="Optional inner.") + string: str = Field(default="hello", hint=FieldHint.core, desc="A string.") + large_int: int = Field(default=2**32, hint=FieldHint.core, desc="A large integer.") + list_of_str: list[str] = Field(default_factory=list, hint=FieldHint.core, desc="A list of strings.") + dict_field: dict[str, int] = Field(default_factory=dict, hint=FieldHint.core, desc="A dict.") + + +# Minimal `found` and `cls_output_paths` dicts used in render_* tests. +_FOUND: dict = { + _InnerConfig: { + "module": "tests.tools._InnerConfig", + "fields": [], + "registry": None, + "registered_in": [], + "abstract": False, + }, + _OuterConfig: { + "module": "tests.tools._OuterConfig", + "fields": [], + "registry": None, + "registered_in": [], + "abstract": False, + }, +} +_CLS_OUTPUT_PATHS: dict[type, pathlib.Path] = { + _InnerConfig: pathlib.Path("tests/InnerConfig.md"), + _OuterConfig: pathlib.Path("tests/OuterConfig.md"), +} +_OWN_PATH = pathlib.Path("tests/SomeConfig.md") + +_OUTER_FIELDS = dict(_OuterConfig.fields()) +_OUTER_HINTS = typing.get_type_hints(_OuterConfig) + + +# --------------------------------------------------------------------------- +# get_module_dir +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "module_name, expected", + [ + ("fast_llm.config", pathlib.Path(".")), + ("fast_llm.engine.distributed.config", pathlib.Path("engine/distributed")), + ("fast_llm.data.dataset.config", pathlib.Path("data/dataset")), + ("fast_llm.models.gpt.config", pathlib.Path("models/gpt")), + ("fast_llm.engine.training.config", pathlib.Path("engine/training")), + # Module without trailing .config — just strip the fast_llm prefix. + ("fast_llm.profile", pathlib.Path("profile")), + ], +) +def test_get_module_dir(module_name, expected): + assert _gen.get_module_dir(module_name) == expected + + +# --------------------------------------------------------------------------- +# _relative_link +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "from_path, to_path, expected", + [ + # Same directory. + ("engine/distributed/A.md", "engine/distributed/B.md", "B.md"), + # Descend into child directory. + ("engine/A.md", "engine/distributed/B.md", "distributed/B.md"), + # Ascend to parent directory. + ("engine/distributed/A.md", "engine/B.md", "../B.md"), + # Sibling directory (up one, down one). + ("engine/distributed/A.md", "engine/training/B.md", "../training/B.md"), + # Deep cross-package link. + ("engine/training/runner/A.md", "data/dataset/B.md", "../../../data/dataset/B.md"), + # Top-level sibling packages. + ("engine/A.md", "data/B.md", "../data/B.md"), + ], +) +def test_relative_link(from_path, to_path, expected): + assert _gen._relative_link(pathlib.Path(from_path), pathlib.Path(to_path)) == expected + + +# --------------------------------------------------------------------------- +# _unwrap_optional +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected", + [ + (int | None, int), + (str | None, str), + (_InnerConfig | None, _InnerConfig), + (int, int), + (str, str), + ], +) +def test_unwrap_optional_strips_none(annotation, expected): + assert _gen._unwrap_optional(annotation) is expected + + +def test_unwrap_optional_union_unchanged(): + # Two non-None types: should not be simplified. + annotation = int | str + assert _gen._unwrap_optional(annotation) is annotation + + +def test_unwrap_optional_triple_union_unchanged(): + # Optional with two non-None types: should not be simplified. + annotation = int | str | None + assert _gen._unwrap_optional(annotation) is annotation + + +# --------------------------------------------------------------------------- +# render_hint_badge +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "hint, expected", + [ + (FieldHint.core, "`core`"), + (FieldHint.architecture, "`architecture`"), + (FieldHint.optional, "`optional`"), + (FieldHint.performance, "`performance`"), + (FieldHint.feature, "`feature`"), + (FieldHint.expert, "`expert`"), + (FieldHint.logging, "`logging`"), + (FieldHint.deprecated, "`deprecated`"), + (FieldHint.wip, "`wip`"), + # unknown → empty string (no badge). + (FieldHint.unknown, ""), + ], +) +def test_render_hint_badge(hint, expected): + assert _gen.render_hint_badge(hint) == expected + + +# --------------------------------------------------------------------------- +# _class_one_liner +# --------------------------------------------------------------------------- + + +class _DocOneLiner: + """A clean one-liner description.""" + + +class _DocMultiLine: + """First line only. + + More detail here that should be ignored. + """ + + +class _DocAutoSignature: + """SomeName(**kwargs)""" + + +class _DocNoDocstring: + pass + + +class _DocTrailingDot: + """Description ending with a dot.""" + + +@pytest.mark.parametrize( + "cls, expected", + [ + (_DocOneLiner, "A clean one-liner description"), + (_DocMultiLine, "First line only"), + (_DocAutoSignature, ""), # auto-generated __init__ signature — filtered out + (_DocNoDocstring, ""), + (_DocTrailingDot, "Description ending with a dot"), # trailing dot stripped + ], +) +def test_class_one_liner(cls, expected): + assert _gen._class_one_liner(cls, {}) == expected + + +# --------------------------------------------------------------------------- +# is_user_field — uses fields extracted from a synthetic Config class +# --------------------------------------------------------------------------- + + +@config_class() +class _IsUserFieldConfig(Config): + normal: str = Field(default="x", hint=FieldHint.core, desc="Normal field.") + feature: str = Field(default="x", hint=FieldHint.feature, desc="Feature field.") + derived: str = Field(default="x", hint=FieldHint.derived, desc="Derived field.") + testing: str = Field(default="x", hint=FieldHint.testing, desc="Testing field.") + setup_field: str = Field(default="x", hint=FieldHint.setup, desc="Setup field.") + + +_IS_USER_FIELD_FIELDS = dict(_IsUserFieldConfig.fields()) + + +@pytest.mark.parametrize( + "field_name, expected", + [ + ("normal", True), + ("feature", True), + ("derived", False), # excluded hint + ("testing", False), # excluded hint + ("setup_field", False), # excluded hint + ], +) +def test_is_user_field_hint(field_name, expected): + assert _gen.is_user_field(field_name, _IS_USER_FIELD_FIELDS[field_name]) == expected + + +@pytest.mark.parametrize( + "name, expected", + [ + ("_private", False), # underscore prefix → always excluded + ("type", False), # "type" is always excluded regardless of field content + ("normal_name", True), + ], +) +def test_is_user_field_name(name, expected): + # Use a valid public field object; only the name varies. + field = _IS_USER_FIELD_FIELDS["normal"] + assert _gen.is_user_field(name, field) == expected + + +# --------------------------------------------------------------------------- +# _extract_config_types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected_set", + [ + (_InnerConfig, {_InnerConfig}), + (_InnerConfig | None, {_InnerConfig}), + (list[_InnerConfig], {_InnerConfig}), + (dict[str, _InnerConfig], {_InnerConfig}), + (_InnerConfig | _OuterConfig, {_InnerConfig, _OuterConfig}), + (int, set()), + (str | None, set()), + (list[str], set()), + ], +) +def test_extract_config_types(annotation, expected_set): + result = _gen._extract_config_types(annotation, _FOUND) + assert set(result) == expected_set + + +# --------------------------------------------------------------------------- +# render_type +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "annotation, expected", + [ + (str, "`str`"), + (int, "`int`"), + (bool, "`bool`"), + (type(None), "`None`"), + (typing.Any, "`Any`"), + (str | None, "`str` or `None`"), + (int | None, "`int` or `None`"), + (list[str], "list[`str`]"), + (list[int], "list[`int`]"), + (dict[str, int], "dict[`str`, `int`]"), + (tuple[str, int], "tuple[`str`, `int`]"), + (set[str], "set[`str`]"), + ], +) +def test_render_type_primitives(annotation, expected): + assert _gen.render_type(annotation, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) == expected + + +def test_render_type_config_produces_link(): + result = _gen.render_type(_InnerConfig, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) + # Should be a markdown link to the class page. + assert result.startswith("[_InnerConfig](") + assert result.endswith(")") + + +def test_render_type_config_not_in_found(): + # Config type absent from found → backtick name, no link. + result = _gen.render_type(_InnerConfig, {}, {}, _OWN_PATH) + assert result == "`_InnerConfig`" + + +def test_render_type_optional_config(): + result = _gen.render_type(_InnerConfig | None, _FOUND, _CLS_OUTPUT_PATHS, _OWN_PATH) + assert "[_InnerConfig](" in result + assert "or `None`" in result + + +# --------------------------------------------------------------------------- +# render_default +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "field_name, expected", + [ + ("string", '`"hello"`'), + ("large_int", "`4_294_967_296`"), # 2**32 with underscores + ("list_of_str", "`list()`"), + ("dict_field", "`dict()`"), + ], +) +def test_render_default_simple(field_name, expected): + field = _OUTER_FIELDS[field_name] + resolved = _OUTER_HINTS.get(field_name, field.type) + assert _gen.render_default(field, resolved, _FOUND) == expected + + +def test_render_default_none(): + field = _OUTER_FIELDS["inner_optional"] + assert _gen.render_default(field, _InnerConfig | None, _FOUND) == "`None`" + + +def test_render_default_required_primitive(): + field = _OUTER_FIELDS["required"] + assert _gen.render_default(field, str, _FOUND) == "*(required)*" + + +def test_render_default_config_field_sub_fields_optional(): + # Config-typed field with no default → sub-fields are optional. + field = _OUTER_FIELDS["inner"] + assert _gen.render_default(field, _InnerConfig, _FOUND) == "*(sub-fields optional)*" + + +@config_class() +class _TypeDefaultConfig(Config): + fmt: type = Field(default=_InnerConfig, hint=FieldHint.core, desc="A type default.") + + +def test_render_default_type_class(): + # A field whose default value is itself a type object. + fields = dict(_TypeDefaultConfig.fields()) + assert _gen.render_default(fields["fmt"], type, _FOUND) == "`_InnerConfig`" + + +# --------------------------------------------------------------------------- +# format_nav_yaml +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "entries, indent, expected", + [ + # Flat list of strings. + ( + ["reference/a.md", "reference/b.md"], + 0, + ["- reference/a.md", "- reference/b.md"], + ), + # Single nested section. + ( + [{"Section": ["reference/a.md"]}], + 0, + ["- Section:", " - reference/a.md"], + ), + # Double-nested sections. + ( + [{"Outer": [{"Inner": ["reference/a.md"]}]}], + 0, + ["- Outer:", " - Inner:", " - reference/a.md"], + ), + # Non-zero base indent. + ( + ["reference/a.md"], + 1, + [" - reference/a.md"], + ), + # Mixed strings and dicts. + ( + ["reference/index.md", {"Sub": ["reference/sub/a.md"]}], + 0, + ["- reference/index.md", "- Sub:", " - reference/sub/a.md"], + ), + ], +) +def test_format_nav_yaml(entries, indent, expected): + assert _gen.format_nav_yaml(entries, indent) == expected + + +# --------------------------------------------------------------------------- +# render_class_page smoke test +# --------------------------------------------------------------------------- + + +def test_render_class_page_contains_key_sections(): + info = _FOUND[_OuterConfig] + # Build minimal fields list as the generator would. + fields = [] + for name, field in _OuterConfig.fields(): + if _gen.is_user_field(name, field): + resolved = _OUTER_HINTS.get(name, field.type) + fields.append((name, field, resolved)) + info_with_fields = {**info, "fields": fields} + + content = _gen.render_class_page( + _OuterConfig, + info_with_fields, + back_refs=[], + found=_FOUND, + cls_output_paths=_CLS_OUTPUT_PATHS, + own_path=_CLS_OUTPUT_PATHS[_OuterConfig], + ) + + assert "# _OuterConfig" in content + assert "## Fields" in content + assert "`string`" in content + assert "`large_int`" in content + assert "*(sub-fields optional)*" in content # inner field + assert "*(required)*" in content # required field + + +# --------------------------------------------------------------------------- +# render_index_page smoke test +# --------------------------------------------------------------------------- + + +def test_render_index_page_lists_classes(): + classes_in_dir = list(_FOUND.items()) + content = _gen.render_index_page( + pathlib.Path("tests"), + classes_in_dir, + cls_output_paths=_CLS_OUTPUT_PATHS, + subdirs=[], + ) + + assert "## Classes" in content + assert "_InnerConfig" in content + assert "_OuterConfig" in content diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6268ac194..0f89d9323 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -706,7 +706,7 @@ def update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, ModelTestingGroup.streaming: ModelTestingGroupAction.normal, }, ) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index b8f0b5b7a..78e4d4357 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -43,7 +43,10 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cuda=self._use_cuda + timeout=self._timeout, + init_method=self._init_method, + backend=self._backend, + device=None if self._use_cuda else torch.device("cpu"), ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size diff --git a/tools/generate_config_docs.py b/tools/generate_config_docs.py new file mode 100644 index 000000000..d797eb4ae --- /dev/null +++ b/tools/generate_config_docs.py @@ -0,0 +1,772 @@ +#!/usr/bin/env python3 +""" +Generate markdown documentation for Fast-LLM configuration classes. + +Walks the fast_llm package, finds all @config_class-decorated classes, and writes +one markdown file per class under docs/reference/configuration/, mirroring the +package structure. Also writes index.md files per directory and updates the nav +section in mkdocs.yaml. + +Usage: + python tools/generate_config_docs.py +""" + +import dataclasses +import importlib +import pathlib +import pkgutil +import re +import sys +import types +import typing + +from fast_llm.config import Config, Field, FieldHint, FieldHintImportance # noqa: E402 + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +REPO_ROOT = pathlib.Path(__file__).parent.parent +OUTPUT_DIR = REPO_ROOT / "docs" / "reference" / "configuration" +MKDOCS_YAML = REPO_ROOT / "mkdocs.yaml" + +sys.path.insert(0, str(REPO_ROOT)) + + +# --------------------------------------------------------------------------- +# Field filtering +# --------------------------------------------------------------------------- + +# Hints that describe internal/computed/testing fields — not useful in config docs. +EXCLUDED_HINTS: set[FieldHint] = {FieldHint.derived, FieldHint.setup, FieldHint.testing} + +# Field names that are always excluded regardless of hint. +EXCLUDED_FIELD_NAMES: set[str] = {"type"} + + +def is_user_field(name: str, field: Field) -> bool: + """Return True if this field should appear in user-facing documentation.""" + if name.startswith("_"): + return False + if name in EXCLUDED_FIELD_NAMES: + return False + if not field.init or field._field_type is not dataclasses._FIELD: # noqa: SLF001 + return False + if getattr(field, "hint", None) in EXCLUDED_HINTS: + return False + return True + + +# --------------------------------------------------------------------------- +# Module collection +# --------------------------------------------------------------------------- + + +def import_all_config_modules() -> None: + """Import every module in the fast_llm package so all Config subclasses are registered.""" + import fast_llm # noqa: F401 + + for module_info in pkgutil.walk_packages( + path=[str(REPO_ROOT / "fast_llm")], + prefix="fast_llm.", + onerror=lambda name: None, + ): + # Only import config modules — they are safe to import without GPU. + if not module_info.name.endswith(".config"): + continue + try: + importlib.import_module(module_info.name) + except Exception as exc: # noqa: BLE001 + print(f" [skip] {module_info.name}: {exc}", file=sys.stderr) + + +def collect_config_classes() -> dict[type, dict]: + """ + Return a dict mapping each Config subclass to metadata: + { + "module": str, + "fields": list[(name, Field, resolved_type)], + "registry": dict[str, type] | None, # subclasses if this has a registry + "registered_in": list[(base_cls, type_key)], # registries this class is in + "abstract": bool, + } + """ + import fast_llm.config as config_module + + config_base = config_module.Config + + # Collect all Config subclasses that have been processed by @config_class. + found: dict[type, dict] = {} + for cls in _all_subclasses(config_base): + if not getattr(cls, "__class_validated__", False): + continue + if cls.__module__ == "builtins": + continue + found[cls] = { + "module": cls.__module__, + "fields": [], + "registry": None, + "registered_in": [], + "abstract": bool(getattr(cls, "_abstract", False)), + } + + # Resolve type hints and build field lists. + for cls, info in found.items(): + try: + hints = typing.get_type_hints(cls) + except Exception: # noqa: BLE001 + hints = {} + for name, field in cls.fields(): + if not is_user_field(name, field): + continue + resolved = hints.get(name, field.type) + info["fields"].append((name, field, resolved)) + # Sort by hint importance (lower = more important), then alphabetically. + info["fields"].sort( + key=lambda t: (FieldHintImportance.get(getattr(t[1], "hint", FieldHint.unknown), 20), t[0]) + ) + + # Build registry info. + for cls, info in found.items(): + registry = getattr(cls, "_registry", None) + if registry is not None: + info["registry"] = {key: found_cls for key in registry if (found_cls := registry[key]) in found} + + # Build registered_in back-references. + for cls, info in found.items(): + registry = getattr(cls, "_registry", None) + if registry is None: + continue + for key in registry: + subclass = registry[key] + if subclass in found: + found[subclass]["registered_in"].append((cls, key)) + + return found + + +def _all_subclasses(cls: type) -> list[type]: + """Recursively collect all subclasses of a class.""" + result = [] + queue = list(cls.__subclasses__()) + seen = set() + while queue: + sub = queue.pop() + if sub in seen: + continue + seen.add(sub) + result.append(sub) + queue.extend(sub.__subclasses__()) + return result + + +# --------------------------------------------------------------------------- +# Back-reference computation +# --------------------------------------------------------------------------- + + +def build_back_refs(found: dict[type, dict]) -> dict[type, list[tuple[type, str]]]: + """ + For each config class, find all (owner_class, field_name) pairs that reference it + as part of their field type. + """ + back_refs: dict[type, list[tuple[type, str]]] = {cls: [] for cls in found} + + for owner_cls, info in found.items(): + for name, _field, resolved_type in info["fields"]: + for referenced_cls in _extract_config_types(resolved_type, found): + back_refs[referenced_cls].append((owner_cls, name)) + + return back_refs + + +def _extract_config_types(annotation, found: dict[type, dict]) -> list[type]: + """Extract all Config subclass types referenced in an annotation.""" + results = [] + if isinstance(annotation, type) and annotation in found: + results.append(annotation) + elif isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + for arg in typing.get_args(annotation): + results.extend(_extract_config_types(arg, found)) + elif hasattr(annotation, "__origin__"): + for arg in typing.get_args(annotation): + results.extend(_extract_config_types(arg, found)) + return results + + +# --------------------------------------------------------------------------- +# Type rendering +# --------------------------------------------------------------------------- + + +def render_type( + annotation, + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + own_path: pathlib.Path, +) -> str: + """Render a type annotation as a markdown string, linking to Config class pages.""" + if annotation is type(None): + return "`None`" + if annotation is typing.Any: + return "`Any`" + if isinstance(annotation, type): + if annotation in found: + rel_path = cls_output_paths.get(annotation) + if rel_path is not None: + link = _relative_link(own_path, rel_path) + return f"[{annotation.__name__}]({link})" + return f"`{annotation.__name__}`" + if issubclass(annotation, type): + return "`type`" + return f"`{annotation.__name__}`" + if isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + args = [a for a in typing.get_args(annotation) if a is not type(None)] + none_part = " or `None`" if type(None) in typing.get_args(annotation) else "" + inner = " or ".join(render_type(a, found, cls_output_paths, own_path) for a in args) + return inner + none_part + if hasattr(annotation, "__origin__"): + origin = annotation.__origin__ + args = typing.get_args(annotation) + if origin is list: + return f"list[{render_type(args[0], found, cls_output_paths, own_path)}]" if args else "`list`" + if origin is dict: + k = render_type(args[0], found, cls_output_paths, own_path) if args else "`Any`" + v = render_type(args[1], found, cls_output_paths, own_path) if len(args) > 1 else "`Any`" + return f"dict[{k}, {v}]" + if origin is tuple: + inner = ", ".join(render_type(a, found, cls_output_paths, own_path) for a in args) + return f"tuple[{inner}]" + if origin is set: + return f"set[{render_type(args[0], found, cls_output_paths, own_path)}]" if args else "`set`" + # Fallback for other generics + return f"`{getattr(origin, '__name__', str(origin))}`" + return f"`{annotation}`" + + +def render_default(field: Field, resolved_type, found: dict[type, dict]) -> str: + """Render the default value of a field as a string.""" + if field.default is not dataclasses.MISSING: + value = field.default + if isinstance(value, str): + return f'`"{value}"`' + if value is None: + return "`None`" + # Class objects: show the class name, not `` + if isinstance(value, type): + return f"`{value.__name__}`" + # Large integers: insert underscores every 3 digits for readability + if isinstance(value, int) and abs(value) > 999_999: + return f"`{value:_}`" + return f"`{value}`" + if field.default_factory is not dataclasses.MISSING: + factory = field.default_factory + # A factory that is itself a Config class means "instantiate with defaults". + if isinstance(factory, type) and factory in found: + return "*(sub-fields optional)*" + if hasattr(factory, "__name__"): + return f"`{factory.__name__}()`" + # If the type itself is a Config class, the value is still required in YAML + # but every sub-field within it has its own default — don't say "required". + core_type = _unwrap_optional(resolved_type) + if isinstance(core_type, type) and core_type in found: + return "*(sub-fields optional)*" + return "*(required)*" + + +def _unwrap_optional(annotation) -> type | None: + """Return the inner type of Optional[X] / X | None, or the annotation itself.""" + if isinstance(annotation, types.UnionType) or ( + hasattr(annotation, "__origin__") and annotation.__origin__ is typing.Union + ): + args = [a for a in typing.get_args(annotation) if a is not type(None)] + if len(args) == 1: + return args[0] + return annotation + + +# --------------------------------------------------------------------------- +# Output path computation +# --------------------------------------------------------------------------- + + +def get_module_dir(module_name: str) -> pathlib.Path: + """ + Convert a module name like 'fast_llm.engine.distributed.config' to a + relative output path like 'engine/distributed'. + """ + parts = module_name.split(".") + # Strip 'fast_llm' prefix. + if parts and parts[0] == "fast_llm": + parts = parts[1:] + # Strip trailing 'config'. + if parts and parts[-1] == "config": + parts = parts[:-1] + return pathlib.Path(*parts) if parts else pathlib.Path(".") + + +def compute_output_paths(found: dict[type, dict]) -> dict[type, pathlib.Path]: + """ + Return a dict mapping each class to its output path relative to OUTPUT_DIR, + e.g. engine/distributed/DistributedConfig.md + """ + return {cls: get_module_dir(info["module"]) / f"{cls.__name__}.md" for cls, info in found.items()} + + +# --------------------------------------------------------------------------- +# Markdown rendering +# --------------------------------------------------------------------------- + + +def render_hint_badge(hint: FieldHint) -> str: + badge_map = { + FieldHint.core: "core", + FieldHint.architecture: "architecture", + FieldHint.optional: "optional", + FieldHint.performance: "performance", + FieldHint.stability: "stability", + FieldHint.feature: "feature", + FieldHint.expert: "expert", + FieldHint.logging: "logging", + FieldHint.deprecated: "deprecated", + FieldHint.wip: "wip", + FieldHint.unknown: "", + } + label = badge_map.get(hint, str(hint)) + return f"`{label}`" if label else "" + + +def render_class_page( + cls: type, + info: dict, + back_refs: list[tuple[type, str]], + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + own_path: pathlib.Path, +) -> str: + """Render the full markdown page for a config class.""" + lines = [] + + # Title + lines.append(f"# {cls.__name__}\n") + + # Abstract badge + if info["abstract"]: + lines.append( + '!!! note "Abstract"\n This class cannot be instantiated directly. Use one of the variants listed below.\n' + ) + + # Module + lines.append(f"**Module:** `{cls.__module__}`\n") + + # Registered as / variant of + if info["registered_in"]: + for base_cls, type_key in info["registered_in"]: + base_path = cls_output_paths.get(base_cls) + if base_path is not None: + rel = _relative_link(own_path, base_path) + lines.append(f"**Variant of:** [{base_cls.__name__}]({rel}) — select with `type: {type_key}`\n") + else: + lines.append(f"**Variant of:** `{base_cls.__name__}` — select with `type: {type_key}`\n") + + # Inheritance (Config parents only, skip Config itself and internal bases) + config_parents = [ + base + for base in cls.__mro__[1:] + if base is not cls + and isinstance(base, type) + and issubclass(base, Config) + and base.__name__ != "Config" + and base in found + ] + if config_parents: + parent_links = [] + for parent in config_parents[:3]: # limit to 3 to avoid noise + p_path = cls_output_paths.get(parent) + if p_path is not None: + rel = _relative_link(own_path, p_path) + parent_links.append(f"[{parent.__name__}]({rel})") + else: + parent_links.append(f"`{parent.__name__}`") + lines.append(f"**Inherits from:** {', '.join(parent_links)}\n") + + lines.append("") + + # Fields — definition list, one entry per field + user_fields = info["fields"] + if user_fields: + lines.append("## Fields\n") + for name, field, resolved_type in user_fields: + type_str = render_type(resolved_type, found, cls_output_paths, own_path) + default_str = render_default(field, resolved_type, found) + hint = getattr(field, "hint", FieldHint.unknown) + hint_str = render_hint_badge(hint) + desc = getattr(field, "desc", None) or "" + doc = getattr(field, "doc", None) + if doc: + desc = f"{desc} {doc}".strip() if desc else doc + # Flatten multi-line descriptions (newlines break def-list indentation). + desc = " ".join(desc.split()) + # Term: field name + hint badge (omit separator when hint is empty) + term = f"`{name}`" + (f" — {hint_str}" if hint_str else "") + lines.append(term) + # Definition: metadata line, then description as a separate paragraph. + meta = f"**Type:** {type_str}    **Default:** {default_str}" + lines.append(f": {meta}") + if desc: + # Blank line + 4-space indent = new paragraph within the definition. + lines.append(f"") + lines.append(f" {desc}") + lines.append("") + else: + lines.append("*No user-configurable fields.*\n") + + # Variants table (if this class has a registry) + registry = info.get("registry") + if registry: + lines.append("## Variants\n") + lines.append("Select a variant by setting `type:` to one of the following values.\n") + lines.append("| `type` value | Class | Description |") + lines.append("|--------------|-------|-------------|") + for key in sorted(registry): + subclass = registry[key] + sub_path = cls_output_paths.get(subclass) + if sub_path is not None: + rel = _relative_link(own_path, sub_path) + class_link = f"[{subclass.__name__}]({rel})" + else: + class_link = f"`{subclass.__name__}`" + sub_info = found.get(subclass, {}) + desc = _class_one_liner(subclass, sub_info) + lines.append(f"| `{key}` | {class_link} | {desc} |") + lines.append("") + + # Used in (back-references) + if back_refs: + lines.append("## Used in\n") + seen = set() + for owner_cls, field_name in sorted(back_refs, key=lambda t: (t[0].__name__, t[1])): + key = (owner_cls, field_name) + if key in seen: + continue + seen.add(key) + owner_path = cls_output_paths.get(owner_cls) + if owner_path is not None: + rel = _relative_link(own_path, owner_path) + lines.append(f"- [`{field_name}`]({rel}) in [{owner_cls.__name__}]({rel})") + else: + lines.append(f"- `{field_name}` in `{owner_cls.__name__}`") + lines.append("") + + return "\n".join(lines) + + +def _class_one_liner(cls: type, info: dict) -> str: + """Return a short description for a class, or empty string if none is available.""" + doc = getattr(cls, "__doc__", None) + if doc: + first_line = doc.strip().split("\n")[0].strip().rstrip(".") + # Skip auto-generated __init__ signatures like "ClassName(**kwargs)" + if first_line and not re.match(r"^\w.*\(.*\)\s*$", first_line): + return first_line + return "" + + +def _relative_link(from_path: pathlib.Path, to_path: pathlib.Path) -> str: + """ + Compute a relative markdown link from one page to another, + both paths relative to OUTPUT_DIR. + """ + from_dir = from_path.parent + try: + rel = pathlib.Path(to_path).relative_to(from_dir) + except ValueError: + # Go up from from_dir to the common ancestor + parts_from = from_dir.parts + parts_to = to_path.parts + # Find common prefix length + common = 0 + for a, b in zip(parts_from, parts_to): + if a == b: + common += 1 + else: + break + up = len(parts_from) - common + rel = pathlib.Path(*[".."] * up, *parts_to[common:]) + return str(rel).replace("\\", "/") + + +# --------------------------------------------------------------------------- +# Index page rendering +# --------------------------------------------------------------------------- + + +def render_index_page( + directory: pathlib.Path, + classes_in_dir: list[tuple[type, dict]], + cls_output_paths: dict[type, pathlib.Path], + subdirs: list[pathlib.Path], +) -> str: + """Render an index.md for a directory.""" + lines = [] + + # Title: use the directory name + if directory == pathlib.Path("."): + title = "Configuration Reference" + else: + title = " / ".join(p.replace("_", " ").title() for p in directory.parts) + lines.append(f"# {title}\n") + + directory / "index.md" + + # Subdirectory links + if subdirs: + lines.append("## Sections\n") + for subdir in sorted(subdirs): + section_name = subdir.name.replace("_", " ").title() + rel = str((subdir / "index.md").relative_to(directory)).replace("\\", "/") + lines.append(f"- [{section_name}]({rel})") + lines.append("") + + # Class table + if classes_in_dir: + lines.append("## Classes\n") + lines.append("| Class | Description |") + lines.append("|-------|-------------|") + for cls, info in sorted(classes_in_dir, key=lambda t: t[0].__name__): + cls_path = cls_output_paths[cls] + rel = str(cls_path.relative_to(directory)).replace("\\", "/") + desc = _class_one_liner(cls, info) + abstract_note = " *(abstract)*" if info["abstract"] else "" + lines.append(f"| [{cls.__name__}]({rel}){abstract_note} | {desc} |") + lines.append("") + + return "\n".join(lines) + + +def render_root_index( + found: dict[type, dict], + cls_output_paths: dict[type, pathlib.Path], + top_level_dirs: list[pathlib.Path], +) -> str: + """Render the top-level index.md.""" + lines = [ + "# Configuration Reference\n", + "This reference documents all configuration classes in Fast-LLM.", + "Configurations are YAML files passed to the `fast-llm` CLI.", + "The entry point is `GPTTrainerConfig`, which composes all other configurations.\n", + "## Sections\n", + ] + for d in sorted(top_level_dirs): + section_name = d.name.replace("_", " ").title() + lines.append(f"- [{section_name}]({d.name}/index.md)") + lines.append("") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Nav generation +# --------------------------------------------------------------------------- + + +def build_nav_tree(cls_output_paths: dict[type, pathlib.Path], found: dict[type, dict]) -> dict: + """ + Build a nested dict representing the nav tree: + { dir_path: { "index": index_path, "classes": [...], "subdirs": {subdir: ...} } } + """ + tree: dict = {} + + for cls, rel_path in cls_output_paths.items(): + parts = rel_path.parent.parts + node = tree + for part in parts: + node = node.setdefault(part, {}) + node.setdefault("_classes", []).append(cls) + + return tree + + +def nav_entries( + tree: dict, + cls_output_paths: dict[type, pathlib.Path], + prefix: pathlib.Path = pathlib.Path("."), +) -> list: + """Recursively build the mkdocs nav list for the config reference section.""" + entries = [] + + # Index for this directory + if prefix == pathlib.Path("."): + index_rel = "reference/configuration/index.md" + else: + index_rel = f"reference/configuration/{prefix}/index.md".replace("\\", "/") + entries.append(index_rel) + + # Classes directly in this directory + classes = tree.get("_classes", []) + for cls in sorted(classes, key=lambda c: c.__name__): + rel = cls_output_paths[cls] + entries.append(f"reference/configuration/{rel}".replace("\\", "/")) + + # Subdirectories + for key, subtree in sorted((k, v) for k, v in tree.items() if not k.startswith("_")): + subprefix = prefix / key if prefix != pathlib.Path(".") else pathlib.Path(key) + section_name = key.replace("_", " ").title() + sub_entries = nav_entries(subtree, cls_output_paths, subprefix) + entries.append({section_name: sub_entries}) + + return entries + + +def format_nav_yaml(entries: list, indent: int = 0) -> list[str]: + """Render nav entries as YAML lines.""" + lines = [] + pad = " " * indent + for entry in entries: + if isinstance(entry, str): + lines.append(f"{pad}- {entry}") + elif isinstance(entry, dict): + for key, sub_entries in entry.items(): + lines.append(f"{pad}- {key}:") + lines.extend(format_nav_yaml(sub_entries, indent + 1)) + return lines + + +# --------------------------------------------------------------------------- +# mkdocs.yaml nav update +# --------------------------------------------------------------------------- + +NAV_SENTINEL_START = " # BEGIN AUTO-GENERATED CONFIG REFERENCE" +NAV_SENTINEL_END = " # END AUTO-GENERATED CONFIG REFERENCE" + + +def update_mkdocs_nav(nav_lines: list[str]) -> None: + """ + Replace the auto-generated config reference section in mkdocs.yaml. + If the sentinels are not present, append the section to the nav. + """ + content = MKDOCS_YAML.read_text() + + new_block = "\n".join([NAV_SENTINEL_START] + nav_lines + [NAV_SENTINEL_END]) + + if NAV_SENTINEL_START in content and NAV_SENTINEL_END in content: + # Replace existing block + pattern = re.escape(NAV_SENTINEL_START) + r".*?" + re.escape(NAV_SENTINEL_END) + content = re.sub(pattern, new_block, content, flags=re.DOTALL) + else: + # Append before the last line of the nav section + # Find the nav: key and append at the end of its list + lines = content.splitlines() + # Find the last non-empty line inside the nav block (heuristic: insert before next top-level key) + insert_at = len(lines) + in_nav = False + for i, line in enumerate(lines): + if line.startswith("nav:"): + in_nav = True + elif in_nav and line and not line.startswith(" "): + insert_at = i + break + indent = " " + nav_indented = "\n".join(indent + l for l in new_block.splitlines()) + lines.insert(insert_at, nav_indented) + content = "\n".join(lines) + "\n" + + MKDOCS_YAML.write_text(content) + print(f"Updated nav in {MKDOCS_YAML}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def generate(*, update_nav: bool = True, verbose: bool = True) -> None: + """Generate all config reference docs, optionally updating mkdocs.yaml nav.""" + + def log(msg: str) -> None: + if verbose: + print(msg) + + log("Importing fast_llm config modules...") + import_all_config_modules() + + log("Collecting config classes...") + found = collect_config_classes() + log(f" Found {len(found)} config classes") + + log("Computing output paths...") + cls_output_paths = compute_output_paths(found) + + log("Building back-references...") + back_refs = build_back_refs(found) + + # Group classes by output directory + dir_to_classes: dict[pathlib.Path, list[tuple[type, dict]]] = {} + for cls, info in found.items(): + directory = cls_output_paths[cls].parent + dir_to_classes.setdefault(directory, []).append((cls, info)) + + log(f"Writing to {OUTPUT_DIR} ...") + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Write class pages + for cls, info in found.items(): + rel_path = cls_output_paths[cls] + out_path = OUTPUT_DIR / rel_path + out_path.parent.mkdir(parents=True, exist_ok=True) + content = render_class_page(cls, info, back_refs[cls], found, cls_output_paths, rel_path) + out_path.write_text(content) + + # Write index pages — include all ancestor directories, not just leaf dirs with classes. + leaf_dirs = {cls_output_paths[cls].parent for cls in found} + all_dirs: set[pathlib.Path] = set() + for directory in leaf_dirs: + all_dirs.add(directory) + for i in range(len(directory.parts)): + all_dirs.add(pathlib.Path(*directory.parts[:i]) if i > 0 else pathlib.Path(".")) + + # Find all top-level directories (direct children of output root) + top_level_dirs = sorted({d.parts[0] for d in all_dirs if d != pathlib.Path(".")}) + + for directory in sorted(all_dirs): + classes_in_dir = dir_to_classes.get(directory, []) + # Find immediate subdirectories + subdirs = sorted( + { + directory / d.parts[len(directory.parts)] + for d in all_dirs + if len(d.parts) > len(directory.parts) and d.parts[: len(directory.parts)] == directory.parts + } + ) + index_content = render_index_page(directory, classes_in_dir, cls_output_paths, subdirs) + index_path = OUTPUT_DIR / directory / "index.md" + index_path.parent.mkdir(parents=True, exist_ok=True) + index_path.write_text(index_content) + + # Write root index + root_index = render_root_index( + found, + cls_output_paths, + [pathlib.Path(d) for d in top_level_dirs], + ) + (OUTPUT_DIR / "index.md").write_text(root_index) + + if update_nav: + log("Updating mkdocs.yaml nav...") + tree = build_nav_tree(cls_output_paths, found) + nav_root = nav_entries(tree, cls_output_paths) + nav_yaml_lines = format_nav_yaml([{"Configuration Reference": nav_root}], indent=1) + update_mkdocs_nav(nav_yaml_lines) + + log("Done.") + + +def main() -> None: + generate(update_nav=True, verbose=True) + + +if __name__ == "__main__": + main()