diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ad4c917f9..ecad0220c 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -2,6 +2,7 @@ from peft import LoraConfig from tqdm import tqdm +from torch.optim import Muon # PyTorch 2.9+; matrix-orthogonalized momentum optimizer. import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger @@ -64,7 +65,16 @@ def train(): model.add_adapter_to_model( args.lora.adapter_name, lora_config, gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + # Muon optimizes 2D hidden-layer weight matrices via Newton-Schulz orthogonalization. + # In LoRA training the trainable params are exclusively lora_A / lora_B (both 2D), + # so Muon applies cleanly without an AdamW fallback for 1D params. + # ``adjust_lr_fn='match_rms_adamw'`` rescales the orthogonalized update so the same + # lr / weight_decay tuned for AdamW can be reused directly (Moonshot Muon recipe). + model.set_optimizer( + optimizer_cls=Muon, + lr=args.optimizer.learning_rate, + adjust_lr_fn='match_rms_adamw', + ) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/docs/source_en/Components/Gym/Gym.md b/docs/source_en/Components/Gym/Gym.md index 4db355b8a..8d243677c 100644 --- a/docs/source_en/Components/Gym/Gym.md +++ b/docs/source_en/Components/Gym/Gym.md @@ -3,7 +3,8 @@ The Gym component provides an interface for reinforcement learning environments in Twinkle. ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" index 63dc87aa7..4e34ffb93 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" @@ -3,7 +3,8 @@ Gym 组件为 Twinkle 中的强化学习环境提供接口。 ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index cde5c519d..3860d2840 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py -import time from typing import List, Optional from twinkle import Platform, get_logger diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index e2e5d94d5..8dc15c926 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py index 1730887f2..10ad4a396 100644 --- a/src/twinkle/cli/cli.py +++ b/src/twinkle/cli/cli.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - import os import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Iterator, Literal + # ──────────────────────────────────────────────────────────────────────────────── # Arg group dataclasses @@ -243,7 +242,7 @@ def _resolve_path(self) -> Path | None: class EnvVarSource(ConfigSource): """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry.""" - def __init__(self, registry: ConfigRegistry): + def __init__(self, registry: 'ConfigRegistry'): self._registry = registry def load(self) -> dict[str, str]: diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py index 763ef246f..596252fb6 100644 --- a/src/twinkle/data_format/output.py +++ b/src/twinkle/data_format/output.py @@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False): loss: The loss calculated by the model. logps: The log-probabilities of correct tokens by the model. num_tokens: The token denominator associated with ``loss``. + embeddings: The embeddings output by the model, used be embedding task. """ logits: Optional[OutputType] loss: Optional[OutputType] logps: Optional[OutputType] num_tokens: Optional[OutputType] + embeddings: Optional[OutputType] class LossOutput(TypedDict, total=False): diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 1d5fe07c6..01ff0377d 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index c392d56cf..408c8d4b4 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -146,7 +146,7 @@ def _tracking_iter(self, inner): def skip_consumed_samples(self, consumed_train_samples: int) -> None: from torch.utils.data import IterableDataset - if isinstance(self.dataset, IterableDataset): + if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0: warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') self._skip_samples = 0 return @@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs): @remote_function() def get_state(self) -> dict: + """The dataloader state for saving.""" return {'consumed_train_samples': self._consumed_train_samples} def _rebuild_sampler_stack(self): diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py index ca7c6fbd8..ab1d3a982 100644 --- a/src/twinkle/dataset/iterable_packing_dataset.py +++ b/src/twinkle/dataset/iterable_packing_dataset.py @@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples): last_res += res return last_res - @staticmethod - def _cyclic_iter(iterable): - while True: - yield from iterable + def _write_through_iter(self, iterable): + """Yields from iterable, meanwhile, save it to disk if needed. + Saving is needed when you are using several datasets at a time. + """ + if not self.cyclic: + for row in iterable: + self._write_through(row) + yield row + return + else: + first_pass = True + while True: + empty = True + for row in iterable: + empty = False + if first_pass: + self._write_through(row) + yield row + if empty: + return + first_pass = False @remote_function() def __iter__(self): @@ -102,10 +119,7 @@ def __iter__(self): except StopIteration: return - if self.cyclic: - iterator = self._cyclic_iter(self.dataset) - else: - iterator = iter(self.dataset) + iterator = self._write_through_iter(self.dataset) data = [] max_length = self.template.max_length or 2048 while True: diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py index fa4acbd57..ada9498b8 100644 --- a/src/twinkle/dataset/packing_dataset.py +++ b/src/twinkle/dataset/packing_dataset.py @@ -114,6 +114,8 @@ def __getitem__(self, index): assert self._packed_called, 'Call `pack_dataset()` first before index the sample.' sequence = self.packed_idx[index] rows = [self.dataset[i] for i in sequence] + for row in rows: + self._write_through(row) output = {} for key in rows[0]: output[key] = [r[key] for r in rows] diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py deleted file mode 100644 index 44b0771bb..000000000 --- a/src/twinkle/gym/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from .base import Gym diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py deleted file mode 100644 index aca798093..000000000 --- a/src/twinkle/gym/base.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - - -class Gym: - - def __init__(self): - pass - - def step(self): - pass diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 83e10d132..1227584fb 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -1,11 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import functools import inspect -import itertools import json import numpy as np import os -import random import sys from typing import Any, Callable, List, Literal, Optional, TypeVar, Union @@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None: prefix = f'[twinkle driver caller: {caller}] ' exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), ) exc._twinkle_caller_augmented = True - except Exception: # noqa: BLE001 + except Exception: # noqa pass @@ -404,6 +402,7 @@ def dispatch_func(arg, n): return result elif dispatch == 'slice_dp': + assert device_mesh is not None # split by dp. each worker in one ep will receive the same argument result = [] # if device_mesh is not None: @@ -420,14 +419,6 @@ def dispatch_func(arg, n): import torch if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] - if device_mesh is None: - total = len(arg) - chunk = max(1, (total + n - 1) // n) - for i in range(n): - start = i * chunk - end = min(total, start + chunk) - _args.append(arg[start:end]) - return _args for i in range(n): _args.append(arg[device_mesh.get_slice( len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))]) @@ -696,7 +687,7 @@ def __next__(_self): return decorator -def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice', +def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none', sync: bool = False, @@ -803,18 +794,13 @@ def wrapper(self, *args, **kwargs) -> T1: # And this is user independent, only decided by the code. _local_lazy_collect = self._lazy_collect if _local_lazy_collect: - # Wrap the deferred collector so that exceptions - # raised when the caller later materializes the - # result also trigger the notifier. Attributes - # (``_futures`` etc.) on the original collector - # are preserved for downstream code paths. _orig_result_func = result_func @functools.wraps(_orig_result_func) def _notifying_result_func(*rargs, **rkwargs): try: return _orig_result_func(*rargs, **rkwargs) - except Exception as _e: # noqa: BLE001 + except Exception as _e: # noqa _tag_exc(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py index 22d3d4077..061ca2168 100644 --- a/src/twinkle/loss/chunked_cross_entropy.py +++ b/src/twinkle/loss/chunked_cross_entropy.py @@ -1,63 +1,159 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math -from typing import Any - -from ..data_format import LossOutput +from twinkle.data_format import LossOutput from .base import Loss +# Lazily-built singleton autograd.Function, so we neither pay the +# class-construction cost on every forward nor force a top-level torch import. +_CHUNKED_CE_FUNC = None + + +def _get_chunked_ce_func(): + global _CHUNKED_CE_FUNC + if _CHUNKED_CE_FUNC is not None: + return _CHUNKED_CE_FUNC + + import torch + import torch.nn.functional as F + + class _ChunkedCrossEntropyFunc(torch.autograd.Function): + """Chunked CE that materialises log_softmax(B, V) only one chunk at a time. + + Forward returns a scalar loss; backward writes per-token gradients into + a freshly allocated `grad_logits` tensor (the input `logits` is never + mutated). Mathematically equivalent to ``CrossEntropyLoss`` in the same + package; ``chunk_size`` only controls the memory/throughput trade-off. + """ + + @staticmethod + def forward(ctx, logits, labels, chunk_size, ignore_index, reduction, dft): + ctx.save_for_backward(logits, labels) + ctx.chunk_size = chunk_size + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.dft = dft + + n = logits.shape[0] + # Use fp32 accumulators so we don't lose precision when summing + # over many tokens under fp16/bf16 autocast (matches cross_entropy.py). + total_loss = logits.new_zeros((), dtype=torch.float32) + total_count = logits.new_zeros((), dtype=torch.float32) + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end] + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + + total_loss = total_loss + (per_token * mask).sum() + total_count = total_count + mask.sum() + + ctx.num_tokens = total_count.detach() + if reduction == 'mean': + return total_loss / total_count.clamp(min=1) + return total_loss + + @staticmethod + def backward(ctx, grad_output): + logits, labels = ctx.saved_tensors + chunk_size = ctx.chunk_size + ignore_index = ctx.ignore_index + reduction = ctx.reduction + dft = ctx.dft + + if reduction == 'mean': + scale = grad_output / ctx.num_tokens.clamp(min=1) + else: + scale = grad_output + + grad_logits = torch.empty_like(logits) + n = logits.shape[0] + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end].detach().requires_grad_(True) + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + with torch.enable_grad(): + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + loss_chunk = (per_token * mask).sum() + + grad_chunk = torch.autograd.grad(loss_chunk, logits_chunk, retain_graph=False)[0] + grad_logits[start:end] = grad_chunk * scale + + # logits, labels, chunk_size, ignore_index, reduction, dft + return grad_logits, None, None, None, None, None + + _CHUNKED_CE_FUNC = _ChunkedCrossEntropyFunc + return _CHUNKED_CE_FUNC + class ChunkedCrossEntropyLoss(Loss): - """TODO untested code""" + """CE loss that chunks the (B, V) softmax to bound peak memory. + + Drop-in replacement for :class:`CrossEntropyLoss` when ``outputs['logits']`` + is large (e.g. long sequence x big vocab). Behaviour matches that loss + bit-for-bit; ``chunk_size`` only affects memory/throughput. + + Args: + chunk_size: How many rows of ``logits`` to process per chunk. + ignore_index: Label id treated as padding (excluded from loss). + reduction: ``'mean'`` or ``'sum'``; matches ``CrossEntropyLoss``. + dft: If True, use DFT weighting ``-p*log(p)`` (arxiv 2508.05629). + """ - def __init__(self, chunk_size): + require_logits = True + # We chunk the (B, V) softmax ourselves; tell upstream not to materialise + # `logps` (which would already pay the full memory cost we're trying to + # avoid). The `_loss_from_logps` fast path is kept only for the rare case + # where someone explicitly hands us pre-computed logps. + require_logps = False + + def __init__(self, + chunk_size: int, + ignore_index: int = -100, + reduction: str = 'mean', + dft: bool = False, + **kwargs): + super().__init__() + assert chunk_size > 0, 'chunk_size must be positive' + assert reduction in ('mean', 'sum'), f"reduction must be 'mean' or 'sum', got {reduction!r}" self.chunk_size = chunk_size + self.ignore_index = ignore_index + self.reduction = reduction + self.dft = dft def __call__(self, inputs, outputs, **kwargs): - import torch - - class ChunkedCrossEntropyLossFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, labels, chunk_size): - import torch - ctx.save_for_backward(logits, labels) - ctx.chunk_size = chunk_size - - losses = [] - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end] - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - loss_chunk = loss_fct(logits_chunk, labels_chunk) - losses.append(loss_chunk) - del logits_chunk - del labels_chunk - all_losses = torch.cat(losses) - return all_losses - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any): - import torch - logits, labels = ctx.saved_tensors - chunk_size = ctx.chunk_size - - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end].detach().requires_grad_(True) - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - with torch.enable_grad(): - loss_chunk = loss_fct(logits_chunk, labels_chunk) - grad_output_chunk = grad_outputs[0][l_start:l_end] - _loss_chunk = (loss_chunk * grad_output_chunk).sum() - grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0] - logits[l_start:l_end] = grad_chunk - - return logits, None, None + labels = inputs['labels'] + logps = outputs.get('logps') + + # Fast path: if logps is already gathered upstream, chunking the + # softmax is moot — fall back to the same scalar formula as + # CrossEntropyLoss to keep behaviour identical. + if logps is not None: + return self._loss_from_logps(labels, logps) logits = outputs['logits'] - labels = inputs['labels'] - return LossOutput(loss=ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size), num_tokens=0) + labels = labels.view(-1) + logits = logits.view(-1, logits.shape[-1]) + + func = _get_chunked_ce_func() + loss = func.apply(logits, labels, self.chunk_size, self.ignore_index, self.reduction, self.dft) + + if self.reduction == 'mean': + return LossOutput(loss=loss, num_tokens=0) + num_tokens = (labels != self.ignore_index).float().sum().clamp(min=1) + return LossOutput(loss=loss, num_tokens=num_tokens) + + def _loss_from_logps(self, labels, logps): + mask = (labels != self.ignore_index).float() + per_token = -logps * logps.exp() if self.dft else -logps + if self.reduction == 'mean': + return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0) + return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1)) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index fe526ab46..d53019513 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -7,7 +7,7 @@ (https://arxiv.org/abs/2305.18290) """ from typing import TYPE_CHECKING, Dict, List, Optional, Union - +import math from twinkle.data_format import LossOutput from twinkle.loss.base import Loss from twinkle.utils.torch_utils import selective_log_softmax @@ -132,6 +132,12 @@ def __init__( **kwargs, ): super().__init__(ignore_index=ignore_index) + if loss_type not in ('sigmoid', 'hinge', 'ipo', 'kto_pair'): + raise ValueError(f'Unknown loss_type: {loss_type}') + if label_smoothing > 0 and loss_type != 'sigmoid': + raise ValueError( + f'label_smoothing > 0 is only defined for loss_type="sigmoid", ' + f'got loss_type="{loss_type}". Set label_smoothing=0.0 or switch to sigmoid.') self.beta = beta self.label_smoothing = label_smoothing self.loss_type = loss_type @@ -217,6 +223,11 @@ def _compute_dpo_loss( if self.loss_type == 'sigmoid': # Standard DPO loss: -log(sigmoid(beta * margin)) losses = -F.logsigmoid(logits) + # Apply label smoothing (only meaningful here: Bradley-Terry soft labels). + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses elif self.loss_type == 'hinge': # Hinge loss variant losses = torch.relu(1 - logits) @@ -234,12 +245,6 @@ def _compute_dpo_loss( else: raise ValueError(f'Unknown loss_type: {self.loss_type}') - # Apply label smoothing if specified - if self.label_smoothing > 0: - # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected - smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference - losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses - return losses.mean() def __call__( @@ -321,7 +326,8 @@ def __call__( reference_chosen_logps = torch.zeros_like(policy_chosen_logps) reference_rejected_logps = torch.zeros_like(policy_rejected_logps) else: - return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) + zero = (policy_chosen_logps.sum() + policy_rejected_logps.sum()) * 0.0 + return LossOutput(loss=zero, num_tokens=0) # Compute DPO loss dpo_loss = self._compute_dpo_loss( @@ -535,11 +541,23 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) - # Compute entirely in log-space to avoid exp() underflow: - # log(p) = avg_logps (already in log-space) - # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps)) + # Compute log(1-p) = log(1 - exp(avg_logp)) numerically stably: + # - For x > -log(2): log(-expm1(x)) (avoids log(0) when p → 1) + # - For x ≤ -log(2): log1p(-exp(x)) (avoids cancellation when p → 0) + # ``avg_logp ∈ (-∞, 0]`` so the threshold partitions the safe regime. + log_two = math.log(2.0) + + def _log1mexp(x: 'torch.Tensor') -> 'torch.Tensor': + # Clamp at a tiny negative to keep both branches well-defined when p≈1. + x_safe = torch.clamp(x, max=-1e-7) + return torch.where( + x_safe > -log_two, + torch.log(-torch.expm1(x_safe)), + torch.log1p(-torch.exp(x_safe)), + ) + + log_odds_chosen = chosen_avg_logps - _log1mexp(chosen_avg_logps) + log_odds_rejected = rejected_avg_logps - _log1mexp(rejected_avg_logps) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 3f7db4bfb..7c198ad02 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -41,6 +41,10 @@ def __init__( chunk_size: int = 512, **kwargs, ): + if not (0.0 <= beta <= 1.0): + raise ValueError(f'beta must be in [0, 1], got {beta}') + if temperature <= 0: + raise ValueError(f'temperature must be > 0, got {temperature}') self.beta = beta self.temperature = temperature self.ignore_index = ignore_index @@ -94,6 +98,7 @@ def __call__( labels=labels, beta=self.beta, temperature=self.temperature, + ignore_index=self.ignore_index, chunk_size=self.chunk_size, topk=topk, teacher_topk_logprobs=teacher_topk_logprobs, @@ -108,6 +113,7 @@ def _generalized_jsd_loss( labels=None, beta: float = 0.5, temperature: float = 1.0, + ignore_index: int = -100, chunk_size: int = 512, topk: Optional[int] = None, teacher_topk_logprobs=None, @@ -164,7 +170,7 @@ def _generalized_jsd_loss( # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: - mask = labels != -100 # ignore_index is always -100 per convention + mask = labels != ignore_index # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side # so both distributions are defined over the same token set. stu_dim = student_logits.shape[-1] @@ -178,12 +184,15 @@ def _generalized_jsd_loss( student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() + # ``[mask]`` already created fresh storage, so in-place divide is safe + # and avoids an extra [num_valid, V] allocation. + student_logits.div_(temperature) + teacher_logits.div_(temperature) else: - student_logits = student_logits.view(-1, student_logits.size(-1)) - teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + # Keep logits, may be an infer scenario + student_logits = student_logits.reshape(-1, student_logits.size(-1)) / temperature + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) / temperature num_valid = student_logits.size(0) - student_logits.div_(temperature) - teacher_logits.div_(temperature) if num_valid == 0: return student_logits.new_zeros(()) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index 4bb71216c..781b22060 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -42,18 +42,6 @@ def __init__( self.require_entropy = entropy_coef > 0.0 self.ignore_index = ignore_index - def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor': - """ - Compute loss mask from labels. - - Args: - labels: [batch, seq_len] target token ids, -100 for ignored positions - - Returns: - mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored - """ - return (labels != self.ignore_index).float() - def _compute_log_importance_weights( self, per_token_logps: 'torch.Tensor', @@ -165,10 +153,13 @@ def _pad_and_align_to_batch( return data # Already aligned if data.dim() == 1: data = data.unsqueeze(1) - if data.shape[1] == 1: # Scalars - result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) - result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0] - return result + if data.shape[1] == 1: + assert data.shape[0] == batch_size, ( + f'scalar broadcast expects data.shape[0]==batch_size, ' + f'got data.shape={tuple(data.shape)} mask.shape={(batch_size, seq_len)}') + fill = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) + expanded = data.expand(batch_size, seq_len) + return torch.where(mask, expanded, fill) data = [data[i] for i in range(batch_size)] # To list # Handle list (scalars or sequences) @@ -276,10 +267,12 @@ def __call__( ) # GRPO loss is ill-defined without advantages (e.g. ref-logps-only forward, - # or eval/validation forwards). Return a zero loss so the forward still - # flows through cleanly and callers can harvest outputs['logps'] freely. + # or eval/validation forwards). Return a zero loss that still flows through + # autograd so DDP/FSDP do not see unused params, and callers can harvest + # outputs['logps'] freely. if advantages is None: - return LossOutput(loss=torch.zeros((), device=device, dtype=logps.dtype), num_tokens=0) + zero = logps.sum() * 0.0 + return LossOutput(loss=zero, num_tokens=0) advantages = self._pad_and_align_to_batch( advantages, diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py index 68d14840c..c356bd64c 100644 --- a/src/twinkle/loss/infonce.py +++ b/src/twinkle/loss/infonce.py @@ -13,8 +13,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F -from enum import Enum from torch import nn from typing import Optional @@ -22,15 +20,6 @@ from .base import Loss -# Borrowed from sentence_transformers. -class SiameseDistanceMetric(Enum): - """Distance metrics available to the pairwise contrastive losses.""" - - EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa - MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa - COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa - - def _extract_sentences(outputs) -> torch.Tensor: """Return [B, D] sentence embeddings from postprocess_tensor_sp output. @@ -119,6 +108,11 @@ def __init__( process_group=None, **kwargs, ): + if mask_fake_negative and fake_neg_margin <= 0: + raise ValueError( + f'fake_neg_margin must be > 0 when mask_fake_negative=True, got {fake_neg_margin}. ' + 'A non-positive margin would mask out the positive itself or every above-positive ' + 'logit indiscriminately, collapsing the contrastive signal.') self.temperature = temperature self.use_batch = use_batch self.hard_negatives = hard_negatives @@ -129,7 +123,13 @@ def __init__( self.process_group = process_group def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): - """All-gather embeddings & labels across DP ranks; only local shard keeps grad.""" + """All-gather embeddings & labels across DP ranks; only local shard keeps grad. + + NCCL ``all_gather`` requires every rank to send the *same* tensor size. Under + ``slice_dp`` dispatch the per-rank batch is uneven (``divmod`` splits), so we + pad each rank to the global max along dim-0, do an equal-sized all_gather, + then strip padding back. Only the local shard retains gradients. + """ if not (dist.is_available() and dist.is_initialized()): return sentences, labels world_size = dist.get_world_size(group=self.process_group) @@ -137,24 +137,40 @@ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): return sentences, labels rank = dist.get_rank(group=self.process_group) - # variable per-rank shapes require communicating shape first - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape, group=self.process_group) - all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] - dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group) - - local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long) - label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)] - dist.all_gather(label_shapes, local_label_shape, group=self.process_group) - all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes] - dist.all_gather(all_labels, labels.contiguous(), group=self.process_group) - - # keep the local shard differentiable; detach others - all_sentences[rank] = sentences + # ``labels`` is a 1-D mask aligned to ``sentences`` along dim-0, so they + # share the same per-rank size. Gather sizes once and reuse for both. + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n, group=self.process_group) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: torch.Tensor): + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous(), group=self.process_group) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + + # Strip padding; keep local shard differentiable, detach others. + all_sentences = [] + all_labels = [] for idx in range(world_size): - if idx != rank: - all_sentences[idx] = all_sentences[idx].detach() + n = sizes_int[idx] + if idx == rank: + all_sentences.append(sentences) + all_labels.append(labels) + else: + all_sentences.append(sent_buffers[idx][:n].detach()) + all_labels.append(label_buffers[idx][:n]) return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0) def __call__(self, inputs, outputs, **kwargs) -> LossOutput: diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py index b3034c57a..4dfb01198 100644 --- a/src/twinkle/metric/accuracy.py +++ b/src/twinkle/metric/accuracy.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from typing import List, Union from ..data_format import InputFeature, ModelOutput diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index b203d255e..024cb0473 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -131,6 +131,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_outputs = kwargs.get('ref_outputs') if ref_outputs is not None: ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + if isinstance(ref_logps, list): + if len(ref_logps) == 0: + ref_logps = None + else: + ref_logps = pad_and_stack_tensors(ref_logps) if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index 9fb3aed8c..8b3681031 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -1,7 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import torch -import torch.distributed as dist -import torch.nn.functional as F from typing import List, Union from twinkle.data_format import InputFeature, ModelOutput @@ -32,6 +29,9 @@ def reset(self): self.grad_norm = 0.0 def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + import torch + import torch.distributed as dist + import torch.nn.functional as F sentences = outputs.get('embeddings') if sentences is None: sentences = outputs.get('logits') @@ -44,22 +44,34 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M inputs = [inputs] labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0) - # Gather embeddings and labels across DP for in-batch stats + # Gather embeddings and labels across DP for in-batch stats. + # NCCL ``all_gather`` requires every rank to send the same tensor size, + # but ``slice_dp`` dispatch (``divmod`` split) can leave per-rank dim-0 + # uneven. Pad to the global max along dim-0, gather, then strip padding. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: world_size = dist.get_world_size() - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape) - all_sentences = [sentences.new_empty(s.tolist()) for s in shapes] - dist.all_gather(all_sentences, sentences.contiguous()) - sentences = torch.cat(all_sentences, dim=0) - - local_lshape = labels.new_tensor(labels.shape, dtype=torch.long) - lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)] - dist.all_gather(lshapes, local_lshape) - all_labels = [labels.new_empty(s.tolist()) for s in lshapes] - dist.all_gather(all_labels, labels.contiguous()) - labels = torch.cat(all_labels, dim=0) + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: 'torch.Tensor') -> 'List[torch.Tensor]': + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous()) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + sentences = torch.cat([sent_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) + labels = torch.cat([label_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1) if anchor_idx.numel() == 0: diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py index 06e082eeb..e2797b1ec 100644 --- a/src/twinkle/metric/grpo.py +++ b/src/twinkle/metric/grpo.py @@ -3,9 +3,12 @@ from typing import Any, Dict, List, Optional, Union from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import get_logger from twinkle.utils.transformers_utils import align_logps_to_mask from .base import Metric +logger = get_logger() + class GRPOMetric(Metric): @@ -254,6 +257,11 @@ def accumulate( if len(seq_lens) == 1: merged = torch.cat(label_tensors, dim=0) inputs_list = [{'labels': merged}] + else: + logger.warning( + f'GRPOMetric: logps is a single tensor but inputs_list has ' + f'{len(inputs_list)} mb with mismatched seq_lens={sorted(seq_lens)}. ' + f'Only mb[0] will be accumulated; check the model forward path.') flat_old: Optional[List] = None if old_logps is not None and isinstance(old_logps, (list, tuple)): @@ -284,7 +292,17 @@ def accumulate( # Uncommon: aligned global tensor. Only honour when it # exactly matches the single-mb shape; otherwise drop. import torch as _torch # noqa: F811 - old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape) else None + if _torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape: + old_slice = old_logps + else: + if mb_idx == 0: + # Warn once per accumulate call (not per mb) to avoid log spam. + old_shape = tuple(old_logps.shape) if _torch.is_tensor(old_logps) else 'unknown' + logger.warning( + f'GRPOMetric: old_logps shape {old_shape} does not match ' + f'logps_mb shape {tuple(logps_mb.shape)}; ratio/kl metrics will ' + f'be skipped for this step.') + old_slice = None else: old_slice = None diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index da82a8783..8d785c38b 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -2,7 +2,7 @@ import time from typing import List, Union -from ..data_format import InputFeature, ModelOutput +from twinkle.data_format import InputFeature, ModelOutput from .base import Metric diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index a4d4ea064..8ea00d696 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a5ea3fc56..60b45f774 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -420,7 +420,7 @@ def forward_step_func(data_iterator, model): embeddings = output_tensor elif labels is not None and is_last_pp: _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _loss_require_entropy = getattr(_loss_instance, 'require_entropy', True) _packed = batch.get('packed_seq_params') cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None if _loss_require_logps: @@ -446,7 +446,7 @@ def forward_step_func(data_iterator, model): _outputs = {'logps': logps} if entropies is not None: _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + if getattr(_loss_instance, 'require_logits', False): _outputs['logits'] = output_tensor batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) logps = _outputs['logps'] @@ -990,7 +990,9 @@ def _get_rng_state() -> 'ShardedObject': 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), + # Backend-agnostic device RNG (CUDA / NPU / MPS); key kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + 'cuda_rng_state': Platform.get_device_rng_state(), 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = [rng_state] @@ -1112,7 +1114,7 @@ def _save_mcore_optimizer( with open(tracker_path, 'w') as f: f.write(str(iteration)) - logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} ' + logger.info(f'Saved mcore optimizer state at iteration {iteration} ' f'to {checkpoint_dir}') def _load_mcore_optimizer( @@ -1139,7 +1141,7 @@ def _load_mcore_optimizer( ) iteration = self._read_iteration(tracker_path) if iteration == 0: - logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}') + logger.warning(f'No checkpoint found in {checkpoint_dir}') return iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}') @@ -1201,7 +1203,9 @@ def _load_mcore_optimizer( random.setstate(rng['random_rng_state']) np.random.set_state(rng['np_rng_state']) torch.set_rng_state(rng['torch_rng_state']) - torch.cuda.set_rng_state(rng['cuda_rng_state']) + # Backend-agnostic restore: tolerates ckpt produced on different backend + # (returns None) and avoids hard-coded torch.cuda which crashes on NPU. + Platform.set_device_rng_state(rng.get('cuda_rng_state')) tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], ) # Restore iteration counter. @@ -1211,26 +1215,26 @@ def _load_mcore_optimizer( if dist.is_initialized(): dist.barrier() - logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} ' + logger.info(f'Resumed from mcore checkpoint at iteration {iteration} ' f'from {checkpoint_dir}') @staticmethod def _read_iteration(tracker_path: str) -> int: - if not os.path.exists(tracker_path): - return 0 - with open(tracker_path) as f: - iteration = int(f.read().strip()) + # All ranks must enter the all_reduce together; missing tracker on some + # ranks (e.g. NFS lag, partial mount) must NOT short-circuit, otherwise + # the remaining ranks hang at the collective. Treat missing as 0 and + # let MAX reduction recover the canonical iteration from any rank that + # successfully read the file. + iteration = 0 + if os.path.exists(tracker_path): + with open(tracker_path) as f: + iteration = int(f.read().strip()) if torch.distributed.is_initialized(): - iters_cuda = torch.tensor( - [iteration], - dtype=torch.long, - device='cuda', - ) - torch.distributed.all_reduce( - iters_cuda, - op=torch.distributed.ReduceOp.MAX, - ) - iteration = iters_cuda[0].item() + # Use Platform.get_local_device() to stay backend-agnostic + # (CUDA / NPU / MPS); 'cuda' would crash on NPU. + iters_dev = torch.tensor([iteration], dtype=torch.long, device=Platform.get_local_device()) + torch.distributed.all_reduce(iters_dev, op=torch.distributed.ReduceOp.MAX) + iteration = int(iters_dev[0].item()) return iteration def _merge_lora_adapters(self, adapter_name: str = 'default'): @@ -1256,7 +1260,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non For distributed training: - All PP ranks participate in export (each has different layers) - - Only DP rank 0 actually writes to disk + - Only global rank 0 actually writes shared config files - Uses barrier for synchronization For LoRA training: @@ -1264,12 +1268,9 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non """ # Check if this is LoRA training is_peft_format = (adapter_name != _default_adapter_name) + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 - # Create output directory on rank 0 only - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 - - if dp_rank == 0: + if is_global_zero: os.makedirs(output_dir, exist_ok=True) # Synchronize before saving @@ -1281,8 +1282,8 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non self.strategy.bridge.save_weights( model, output_dir, peft_format=is_peft_format, adapter_name=adapter_name, converter=lora_converter) - # Save config on rank 0 only - if dp_rank == 0: + # Save config on global rank 0 only (avoid concurrent writers). + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): config = model[0].peft_config[adapter_name] @@ -1291,11 +1292,13 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non model[0].peft_config[adapter_name].save_pretrained(output_dir) config.target_modules = target_modules + if dist.is_initialized(): + dist.barrier() + def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 os.makedirs(output_dir, exist_ok=True) - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 state_dict = self._get_trainable_parameters(adapter_name) cpu_state_dict = {} for k, v in state_dict.items(): @@ -1311,13 +1314,18 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - # Save config on rank 0 only + # Save shared config on global rank 0 only (avoid concurrent writers). model = self.strategy.unwrap_model(self.model) - if dp_rank == 0: + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): model[0].peft_config[adapter_name].save_pretrained(output_dir) + # Finalize barrier: ensure all ranks finish writing model_rank*.pt + # before the caller proceeds (e.g. uploading / loading the ckpt). + if dist.is_initialized(): + dist.barrier() + def _save_tokenizer(self, output_dir: str, **kwargs): from twinkle.utils import is_last_rank if not is_last_rank(): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 2dd6b7a53..ae2eb4b2c 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -15,7 +15,7 @@ from transformers import AutoConfig, PretrainedConfig from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union -from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util +from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict @@ -221,8 +221,11 @@ def _save_local_training_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if torch.cuda.is_available(): - rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + # Backend-agnostic device RNG capture (CUDA / NPU / MPS). Key is kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + device_rng = Platform.get_device_rng_state() + if device_rng is not None: + rng_state['cuda_rng_state'] = device_rng rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() return rng_state @@ -233,8 +236,10 @@ def _load_local_training_rng_state(rng_state): random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) - if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Backend-agnostic device RNG restore: tolerates ckpt produced on different + # backend (key absent or None) and avoids hard-coded torch.cuda on NPU. + if 'cuda_rng_state' in rng_state: + Platform.set_device_rng_state(rng_state['cuda_rng_state']) tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): @@ -251,6 +256,9 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + if dist.is_initialized(): + dist.barrier() + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): no_load_optim = kwargs.pop('no_load_optim', False) no_load_rng = kwargs.pop('no_load_rng', False) @@ -260,6 +268,13 @@ def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '' if not no_load_optim and optimizer_config is not None: if optimizer_config.optimizer is not None and 'optimizer' in state_dict: optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + device = Platform.get_local_device() + for group_state in optimizer_config.optimizer.state.values(): + if not isinstance(group_state, dict): + continue + for k, v in group_state.items(): + if isinstance(v, torch.Tensor): + group_state[k] = v.to(device) if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) if not no_load_rng and 'rng_state' in state_dict: diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb8..1bd809025 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -48,7 +48,6 @@ def __init__( ddp_config: Dict[str, Any] = None, **kwargs, ): - import torch.distributed as dist from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991f..018f4c494 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -34,10 +34,8 @@ def __init__( parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) - kwargs_handlers = [] - kwargs_handlers.append( - InitProcessGroupKwargs( - timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))) + kwargs_handlers = [InitProcessGroupKwargs( + timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))] if ddp_config is not None: from accelerate import DistributedDataParallelKwargs ddp_config = DistributedDataParallelKwargs(**ddp_config) @@ -131,8 +129,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di return fsdp_plugin def wrap_model(self, model, *args): - result = self.accelerator.prepare(model, *args) - return result + return self.accelerator.prepare(model, *args) def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 61733d7dc..cac375263 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -414,8 +414,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs = optimizer_config.template.batch_encode(inputs) # noqa processor: InputProcessor = optimizer_config.processor loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' inputs: Dict[str, Any] = processor( @@ -490,8 +490,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) inputs: Dict[str, Any] = processor( inputs, @@ -929,7 +929,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if optimizer_config.cur_step % interval != 0: return model = self.strategy.unwrap_model(self.model) - processed_state_dict = {} save_kwargs = {} if adapter_name == _default_adapter_name: # Full model save diff --git a/src/twinkle/notifier/__init__.py b/src/twinkle/notifier/__init__.py index 329cb6f1d..067db71a1 100644 --- a/src/twinkle/notifier/__init__.py +++ b/src/twinkle/notifier/__init__.py @@ -1,2 +1,3 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from .base import Notifier, notify_exception from .ding_notifier import DingNotifier diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index a83903b53..6f50ca659 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import os from typing import Dict, Optional @@ -66,7 +67,7 @@ def notify_exception(notifier: Notifier, context: str, exc: BaseException, name: if not _try_claim_notify_slot(exc, context, name): try: setattr(exc, '_twinkle_notified', True) - except Exception: # noqa: BLE001 + except Exception: # noqa pass return diff --git a/src/twinkle/notifier/ding_notifier.py b/src/twinkle/notifier/ding_notifier.py index fe102d8a5..fc535edd7 100644 --- a/src/twinkle/notifier/ding_notifier.py +++ b/src/twinkle/notifier/ding_notifier.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib import hmac diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 8709d98ab..c505d91cd 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -42,7 +42,6 @@ class InputProcessor: 'video_grid_thw': 0, 'input_features': 0.0, 'feature_attention_mask': 0, - 'mm_token_type_ids': 0, } # VLM fields to concatenate (not pad) in batch @@ -108,8 +107,12 @@ def to_tensor(_input): # so tensor ops like labels != ignore_index or .to(device) would fail without this. if isinstance(value, np.ndarray): value = torch.from_numpy(value) - elif (isinstance(value, list) and isinstance(value[0], - (int, float, np.number))) or key == 'position_ids': + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], (int, float, np.number)): + value = torch.tensor(value) + elif key == 'position_ids' and not isinstance(value, torch.Tensor): + if value is None: + continue value = torch.tensor(value) elif (isinstance(value, list)) and key in ('completion_mask', 'mm_token_type_ids'): value = torch.tensor(value) @@ -284,7 +287,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso return input_tensor if cp_size > 1: - position_ids_f = position_ids.flatten() + pos_for_cu = position_ids[:1] if position_ids.dim() >= 2 and position_ids.shape[0] > 1 \ + else position_ids + position_ids_f = pos_for_cu.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) cu_seqlens = torch.cat([ indices_q[position_ids_f == 0], @@ -354,8 +359,11 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', - pin_memory=True).cuda(non_blocking=True) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], + device=inputs.device, + dtype=torch.long, + ) val = val.index_select(dim, index) view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) new_inputs.append(val.view(view_shape)) @@ -402,17 +410,18 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], ** if not padding_free or bool(kwargs.get('enable_sp', False)): return inputs - from twinkle.patch import apply_patch - from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch - - apply_patch( - model, - GatedDeltaNetPaddingFreePatch, - hf_config=kwargs.get('hf_config'), - enable_sp=False, - ) if not getattr(model, '_twinkle_gdn_padding_free_patched', False): - return inputs + from twinkle.patch import apply_patch + from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch + + apply_patch( + model, + GatedDeltaNetPaddingFreePatch, + hf_config=kwargs.get('hf_config'), + enable_sp=False, + ) + if not getattr(model, '_twinkle_gdn_padding_free_patched', False): + return inputs for _inp in inputs: position_ids = _inp.get('position_ids') @@ -631,15 +640,27 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat output = {} _keys = [ 'input_ids', - 'input_embeddings', + 'inputs_embeds', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'cu_seq_lens_q', + 'cu_seq_lens_k', + 'cu_seqlens_q', + 'cu_seqlens_kv', + 'max_length_q', + 'max_length_k', + 'packed_seq_params', ] + list(InputProcessor.VLM_CONCAT_FIELDS) for key in list(_input.keys()): - if key in _keys: - output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] + if key not in _keys: + continue + value = _input[key] + if isinstance(value, torch.Tensor) or not isinstance(value, (list, np.ndarray)): + output[key] = value + else: + output[key] = np.array(value) results.append(InputFeature(**output)) return results @@ -776,6 +797,5 @@ def postprocess_tensor_cp(self, tensor, cu_seqlens=None): if self.device_mesh.cp_world_size <= 1: return tensor from megatron.core import parallel_state as mpu - from twinkle.utils.torch_utils import gather_cp_load_balanced return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1, cu_seqlens=cu_seqlens) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 32dc1ca50..0433ef5a8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -1,24 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""vLLM-based sampler using VLLMEngine (AsyncLLM). - -Device Configuration: - vLLMSampler automatically detects the number of available GPUs from - CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager) - and configures vLLM's tensor_parallel_size accordingly. - - To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1: - - # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2) - - # TP4 (4 GPUs, 1 worker with all 4 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4) - -Data Flow: - When multiple vLLMSampler workers exist (DP > 1): - - Data is dispatched via dispatch='slice_dp' (each worker gets a slice) - - Results are collected via collect='flatten' (merged into single list) -""" import asyncio import atexit import numpy as np diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 326a6916f..be24c7824 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -1,12 +1,11 @@ """Backend factory for creating StateBackend instances based on configuration.""" from __future__ import annotations -import logging - from twinkle.server.config.persistence import PersistenceConfig +from twinkle.utils import get_logger from .base import StateBackend -logger = logging.getLogger(__name__) +logger = get_logger() def create_backend(config: PersistenceConfig | None = None) -> StateBackend: diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index d931ce336..8cb055ae4 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -import logging import time from abc import ABC, abstractmethod from datetime import datetime, timezone @@ -9,9 +8,10 @@ from typing import Generic, TypeVar from twinkle.server.state.backend.base import StateBackend +from twinkle.utils import get_logger T = TypeVar('T', bound=BaseModel) -logger = logging.getLogger(__name__) +logger = get_logger() class BaseManager(ABC, Generic[T]): diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index 442019ea4..1bc901b0c 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -2,14 +2,14 @@ from __future__ import annotations import functools -import logging import time +from twinkle.utils import get_logger from .backend.base import ConcurrencyError, StateBackend from .base import BaseManager from .models import SessionRecord -logger = logging.getLogger(__name__) +logger = get_logger() def _session_touch_transform(existing: dict | None, *, now: float) -> dict | None: diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index 77212c757..059f301fb 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -16,8 +16,9 @@ from typing import Any from twinkle.server.config.telemetry import TelemetryConfig +from twinkle.utils import get_logger -logger = logging.getLogger(__name__) +logger = get_logger() # Loggers belonging to the OTLP transport stack. Their own records must never # be routed back through the OTLP LoggingHandler: an exporter error logged diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py index 997f2e140..40edc6628 100644 --- a/src/twinkle/server/telemetry/worker_init.py +++ b/src/twinkle/server/telemetry/worker_init.py @@ -7,10 +7,11 @@ """ from __future__ import annotations -import logging import os -logger = logging.getLogger(__name__) +from twinkle.utils import get_logger + +logger = get_logger() _worker_initialized = False diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 6c4bdddd2..168456eab 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -2,3 +2,4 @@ from .base import Template from .deepseek_v4 import DeepseekV4Template from .qwen3_5_vl import Qwen3_5Template +from .tools import ToolCallParser, ToolCallRegistry, ClineParser, HermesQwenParser, ReActParser, VCPParser diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 26c2e4f26..c1e8f069f 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -32,6 +32,19 @@ class Template: video_placeholder: str = '