From 28ab28fa4ea5f3c321dfcc4605a9bfcac03ea0fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:12:33 +0800 Subject: [PATCH 01/14] wip --- src/twinkle/checkpoint_engine/manager.py | 1 - src/twinkle/checkpoint_engine/mixin.py | 1 - src/twinkle/cli/cli.py | 7 +++---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index cde5c519d..3860d2840 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py -import time from typing import List, Optional from twinkle import Platform, get_logger diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py index e2e5d94d5..8dc15c926 100644 --- a/src/twinkle/checkpoint_engine/mixin.py +++ b/src/twinkle/checkpoint_engine/mixin.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os from twinkle import Platform, remote_function from twinkle.checkpoint_engine.base import CheckpointEngine diff --git a/src/twinkle/cli/cli.py b/src/twinkle/cli/cli.py index 1730887f2..10ad4a396 100644 --- a/src/twinkle/cli/cli.py +++ b/src/twinkle/cli/cli.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - import os import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Iterator, Literal + # ──────────────────────────────────────────────────────────────────────────────── # Arg group dataclasses @@ -243,7 +242,7 @@ def _resolve_path(self) -> Path | None: class EnvVarSource(ConfigSource): """Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry.""" - def __init__(self, registry: ConfigRegistry): + def __init__(self, registry: 'ConfigRegistry'): self._registry = registry def load(self) -> dict[str, str]: From 30e8412de6d3ec5e7ae929c62d6a40c28829a88a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:43:42 +0800 Subject: [PATCH 02/14] wip --- src/twinkle/data_format/sampling.py | 1 - src/twinkle/dataloader/dataloader.py | 3 +- .../dataset/iterable_packing_dataset.py | 30 ++++++++++++++----- src/twinkle/dataset/packing_dataset.py | 2 ++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index 1d5fe07c6..01ff0377d 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index c392d56cf..408c8d4b4 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -146,7 +146,7 @@ def _tracking_iter(self, inner): def skip_consumed_samples(self, consumed_train_samples: int) -> None: from torch.utils.data import IterableDataset - if isinstance(self.dataset, IterableDataset): + if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0: warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') self._skip_samples = 0 return @@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs): @remote_function() def get_state(self) -> dict: + """The dataloader state for saving.""" return {'consumed_train_samples': self._consumed_train_samples} def _rebuild_sampler_stack(self): diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py index ca7c6fbd8..ab1d3a982 100644 --- a/src/twinkle/dataset/iterable_packing_dataset.py +++ b/src/twinkle/dataset/iterable_packing_dataset.py @@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples): last_res += res return last_res - @staticmethod - def _cyclic_iter(iterable): - while True: - yield from iterable + def _write_through_iter(self, iterable): + """Yields from iterable, meanwhile, save it to disk if needed. + Saving is needed when you are using several datasets at a time. + """ + if not self.cyclic: + for row in iterable: + self._write_through(row) + yield row + return + else: + first_pass = True + while True: + empty = True + for row in iterable: + empty = False + if first_pass: + self._write_through(row) + yield row + if empty: + return + first_pass = False @remote_function() def __iter__(self): @@ -102,10 +119,7 @@ def __iter__(self): except StopIteration: return - if self.cyclic: - iterator = self._cyclic_iter(self.dataset) - else: - iterator = iter(self.dataset) + iterator = self._write_through_iter(self.dataset) data = [] max_length = self.template.max_length or 2048 while True: diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py index fa4acbd57..ada9498b8 100644 --- a/src/twinkle/dataset/packing_dataset.py +++ b/src/twinkle/dataset/packing_dataset.py @@ -114,6 +114,8 @@ def __getitem__(self, index): assert self._packed_called, 'Call `pack_dataset()` first before index the sample.' sequence = self.packed_idx[index] rows = [self.dataset[i] for i in sequence] + for row in rows: + self._write_through(row) output = {} for key in rows[0]: output[key] = [r[key] for r in rows] From 6909de3bd3f2ef03ffefb45524a40ee2c02105de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 11:44:19 +0800 Subject: [PATCH 03/14] wip --- docs/source_en/Components/Gym/Gym.md | 3 ++- "docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" | 3 ++- src/twinkle/gym/__init__.py | 2 -- src/twinkle/gym/base.py | 10 ---------- 4 files changed, 4 insertions(+), 14 deletions(-) delete mode 100644 src/twinkle/gym/__init__.py delete mode 100644 src/twinkle/gym/base.py diff --git a/docs/source_en/Components/Gym/Gym.md b/docs/source_en/Components/Gym/Gym.md index 4db355b8a..8d243677c 100644 --- a/docs/source_en/Components/Gym/Gym.md +++ b/docs/source_en/Components/Gym/Gym.md @@ -3,7 +3,8 @@ The Gym component provides an interface for reinforcement learning environments in Twinkle. ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" index 63dc87aa7..4e34ffb93 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/Gym/Gym.md" @@ -3,7 +3,8 @@ Gym 组件为 Twinkle 中的强化学习环境提供接口。 ```python -from twinkle.gym import Gym +from twinkle_agentic.env import Gym + class CustomGym(Gym): diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py deleted file mode 100644 index 44b0771bb..000000000 --- a/src/twinkle/gym/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from .base import Gym diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py deleted file mode 100644 index aca798093..000000000 --- a/src/twinkle/gym/base.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. - - -class Gym: - - def __init__(self): - pass - - def step(self): - pass From c7a8b88b125545c64b7ff4b748ded32725bfa422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 14:02:40 +0800 Subject: [PATCH 04/14] wip --- src/twinkle/infra/__init__.py | 20 +-- src/twinkle/loss/chunked_cross_entropy.py | 200 ++++++++++++++++------ src/twinkle/notifier/base.py | 2 +- 3 files changed, 152 insertions(+), 70 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 83e10d132..23fb1fdae 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -1,11 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import functools import inspect -import itertools import json import numpy as np import os -import random import sys from typing import Any, Callable, List, Literal, Optional, TypeVar, Union @@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None: prefix = f'[twinkle driver caller: {caller}] ' exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), ) exc._twinkle_caller_augmented = True - except Exception: # noqa: BLE001 + except Exception: # noqa pass @@ -404,6 +402,7 @@ def dispatch_func(arg, n): return result elif dispatch == 'slice_dp': + assert device_mesh is not None # split by dp. each worker in one ep will receive the same argument result = [] # if device_mesh is not None: @@ -420,14 +419,6 @@ def dispatch_func(arg, n): import torch if isinstance(arg, list) or isinstance(arg, torch.Tensor): _args = [] - if device_mesh is None: - total = len(arg) - chunk = max(1, (total + n - 1) // n) - for i in range(n): - start = i * chunk - end = min(total, start + chunk) - _args.append(arg[start:end]) - return _args for i in range(n): _args.append(arg[device_mesh.get_slice( len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))]) @@ -803,18 +794,13 @@ def wrapper(self, *args, **kwargs) -> T1: # And this is user independent, only decided by the code. _local_lazy_collect = self._lazy_collect if _local_lazy_collect: - # Wrap the deferred collector so that exceptions - # raised when the caller later materializes the - # result also trigger the notifier. Attributes - # (``_futures`` etc.) on the original collector - # are preserved for downstream code paths. _orig_result_func = result_func @functools.wraps(_orig_result_func) def _notifying_result_func(*rargs, **rkwargs): try: return _orig_result_func(*rargs, **rkwargs) - except Exception as _e: # noqa: BLE001 + except Exception as _e: # noqa _tag_exc(_e, _caller) notify_exception(_notifier, _ctx, _e, _name) raise diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py index 22d3d4077..061ca2168 100644 --- a/src/twinkle/loss/chunked_cross_entropy.py +++ b/src/twinkle/loss/chunked_cross_entropy.py @@ -1,63 +1,159 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import math -from typing import Any - -from ..data_format import LossOutput +from twinkle.data_format import LossOutput from .base import Loss +# Lazily-built singleton autograd.Function, so we neither pay the +# class-construction cost on every forward nor force a top-level torch import. +_CHUNKED_CE_FUNC = None + + +def _get_chunked_ce_func(): + global _CHUNKED_CE_FUNC + if _CHUNKED_CE_FUNC is not None: + return _CHUNKED_CE_FUNC + + import torch + import torch.nn.functional as F + + class _ChunkedCrossEntropyFunc(torch.autograd.Function): + """Chunked CE that materialises log_softmax(B, V) only one chunk at a time. + + Forward returns a scalar loss; backward writes per-token gradients into + a freshly allocated `grad_logits` tensor (the input `logits` is never + mutated). Mathematically equivalent to ``CrossEntropyLoss`` in the same + package; ``chunk_size`` only controls the memory/throughput trade-off. + """ + + @staticmethod + def forward(ctx, logits, labels, chunk_size, ignore_index, reduction, dft): + ctx.save_for_backward(logits, labels) + ctx.chunk_size = chunk_size + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.dft = dft + + n = logits.shape[0] + # Use fp32 accumulators so we don't lose precision when summing + # over many tokens under fp16/bf16 autocast (matches cross_entropy.py). + total_loss = logits.new_zeros((), dtype=torch.float32) + total_count = logits.new_zeros((), dtype=torch.float32) + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end] + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + + total_loss = total_loss + (per_token * mask).sum() + total_count = total_count + mask.sum() + + ctx.num_tokens = total_count.detach() + if reduction == 'mean': + return total_loss / total_count.clamp(min=1) + return total_loss + + @staticmethod + def backward(ctx, grad_output): + logits, labels = ctx.saved_tensors + chunk_size = ctx.chunk_size + ignore_index = ctx.ignore_index + reduction = ctx.reduction + dft = ctx.dft + + if reduction == 'mean': + scale = grad_output / ctx.num_tokens.clamp(min=1) + else: + scale = grad_output + + grad_logits = torch.empty_like(logits) + n = logits.shape[0] + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + logits_chunk = logits[start:end].detach().requires_grad_(True) + labels_chunk = labels[start:end] + mask = (labels_chunk != ignore_index).float() + + with torch.enable_grad(): + logps = F.log_softmax(logits_chunk, dim=-1).gather( + -1, labels_chunk.clamp(min=0).unsqueeze(-1)).squeeze(-1) + per_token = -logps * logps.exp() if dft else -logps + loss_chunk = (per_token * mask).sum() + + grad_chunk = torch.autograd.grad(loss_chunk, logits_chunk, retain_graph=False)[0] + grad_logits[start:end] = grad_chunk * scale + + # logits, labels, chunk_size, ignore_index, reduction, dft + return grad_logits, None, None, None, None, None + + _CHUNKED_CE_FUNC = _ChunkedCrossEntropyFunc + return _CHUNKED_CE_FUNC + class ChunkedCrossEntropyLoss(Loss): - """TODO untested code""" + """CE loss that chunks the (B, V) softmax to bound peak memory. + + Drop-in replacement for :class:`CrossEntropyLoss` when ``outputs['logits']`` + is large (e.g. long sequence x big vocab). Behaviour matches that loss + bit-for-bit; ``chunk_size`` only affects memory/throughput. + + Args: + chunk_size: How many rows of ``logits`` to process per chunk. + ignore_index: Label id treated as padding (excluded from loss). + reduction: ``'mean'`` or ``'sum'``; matches ``CrossEntropyLoss``. + dft: If True, use DFT weighting ``-p*log(p)`` (arxiv 2508.05629). + """ - def __init__(self, chunk_size): + require_logits = True + # We chunk the (B, V) softmax ourselves; tell upstream not to materialise + # `logps` (which would already pay the full memory cost we're trying to + # avoid). The `_loss_from_logps` fast path is kept only for the rare case + # where someone explicitly hands us pre-computed logps. + require_logps = False + + def __init__(self, + chunk_size: int, + ignore_index: int = -100, + reduction: str = 'mean', + dft: bool = False, + **kwargs): + super().__init__() + assert chunk_size > 0, 'chunk_size must be positive' + assert reduction in ('mean', 'sum'), f"reduction must be 'mean' or 'sum', got {reduction!r}" self.chunk_size = chunk_size + self.ignore_index = ignore_index + self.reduction = reduction + self.dft = dft def __call__(self, inputs, outputs, **kwargs): - import torch - - class ChunkedCrossEntropyLossFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, labels, chunk_size): - import torch - ctx.save_for_backward(logits, labels) - ctx.chunk_size = chunk_size - - losses = [] - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end] - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - loss_chunk = loss_fct(logits_chunk, labels_chunk) - losses.append(loss_chunk) - del logits_chunk - del labels_chunk - all_losses = torch.cat(losses) - return all_losses - - @staticmethod - def backward(ctx: Any, *grad_outputs: Any): - import torch - logits, labels = ctx.saved_tensors - chunk_size = ctx.chunk_size - - for i in range(math.ceil(logits.shape[0] / chunk_size)): - l_start = i * chunk_size - l_end = min((i + 1) * chunk_size, logits.shape[0]) - logits_chunk = logits[l_start:l_end].detach().requires_grad_(True) - labels_chunk = labels[l_start:l_end] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - with torch.enable_grad(): - loss_chunk = loss_fct(logits_chunk, labels_chunk) - grad_output_chunk = grad_outputs[0][l_start:l_end] - _loss_chunk = (loss_chunk * grad_output_chunk).sum() - grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0] - logits[l_start:l_end] = grad_chunk - - return logits, None, None + labels = inputs['labels'] + logps = outputs.get('logps') + + # Fast path: if logps is already gathered upstream, chunking the + # softmax is moot — fall back to the same scalar formula as + # CrossEntropyLoss to keep behaviour identical. + if logps is not None: + return self._loss_from_logps(labels, logps) logits = outputs['logits'] - labels = inputs['labels'] - return LossOutput(loss=ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size), num_tokens=0) + labels = labels.view(-1) + logits = logits.view(-1, logits.shape[-1]) + + func = _get_chunked_ce_func() + loss = func.apply(logits, labels, self.chunk_size, self.ignore_index, self.reduction, self.dft) + + if self.reduction == 'mean': + return LossOutput(loss=loss, num_tokens=0) + num_tokens = (labels != self.ignore_index).float().sum().clamp(min=1) + return LossOutput(loss=loss, num_tokens=num_tokens) + + def _loss_from_logps(self, labels, logps): + mask = (labels != self.ignore_index).float() + per_token = -logps * logps.exp() if self.dft else -logps + if self.reduction == 'mean': + return LossOutput(loss=(per_token * mask).sum() / mask.sum().clamp(min=1), num_tokens=0) + return LossOutput(loss=(per_token * mask).sum(), num_tokens=mask.sum().clamp(min=1)) diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index a83903b53..6c138c997 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -66,7 +66,7 @@ def notify_exception(notifier: Notifier, context: str, exc: BaseException, name: if not _try_claim_notify_slot(exc, context, name): try: setattr(exc, '_twinkle_notified', True) - except Exception: # noqa: BLE001 + except Exception: # noqa pass return From 47a83204450a99182876d2093a7ed5059a5ba7cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:13:29 +0800 Subject: [PATCH 05/14] wip --- src/twinkle/loss/dpo.py | 44 +++++++++++++++------- src/twinkle/loss/gkd.py | 19 +++++++--- src/twinkle/loss/grpo.py | 31 ++++++---------- src/twinkle/loss/infonce.py | 74 ++++++++++++++++++++++--------------- 4 files changed, 102 insertions(+), 66 deletions(-) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index fe526ab46..d53019513 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -7,7 +7,7 @@ (https://arxiv.org/abs/2305.18290) """ from typing import TYPE_CHECKING, Dict, List, Optional, Union - +import math from twinkle.data_format import LossOutput from twinkle.loss.base import Loss from twinkle.utils.torch_utils import selective_log_softmax @@ -132,6 +132,12 @@ def __init__( **kwargs, ): super().__init__(ignore_index=ignore_index) + if loss_type not in ('sigmoid', 'hinge', 'ipo', 'kto_pair'): + raise ValueError(f'Unknown loss_type: {loss_type}') + if label_smoothing > 0 and loss_type != 'sigmoid': + raise ValueError( + f'label_smoothing > 0 is only defined for loss_type="sigmoid", ' + f'got loss_type="{loss_type}". Set label_smoothing=0.0 or switch to sigmoid.') self.beta = beta self.label_smoothing = label_smoothing self.loss_type = loss_type @@ -217,6 +223,11 @@ def _compute_dpo_loss( if self.loss_type == 'sigmoid': # Standard DPO loss: -log(sigmoid(beta * margin)) losses = -F.logsigmoid(logits) + # Apply label smoothing (only meaningful here: Bradley-Terry soft labels). + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses elif self.loss_type == 'hinge': # Hinge loss variant losses = torch.relu(1 - logits) @@ -234,12 +245,6 @@ def _compute_dpo_loss( else: raise ValueError(f'Unknown loss_type: {self.loss_type}') - # Apply label smoothing if specified - if self.label_smoothing > 0: - # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected - smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference - losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses - return losses.mean() def __call__( @@ -321,7 +326,8 @@ def __call__( reference_chosen_logps = torch.zeros_like(policy_chosen_logps) reference_rejected_logps = torch.zeros_like(policy_rejected_logps) else: - return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0) + zero = (policy_chosen_logps.sum() + policy_rejected_logps.sum()) * 0.0 + return LossOutput(loss=zero, num_tokens=0) # Compute DPO loss dpo_loss = self._compute_dpo_loss( @@ -535,11 +541,23 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) - # Compute entirely in log-space to avoid exp() underflow: - # log(p) = avg_logps (already in log-space) - # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps)) + # Compute log(1-p) = log(1 - exp(avg_logp)) numerically stably: + # - For x > -log(2): log(-expm1(x)) (avoids log(0) when p → 1) + # - For x ≤ -log(2): log1p(-exp(x)) (avoids cancellation when p → 0) + # ``avg_logp ∈ (-∞, 0]`` so the threshold partitions the safe regime. + log_two = math.log(2.0) + + def _log1mexp(x: 'torch.Tensor') -> 'torch.Tensor': + # Clamp at a tiny negative to keep both branches well-defined when p≈1. + x_safe = torch.clamp(x, max=-1e-7) + return torch.where( + x_safe > -log_two, + torch.log(-torch.expm1(x_safe)), + torch.log1p(-torch.exp(x_safe)), + ) + + log_odds_chosen = chosen_avg_logps - _log1mexp(chosen_avg_logps) + log_odds_rejected = rejected_avg_logps - _log1mexp(rejected_avg_logps) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 3f7db4bfb..7c198ad02 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -41,6 +41,10 @@ def __init__( chunk_size: int = 512, **kwargs, ): + if not (0.0 <= beta <= 1.0): + raise ValueError(f'beta must be in [0, 1], got {beta}') + if temperature <= 0: + raise ValueError(f'temperature must be > 0, got {temperature}') self.beta = beta self.temperature = temperature self.ignore_index = ignore_index @@ -94,6 +98,7 @@ def __call__( labels=labels, beta=self.beta, temperature=self.temperature, + ignore_index=self.ignore_index, chunk_size=self.chunk_size, topk=topk, teacher_topk_logprobs=teacher_topk_logprobs, @@ -108,6 +113,7 @@ def _generalized_jsd_loss( labels=None, beta: float = 0.5, temperature: float = 1.0, + ignore_index: int = -100, chunk_size: int = 512, topk: Optional[int] = None, teacher_topk_logprobs=None, @@ -164,7 +170,7 @@ def _generalized_jsd_loss( # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: - mask = labels != -100 # ignore_index is always -100 per convention + mask = labels != ignore_index # Vocab-size mismatch (e.g. Qwen2.5-VL-3B vs 7B): pad the smaller side # so both distributions are defined over the same token set. stu_dim = student_logits.shape[-1] @@ -178,12 +184,15 @@ def _generalized_jsd_loss( student_logits = student_logits[mask] # [num_valid, vocab/topk] teacher_logits = teacher_logits[mask] num_valid = mask.sum() + # ``[mask]`` already created fresh storage, so in-place divide is safe + # and avoids an extra [num_valid, V] allocation. + student_logits.div_(temperature) + teacher_logits.div_(temperature) else: - student_logits = student_logits.view(-1, student_logits.size(-1)) - teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + # Keep logits, may be an infer scenario + student_logits = student_logits.reshape(-1, student_logits.size(-1)) / temperature + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) / temperature num_valid = student_logits.size(0) - student_logits.div_(temperature) - teacher_logits.div_(temperature) if num_valid == 0: return student_logits.new_zeros(()) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index 4bb71216c..781b22060 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -42,18 +42,6 @@ def __init__( self.require_entropy = entropy_coef > 0.0 self.ignore_index = ignore_index - def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor': - """ - Compute loss mask from labels. - - Args: - labels: [batch, seq_len] target token ids, -100 for ignored positions - - Returns: - mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored - """ - return (labels != self.ignore_index).float() - def _compute_log_importance_weights( self, per_token_logps: 'torch.Tensor', @@ -165,10 +153,13 @@ def _pad_and_align_to_batch( return data # Already aligned if data.dim() == 1: data = data.unsqueeze(1) - if data.shape[1] == 1: # Scalars - result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) - result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0] - return result + if data.shape[1] == 1: + assert data.shape[0] == batch_size, ( + f'scalar broadcast expects data.shape[0]==batch_size, ' + f'got data.shape={tuple(data.shape)} mask.shape={(batch_size, seq_len)}') + fill = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device) + expanded = data.expand(batch_size, seq_len) + return torch.where(mask, expanded, fill) data = [data[i] for i in range(batch_size)] # To list # Handle list (scalars or sequences) @@ -276,10 +267,12 @@ def __call__( ) # GRPO loss is ill-defined without advantages (e.g. ref-logps-only forward, - # or eval/validation forwards). Return a zero loss so the forward still - # flows through cleanly and callers can harvest outputs['logps'] freely. + # or eval/validation forwards). Return a zero loss that still flows through + # autograd so DDP/FSDP do not see unused params, and callers can harvest + # outputs['logps'] freely. if advantages is None: - return LossOutput(loss=torch.zeros((), device=device, dtype=logps.dtype), num_tokens=0) + zero = logps.sum() * 0.0 + return LossOutput(loss=zero, num_tokens=0) advantages = self._pad_and_align_to_batch( advantages, diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py index 68d14840c..c356bd64c 100644 --- a/src/twinkle/loss/infonce.py +++ b/src/twinkle/loss/infonce.py @@ -13,8 +13,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F -from enum import Enum from torch import nn from typing import Optional @@ -22,15 +20,6 @@ from .base import Loss -# Borrowed from sentence_transformers. -class SiameseDistanceMetric(Enum): - """Distance metrics available to the pairwise contrastive losses.""" - - EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa - MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa - COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa - - def _extract_sentences(outputs) -> torch.Tensor: """Return [B, D] sentence embeddings from postprocess_tensor_sp output. @@ -119,6 +108,11 @@ def __init__( process_group=None, **kwargs, ): + if mask_fake_negative and fake_neg_margin <= 0: + raise ValueError( + f'fake_neg_margin must be > 0 when mask_fake_negative=True, got {fake_neg_margin}. ' + 'A non-positive margin would mask out the positive itself or every above-positive ' + 'logit indiscriminately, collapsing the contrastive signal.') self.temperature = temperature self.use_batch = use_batch self.hard_negatives = hard_negatives @@ -129,7 +123,13 @@ def __init__( self.process_group = process_group def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): - """All-gather embeddings & labels across DP ranks; only local shard keeps grad.""" + """All-gather embeddings & labels across DP ranks; only local shard keeps grad. + + NCCL ``all_gather`` requires every rank to send the *same* tensor size. Under + ``slice_dp`` dispatch the per-rank batch is uneven (``divmod`` splits), so we + pad each rank to the global max along dim-0, do an equal-sized all_gather, + then strip padding back. Only the local shard retains gradients. + """ if not (dist.is_available() and dist.is_initialized()): return sentences, labels world_size = dist.get_world_size(group=self.process_group) @@ -137,24 +137,40 @@ def _gather_across_dp(self, sentences: torch.Tensor, labels: torch.Tensor): return sentences, labels rank = dist.get_rank(group=self.process_group) - # variable per-rank shapes require communicating shape first - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape, group=self.process_group) - all_sentences = [sentences.new_empty(shape.tolist()) for shape in shapes] - dist.all_gather(all_sentences, sentences.contiguous(), group=self.process_group) - - local_label_shape = labels.new_tensor(labels.shape, dtype=torch.long) - label_shapes = [torch.empty_like(local_label_shape) for _ in range(world_size)] - dist.all_gather(label_shapes, local_label_shape, group=self.process_group) - all_labels = [labels.new_empty(shape.tolist()) for shape in label_shapes] - dist.all_gather(all_labels, labels.contiguous(), group=self.process_group) - - # keep the local shard differentiable; detach others - all_sentences[rank] = sentences + # ``labels`` is a 1-D mask aligned to ``sentences`` along dim-0, so they + # share the same per-rank size. Gather sizes once and reuse for both. + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n, group=self.process_group) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: torch.Tensor): + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous(), group=self.process_group) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + + # Strip padding; keep local shard differentiable, detach others. + all_sentences = [] + all_labels = [] for idx in range(world_size): - if idx != rank: - all_sentences[idx] = all_sentences[idx].detach() + n = sizes_int[idx] + if idx == rank: + all_sentences.append(sentences) + all_labels.append(labels) + else: + all_sentences.append(sent_buffers[idx][:n].detach()) + all_labels.append(label_buffers[idx][:n]) return torch.cat(all_sentences, dim=0), torch.cat(all_labels, dim=0) def __call__(self, inputs, outputs, **kwargs) -> LossOutput: From 18526e508eab5c9f62d9a747b942d7d218eb954f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:17:38 +0800 Subject: [PATCH 06/14] wip --- src/twinkle/metric/accuracy.py | 1 - src/twinkle/metric/embedding.py | 6 +++--- src/twinkle/metric/train_metric.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py index b3034c57a..4dfb01198 100644 --- a/src/twinkle/metric/accuracy.py +++ b/src/twinkle/metric/accuracy.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np from typing import List, Union from ..data_format import InputFeature, ModelOutput diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index 9fb3aed8c..ec67ac020 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -1,7 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import torch -import torch.distributed as dist -import torch.nn.functional as F from typing import List, Union from twinkle.data_format import InputFeature, ModelOutput @@ -32,6 +29,9 @@ def reset(self): self.grad_norm = 0.0 def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + import torch + import torch.distributed as dist + import torch.nn.functional as F sentences = outputs.get('embeddings') if sentences is None: sentences = outputs.get('logits') diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index da82a8783..8d785c38b 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -2,7 +2,7 @@ import time from typing import List, Union -from ..data_format import InputFeature, ModelOutput +from twinkle.data_format import InputFeature, ModelOutput from .base import Metric From 34a3e9a39e2584552d258beca4f56057aa1be8da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 17:39:33 +0800 Subject: [PATCH 07/14] wip --- src/twinkle/metric/dpo.py | 6 +++++ src/twinkle/metric/embedding.py | 40 +++++++++++++++++++++------------ src/twinkle/metric/grpo.py | 20 ++++++++++++++++- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index b203d255e..024cb0473 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -131,6 +131,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M ref_outputs = kwargs.get('ref_outputs') if ref_outputs is not None: ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + if isinstance(ref_logps, list): + if len(ref_logps) == 0: + ref_logps = None + else: + ref_logps = pad_and_stack_tensors(ref_logps) if ref_logps is not None: # Align ref_logps to match labels shape (handles different seq lengths) ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype) diff --git a/src/twinkle/metric/embedding.py b/src/twinkle/metric/embedding.py index ec67ac020..8b3681031 100644 --- a/src/twinkle/metric/embedding.py +++ b/src/twinkle/metric/embedding.py @@ -44,22 +44,34 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M inputs = [inputs] labels = torch.cat([inp['labels'].view(-1) for inp in inputs], dim=0) - # Gather embeddings and labels across DP for in-batch stats + # Gather embeddings and labels across DP for in-batch stats. + # NCCL ``all_gather`` requires every rank to send the same tensor size, + # but ``slice_dp`` dispatch (``divmod`` split) can leave per-rank dim-0 + # uneven. Pad to the global max along dim-0, gather, then strip padding. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: world_size = dist.get_world_size() - local_shape = sentences.new_tensor(sentences.shape, dtype=torch.long) - shapes = [torch.empty_like(local_shape) for _ in range(world_size)] - dist.all_gather(shapes, local_shape) - all_sentences = [sentences.new_empty(s.tolist()) for s in shapes] - dist.all_gather(all_sentences, sentences.contiguous()) - sentences = torch.cat(all_sentences, dim=0) - - local_lshape = labels.new_tensor(labels.shape, dtype=torch.long) - lshapes = [torch.empty_like(local_lshape) for _ in range(world_size)] - dist.all_gather(lshapes, local_lshape) - all_labels = [labels.new_empty(s.tolist()) for s in lshapes] - dist.all_gather(all_labels, labels.contiguous()) - labels = torch.cat(all_labels, dim=0) + assert sentences.shape[0] == labels.shape[0], ( + f'sentences/labels dim-0 mismatch: {sentences.shape[0]} vs {labels.shape[0]}') + local_n = torch.tensor([sentences.shape[0]], device=sentences.device, dtype=torch.long) + sizes = [torch.empty_like(local_n) for _ in range(world_size)] + dist.all_gather(sizes, local_n) + sizes_int = [int(s.item()) for s in sizes] + max_n = max(sizes_int) + + def _pad_gather(tensor: 'torch.Tensor') -> 'List[torch.Tensor]': + if tensor.shape[0] < max_n: + pad_shape = (max_n - tensor.shape[0],) + tuple(tensor.shape[1:]) + padded = torch.cat([tensor, tensor.new_zeros(pad_shape)], dim=0) + else: + padded = tensor + buffers = [torch.empty_like(padded) for _ in range(world_size)] + dist.all_gather(buffers, padded.contiguous()) + return buffers + + sent_buffers = _pad_gather(sentences) + label_buffers = _pad_gather(labels) + sentences = torch.cat([sent_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) + labels = torch.cat([label_buffers[i][:sizes_int[i]] for i in range(world_size)], dim=0) anchor_idx = torch.nonzero(labels, as_tuple=False).squeeze(-1) if anchor_idx.numel() == 0: diff --git a/src/twinkle/metric/grpo.py b/src/twinkle/metric/grpo.py index 06e082eeb..e2797b1ec 100644 --- a/src/twinkle/metric/grpo.py +++ b/src/twinkle/metric/grpo.py @@ -3,9 +3,12 @@ from typing import Any, Dict, List, Optional, Union from twinkle.data_format import InputFeature, ModelOutput +from twinkle.utils import get_logger from twinkle.utils.transformers_utils import align_logps_to_mask from .base import Metric +logger = get_logger() + class GRPOMetric(Metric): @@ -254,6 +257,11 @@ def accumulate( if len(seq_lens) == 1: merged = torch.cat(label_tensors, dim=0) inputs_list = [{'labels': merged}] + else: + logger.warning( + f'GRPOMetric: logps is a single tensor but inputs_list has ' + f'{len(inputs_list)} mb with mismatched seq_lens={sorted(seq_lens)}. ' + f'Only mb[0] will be accumulated; check the model forward path.') flat_old: Optional[List] = None if old_logps is not None and isinstance(old_logps, (list, tuple)): @@ -284,7 +292,17 @@ def accumulate( # Uncommon: aligned global tensor. Only honour when it # exactly matches the single-mb shape; otherwise drop. import torch as _torch # noqa: F811 - old_slice = old_logps if (_torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape) else None + if _torch.is_tensor(old_logps) and old_logps.shape == logps_mb.shape: + old_slice = old_logps + else: + if mb_idx == 0: + # Warn once per accumulate call (not per mb) to avoid log spam. + old_shape = tuple(old_logps.shape) if _torch.is_tensor(old_logps) else 'unknown' + logger.warning( + f'GRPOMetric: old_logps shape {old_shape} does not match ' + f'logps_mb shape {tuple(logps_mb.shape)}; ratio/kl metrics will ' + f'be skipped for this step.') + old_slice = None else: old_slice = None From 420d0da101124a625577c417052bd70a3ad6f823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 18:36:09 +0800 Subject: [PATCH 08/14] wip --- src/twinkle/data_format/output.py | 2 ++ src/twinkle/model/megatron/megatron.py | 34 +++++++++++++------------- src/twinkle/processor/base.py | 1 - 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py index 763ef246f..596252fb6 100644 --- a/src/twinkle/data_format/output.py +++ b/src/twinkle/data_format/output.py @@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False): loss: The loss calculated by the model. logps: The log-probabilities of correct tokens by the model. num_tokens: The token denominator associated with ``loss``. + embeddings: The embeddings output by the model, used be embedding task. """ logits: Optional[OutputType] loss: Optional[OutputType] logps: Optional[OutputType] num_tokens: Optional[OutputType] + embeddings: Optional[OutputType] class LossOutput(TypedDict, total=False): diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a5ea3fc56..bd24bc027 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -420,7 +420,7 @@ def forward_step_func(data_iterator, model): embeddings = output_tensor elif labels is not None and is_last_pp: _loss_require_logps = getattr(_loss_instance, 'require_logps', True) - _loss_require_entropy = (hasattr(_loss_instance, 'require_entropy') and _loss_instance.require_entropy) + _loss_require_entropy = getattr(_loss_instance, 'require_entropy', True) _packed = batch.get('packed_seq_params') cu_seqlens_q = getattr(_packed, 'cu_seqlens_q', None) if _packed is not None else None if _loss_require_logps: @@ -446,7 +446,7 @@ def forward_step_func(data_iterator, model): _outputs = {'logps': logps} if entropies is not None: _outputs['entropies'] = entropies - if hasattr(_loss_instance, 'require_logits') and _loss_instance.require_logits: + if getattr(_loss_instance, 'require_logits', False): _outputs['logits'] = output_tensor batch, _outputs = processor.unpack_packed_sequences(batch, _outputs) logps = _outputs['logps'] @@ -1139,7 +1139,7 @@ def _load_mcore_optimizer( ) iteration = self._read_iteration(tracker_path) if iteration == 0: - logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}') + logger.warning(f'No checkpoint found in {checkpoint_dir}') return iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}') @@ -1216,21 +1216,21 @@ def _load_mcore_optimizer( @staticmethod def _read_iteration(tracker_path: str) -> int: - if not os.path.exists(tracker_path): - return 0 - with open(tracker_path) as f: - iteration = int(f.read().strip()) + # All ranks must enter the all_reduce together; missing tracker on some + # ranks (e.g. NFS lag, partial mount) must NOT short-circuit, otherwise + # the remaining ranks hang at the collective. Treat missing as 0 and + # let MAX reduction recover the canonical iteration from any rank that + # successfully read the file. + iteration = 0 + if os.path.exists(tracker_path): + with open(tracker_path) as f: + iteration = int(f.read().strip()) if torch.distributed.is_initialized(): - iters_cuda = torch.tensor( - [iteration], - dtype=torch.long, - device='cuda', - ) - torch.distributed.all_reduce( - iters_cuda, - op=torch.distributed.ReduceOp.MAX, - ) - iteration = iters_cuda[0].item() + # Use Platform.get_local_device() to stay backend-agnostic + # (CUDA / NPU / MPS); 'cuda' would crash on NPU. + iters_dev = torch.tensor([iteration], dtype=torch.long, device=Platform.get_local_device()) + torch.distributed.all_reduce(iters_dev, op=torch.distributed.ReduceOp.MAX) + iteration = int(iters_dev[0].item()) return iteration def _merge_lora_adapters(self, adapter_name: str = 'default'): diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 8709d98ab..d600ec2b7 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -776,6 +776,5 @@ def postprocess_tensor_cp(self, tensor, cu_seqlens=None): if self.device_mesh.cp_world_size <= 1: return tensor from megatron.core import parallel_state as mpu - from twinkle.utils.torch_utils import gather_cp_load_balanced return gather_cp_load_balanced(tensor, mpu.get_context_parallel_group(), seq_dim=1, cu_seqlens=cu_seqlens) From 5f093e5aa4c7ec9cea21f3b6c72dad9b31fbbda2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 21:06:32 +0800 Subject: [PATCH 09/14] wip --- src/twinkle/infra/__init__.py | 2 +- src/twinkle/model/megatron/megatron.py | 12 +++++++---- .../model/megatron/multi_lora_megatron.py | 15 ++++++++----- src/twinkle/server/state/backend/factory.py | 5 ++--- src/twinkle/server/state/base.py | 4 ++-- src/twinkle/server/state/session_manager.py | 4 ++-- src/twinkle/server/telemetry/provider.py | 3 ++- src/twinkle/server/telemetry/worker_init.py | 5 +++-- src/twinkle/utils/platforms/base.py | 21 +++++++++++++++++++ src/twinkle/utils/platforms/gpu.py | 13 ++++++++++++ src/twinkle/utils/platforms/mps.py | 19 +++++++++++++++++ src/twinkle/utils/platforms/npu.py | 21 +++++++++++++++++++ 12 files changed, 104 insertions(+), 20 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 23fb1fdae..1227584fb 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -687,7 +687,7 @@ def __next__(_self): return decorator -def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice', +def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none', sync: bool = False, diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index bd24bc027..c7174df94 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -990,7 +990,9 @@ def _get_rng_state() -> 'ShardedObject': 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), + # Backend-agnostic device RNG (CUDA / NPU / MPS); key kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + 'cuda_rng_state': Platform.get_device_rng_state(), 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = [rng_state] @@ -1112,7 +1114,7 @@ def _save_mcore_optimizer( with open(tracker_path, 'w') as f: f.write(str(iteration)) - logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} ' + logger.info(f'Saved mcore optimizer state at iteration {iteration} ' f'to {checkpoint_dir}') def _load_mcore_optimizer( @@ -1201,7 +1203,9 @@ def _load_mcore_optimizer( random.setstate(rng['random_rng_state']) np.random.set_state(rng['np_rng_state']) torch.set_rng_state(rng['torch_rng_state']) - torch.cuda.set_rng_state(rng['cuda_rng_state']) + # Backend-agnostic restore: tolerates ckpt produced on different backend + # (returns None) and avoids hard-coded torch.cuda which crashes on NPU. + Platform.set_device_rng_state(rng.get('cuda_rng_state')) tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], ) # Restore iteration counter. @@ -1211,7 +1215,7 @@ def _load_mcore_optimizer( if dist.is_initialized(): dist.barrier() - logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} ' + logger.info(f'Resumed from mcore checkpoint at iteration {iteration} ' f'from {checkpoint_dir}') @staticmethod diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 2dd6b7a53..9a7e7784d 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -15,7 +15,7 @@ from transformers import AutoConfig, PretrainedConfig from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union -from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util +from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, template, torch_util from twinkle.data_format import InputFeature, Trajectory from twinkle.hub import HubOperation from twinkle.infra import collect_tensor_dict @@ -221,8 +221,11 @@ def _save_local_training_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), } - if torch.cuda.is_available(): - rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + # Backend-agnostic device RNG capture (CUDA / NPU / MPS). Key is kept as + # 'cuda_rng_state' for backward compatibility with existing checkpoints. + device_rng = Platform.get_device_rng_state() + if device_rng is not None: + rng_state['cuda_rng_state'] = device_rng rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() return rng_state @@ -233,8 +236,10 @@ def _load_local_training_rng_state(rng_state): random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) - if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Backend-agnostic device RNG restore: tolerates ckpt produced on different + # backend (key absent or None) and avoids hard-coded torch.cuda on NPU. + if 'cuda_rng_state' in rng_state: + Platform.set_device_rng_state(rng_state['cuda_rng_state']) tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): diff --git a/src/twinkle/server/state/backend/factory.py b/src/twinkle/server/state/backend/factory.py index 326a6916f..be24c7824 100644 --- a/src/twinkle/server/state/backend/factory.py +++ b/src/twinkle/server/state/backend/factory.py @@ -1,12 +1,11 @@ """Backend factory for creating StateBackend instances based on configuration.""" from __future__ import annotations -import logging - from twinkle.server.config.persistence import PersistenceConfig +from twinkle.utils import get_logger from .base import StateBackend -logger = logging.getLogger(__name__) +logger = get_logger() def create_backend(config: PersistenceConfig | None = None) -> StateBackend: diff --git a/src/twinkle/server/state/base.py b/src/twinkle/server/state/base.py index d931ce336..8cb055ae4 100644 --- a/src/twinkle/server/state/base.py +++ b/src/twinkle/server/state/base.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -import logging import time from abc import ABC, abstractmethod from datetime import datetime, timezone @@ -9,9 +8,10 @@ from typing import Generic, TypeVar from twinkle.server.state.backend.base import StateBackend +from twinkle.utils import get_logger T = TypeVar('T', bound=BaseModel) -logger = logging.getLogger(__name__) +logger = get_logger() class BaseManager(ABC, Generic[T]): diff --git a/src/twinkle/server/state/session_manager.py b/src/twinkle/server/state/session_manager.py index 442019ea4..1bc901b0c 100644 --- a/src/twinkle/server/state/session_manager.py +++ b/src/twinkle/server/state/session_manager.py @@ -2,14 +2,14 @@ from __future__ import annotations import functools -import logging import time +from twinkle.utils import get_logger from .backend.base import ConcurrencyError, StateBackend from .base import BaseManager from .models import SessionRecord -logger = logging.getLogger(__name__) +logger = get_logger() def _session_touch_transform(existing: dict | None, *, now: float) -> dict | None: diff --git a/src/twinkle/server/telemetry/provider.py b/src/twinkle/server/telemetry/provider.py index 77212c757..059f301fb 100644 --- a/src/twinkle/server/telemetry/provider.py +++ b/src/twinkle/server/telemetry/provider.py @@ -16,8 +16,9 @@ from typing import Any from twinkle.server.config.telemetry import TelemetryConfig +from twinkle.utils import get_logger -logger = logging.getLogger(__name__) +logger = get_logger() # Loggers belonging to the OTLP transport stack. Their own records must never # be routed back through the OTLP LoggingHandler: an exporter error logged diff --git a/src/twinkle/server/telemetry/worker_init.py b/src/twinkle/server/telemetry/worker_init.py index 997f2e140..40edc6628 100644 --- a/src/twinkle/server/telemetry/worker_init.py +++ b/src/twinkle/server/telemetry/worker_init.py @@ -7,10 +7,11 @@ """ from __future__ import annotations -import logging import os -logger = logging.getLogger(__name__) +from twinkle.utils import get_logger + +logger = get_logger() _worker_initialized = False diff --git a/src/twinkle/utils/platforms/base.py b/src/twinkle/utils/platforms/base.py index 71c2bd18f..483c725a3 100644 --- a/src/twinkle/utils/platforms/base.py +++ b/src/twinkle/utils/platforms/base.py @@ -136,3 +136,24 @@ def device_backend(platform: str = None): def get_vllm_device_uuid(device_id: int = 0, platform=None) -> str: platform = Platform.get_platform(platform) return platform.get_vllm_device_uuid(device_id) + + @staticmethod + def get_device_rng_state(platform: str = None): + """Return device-specific RNG state (e.g. CUDA / NPU / MPS). + + Backend-agnostic replacement for hard-coded ``torch.cuda.get_rng_state()``. + Returns ``None`` when no accelerator is available, so callers can safely + skip persistence on CPU-only or unsupported devices. + """ + return Platform.get_platform(platform).get_device_rng_state() + + @staticmethod + def set_device_rng_state(state, *, platform: str = None) -> None: + """Restore device-specific RNG state. + + No-op when ``state`` is ``None`` (e.g. checkpoint produced on a different + backend) or when the current platform has no accelerator available. + """ + if state is None: + return + Platform.get_platform(platform).set_device_rng_state(state) diff --git a/src/twinkle/utils/platforms/gpu.py b/src/twinkle/utils/platforms/gpu.py index 0b99f8855..0f213448e 100644 --- a/src/twinkle/utils/platforms/gpu.py +++ b/src/twinkle/utils/platforms/gpu.py @@ -24,3 +24,16 @@ def device_backend(platform: str = None): def get_vllm_device_uuid(device_id: int = 0) -> str: from vllm.platforms import current_platform return current_platform.get_device_uuid(device_id) + + @staticmethod + def get_device_rng_state(): + import torch + if torch.cuda.is_available(): + return torch.cuda.get_rng_state() + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state) diff --git a/src/twinkle/utils/platforms/mps.py b/src/twinkle/utils/platforms/mps.py index e99abb0e0..86ecf751e 100644 --- a/src/twinkle/utils/platforms/mps.py +++ b/src/twinkle/utils/platforms/mps.py @@ -40,3 +40,22 @@ def device_backend(platform: str = None): @staticmethod def get_vllm_device_uuid(device_id: int = 0) -> str: raise NotImplementedError + + @staticmethod + def get_device_rng_state(): + import torch + if hasattr(torch, 'mps') and hasattr(torch.mps, 'get_rng_state'): + try: + return torch.mps.get_rng_state() + except Exception: # noqa: BLE001 + return None + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + if hasattr(torch, 'mps') and hasattr(torch.mps, 'set_rng_state'): + try: + torch.mps.set_rng_state(state) + except Exception: # noqa: BLE001 + pass diff --git a/src/twinkle/utils/platforms/npu.py b/src/twinkle/utils/platforms/npu.py index 89066b280..de15707f6 100644 --- a/src/twinkle/utils/platforms/npu.py +++ b/src/twinkle/utils/platforms/npu.py @@ -133,3 +133,24 @@ def get_vllm_device_uuid(device_id: int = 0) -> str: visible = os.environ.get(Platform.visible_device_env()) raw = f'{socket.gethostname()}:{visible}:{device_id}' return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:16] + + @staticmethod + def get_device_rng_state(): + import torch + try: + import torch_npu # noqa: F401 + except ImportError: + return None + if hasattr(torch, 'npu') and torch.npu.is_available(): + return torch.npu.get_rng_state() + return None + + @staticmethod + def set_device_rng_state(state) -> None: + import torch + try: + import torch_npu # noqa: F401 + except ImportError: + return + if hasattr(torch, 'npu') and torch.npu.is_available(): + torch.npu.set_rng_state(state) From 520f797c32677b4863981e3b8915569f594e1b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 21:24:13 +0800 Subject: [PATCH 10/14] wip --- src/twinkle/model/megatron/megatron.py | 28 +++++++++++-------- .../model/megatron/multi_lora_megatron.py | 10 +++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index c7174df94..60b45f774 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1260,7 +1260,7 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non For distributed training: - All PP ranks participate in export (each has different layers) - - Only DP rank 0 actually writes to disk + - Only global rank 0 actually writes shared config files - Uses barrier for synchronization For LoRA training: @@ -1268,12 +1268,9 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non """ # Check if this is LoRA training is_peft_format = (adapter_name != _default_adapter_name) + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 - # Create output directory on rank 0 only - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 - - if dp_rank == 0: + if is_global_zero: os.makedirs(output_dir, exist_ok=True) # Synchronize before saving @@ -1285,8 +1282,8 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non self.strategy.bridge.save_weights( model, output_dir, peft_format=is_peft_format, adapter_name=adapter_name, converter=lora_converter) - # Save config on rank 0 only - if dp_rank == 0: + # Save config on global rank 0 only (avoid concurrent writers). + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): config = model[0].peft_config[adapter_name] @@ -1295,11 +1292,13 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non model[0].peft_config[adapter_name].save_pretrained(output_dir) config.target_modules = target_modules + if dist.is_initialized(): + dist.barrier() + def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" + is_global_zero = (not dist.is_initialized()) or dist.get_rank() == 0 os.makedirs(output_dir, exist_ok=True) - from megatron.core import parallel_state as mpu - dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0 state_dict = self._get_trainable_parameters(adapter_name) cpu_state_dict = {} for k, v in state_dict.items(): @@ -1315,13 +1314,18 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert rank = dist.get_rank() if dist.is_initialized() else 0 checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') torch.save(cpu_state_dict, checkpoint_path) - # Save config on rank 0 only + # Save shared config on global rank 0 only (avoid concurrent writers). model = self.strategy.unwrap_model(self.model) - if dp_rank == 0: + if is_global_zero: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): model[0].peft_config[adapter_name].save_pretrained(output_dir) + # Finalize barrier: ensure all ranks finish writing model_rank*.pt + # before the caller proceeds (e.g. uploading / loading the ckpt). + if dist.is_initialized(): + dist.barrier() + def _save_tokenizer(self, output_dir: str, **kwargs): from twinkle.utils import is_last_rank if not is_last_rank(): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 9a7e7784d..ae2eb4b2c 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -256,6 +256,9 @@ def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kw torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + if dist.is_initialized(): + dist.barrier() + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): no_load_optim = kwargs.pop('no_load_optim', False) no_load_rng = kwargs.pop('no_load_rng', False) @@ -265,6 +268,13 @@ def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '' if not no_load_optim and optimizer_config is not None: if optimizer_config.optimizer is not None and 'optimizer' in state_dict: optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + device = Platform.get_local_device() + for group_state in optimizer_config.optimizer.state.values(): + if not isinstance(group_state, dict): + continue + for k, v in group_state.items(): + if isinstance(v, torch.Tensor): + group_state[k] = v.to(device) if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) if not no_load_rng and 'rng_state' in state_dict: From df8855691bf706d517642ffc898510f530719769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Mon, 15 Jun 2026 22:02:48 +0800 Subject: [PATCH 11/14] wip --- src/twinkle/model/megatron/strategy/megatron.py | 1 - src/twinkle/model/transformers/strategy/accelerate.py | 9 +++------ src/twinkle/model/transformers/transformers.py | 9 ++++----- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py index 819014eb8..1bd809025 100644 --- a/src/twinkle/model/megatron/strategy/megatron.py +++ b/src/twinkle/model/megatron/strategy/megatron.py @@ -48,7 +48,6 @@ def __init__( ddp_config: Dict[str, Any] = None, **kwargs, ): - import torch.distributed as dist from megatron.core import mpu self.device_mesh = device_mesh self.use_distributed_optimizer = use_distributed_optimizer diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991f..018f4c494 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -34,10 +34,8 @@ def __init__( parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) - kwargs_handlers = [] - kwargs_handlers.append( - InitProcessGroupKwargs( - timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))) + kwargs_handlers = [InitProcessGroupKwargs( + timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))] if ddp_config is not None: from accelerate import DistributedDataParallelKwargs ddp_config = DistributedDataParallelKwargs(**ddp_config) @@ -131,8 +129,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di return fsdp_plugin def wrap_model(self, model, *args): - result = self.accelerator.prepare(model, *args) - return result + return self.accelerator.prepare(model, *args) def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 61733d7dc..cac375263 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -414,8 +414,8 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs = optimizer_config.template.batch_encode(inputs) # noqa processor: InputProcessor = optimizer_config.processor loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' inputs: Dict[str, Any] = processor( @@ -490,8 +490,8 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' loss_instance = optimizer_config.loss_instance - loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - loss_require_entropy = (hasattr(loss_instance, 'require_entropy') and loss_instance.require_entropy) + loss_require_logits = getattr(loss_instance, 'require_logits', False) + loss_require_entropy = getattr(loss_instance, 'require_entropy', False) loss_require_logps = getattr(loss_instance, 'require_logps', True) inputs: Dict[str, Any] = processor( inputs, @@ -929,7 +929,6 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int if optimizer_config.cur_step % interval != 0: return model = self.strategy.unwrap_model(self.model) - processed_state_dict = {} save_kwargs = {} if adapter_name == _default_adapter_name: # Full model save From 28ee773e5e3bc0b0c9427a580731ead33c02ad43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 14:20:17 +0800 Subject: [PATCH 12/14] wip --- cookbook/transformers/fsdp2.py | 12 +++++- src/twinkle/model/base.py | 3 +- src/twinkle/notifier/__init__.py | 1 + src/twinkle/notifier/base.py | 1 + src/twinkle/notifier/ding_notifier.py | 1 + src/twinkle/processor/base.py | 59 ++++++++++++++++++--------- 6 files changed, 55 insertions(+), 22 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ad4c917f9..ecad0220c 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -2,6 +2,7 @@ from peft import LoraConfig from tqdm import tqdm +from torch.optim import Muon # PyTorch 2.9+; matrix-orthogonalized momentum optimizer. import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger @@ -64,7 +65,16 @@ def train(): model.add_adapter_to_model( args.lora.adapter_name, lora_config, gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + # Muon optimizes 2D hidden-layer weight matrices via Newton-Schulz orthogonalization. + # In LoRA training the trainable params are exclusively lora_A / lora_B (both 2D), + # so Muon applies cleanly without an AdamW fallback for 1D params. + # ``adjust_lr_fn='match_rms_adamw'`` rescales the orthogonalized update so the same + # lr / weight_decay tuned for AdamW can be reused directly (Moonshot Muon recipe). + model.set_optimizer( + optimizer_cls=Muon, + lr=args.optimizer.learning_rate, + adjust_lr_fn='match_rms_adamw', + ) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index a4d4ea064..8ea00d696 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from twinkle import Platform, torch_util from twinkle.data_format import InputFeature, ModelOutput diff --git a/src/twinkle/notifier/__init__.py b/src/twinkle/notifier/__init__.py index 329cb6f1d..067db71a1 100644 --- a/src/twinkle/notifier/__init__.py +++ b/src/twinkle/notifier/__init__.py @@ -1,2 +1,3 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. from .base import Notifier, notify_exception from .ding_notifier import DingNotifier diff --git a/src/twinkle/notifier/base.py b/src/twinkle/notifier/base.py index 6c138c997..6f50ca659 100644 --- a/src/twinkle/notifier/base.py +++ b/src/twinkle/notifier/base.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import os from typing import Dict, Optional diff --git a/src/twinkle/notifier/ding_notifier.py b/src/twinkle/notifier/ding_notifier.py index fe102d8a5..fc535edd7 100644 --- a/src/twinkle/notifier/ding_notifier.py +++ b/src/twinkle/notifier/ding_notifier.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib import hmac diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index d600ec2b7..c505d91cd 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -42,7 +42,6 @@ class InputProcessor: 'video_grid_thw': 0, 'input_features': 0.0, 'feature_attention_mask': 0, - 'mm_token_type_ids': 0, } # VLM fields to concatenate (not pad) in batch @@ -108,8 +107,12 @@ def to_tensor(_input): # so tensor ops like labels != ignore_index or .to(device) would fail without this. if isinstance(value, np.ndarray): value = torch.from_numpy(value) - elif (isinstance(value, list) and isinstance(value[0], - (int, float, np.number))) or key == 'position_ids': + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], (int, float, np.number)): + value = torch.tensor(value) + elif key == 'position_ids' and not isinstance(value, torch.Tensor): + if value is None: + continue value = torch.tensor(value) elif (isinstance(value, list)) and key in ('completion_mask', 'mm_token_type_ids'): value = torch.tensor(value) @@ -284,7 +287,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso return input_tensor if cp_size > 1: - position_ids_f = position_ids.flatten() + pos_for_cu = position_ids[:1] if position_ids.dim() >= 2 and position_ids.shape[0] > 1 \ + else position_ids + position_ids_f = pos_for_cu.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) cu_seqlens = torch.cat([ indices_q[position_ids_f == 0], @@ -354,8 +359,11 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', - pin_memory=True).cuda(non_blocking=True) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], + device=inputs.device, + dtype=torch.long, + ) val = val.index_select(dim, index) view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:]) new_inputs.append(val.view(view_shape)) @@ -402,17 +410,18 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], ** if not padding_free or bool(kwargs.get('enable_sp', False)): return inputs - from twinkle.patch import apply_patch - from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch - - apply_patch( - model, - GatedDeltaNetPaddingFreePatch, - hf_config=kwargs.get('hf_config'), - enable_sp=False, - ) if not getattr(model, '_twinkle_gdn_padding_free_patched', False): - return inputs + from twinkle.patch import apply_patch + from twinkle.patch.gdn_padding_free import GatedDeltaNetPaddingFreePatch + + apply_patch( + model, + GatedDeltaNetPaddingFreePatch, + hf_config=kwargs.get('hf_config'), + enable_sp=False, + ) + if not getattr(model, '_twinkle_gdn_padding_free_patched', False): + return inputs for _inp in inputs: position_ids = _inp.get('position_ids') @@ -631,15 +640,27 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat output = {} _keys = [ 'input_ids', - 'input_embeddings', + 'inputs_embeds', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'cu_seq_lens_q', + 'cu_seq_lens_k', + 'cu_seqlens_q', + 'cu_seqlens_kv', + 'max_length_q', + 'max_length_k', + 'packed_seq_params', ] + list(InputProcessor.VLM_CONCAT_FIELDS) for key in list(_input.keys()): - if key in _keys: - output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] + if key not in _keys: + continue + value = _input[key] + if isinstance(value, torch.Tensor) or not isinstance(value, (list, np.ndarray)): + output[key] = value + else: + output[key] = np.array(value) results.append(InputFeature(**output)) return results From ffeba0b44cb734ba33493e8f4d3fb25fa181d546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 14:31:21 +0800 Subject: [PATCH 13/14] wip --- .../sampler/vllm_sampler/vllm_sampler.py | 20 ------------- src/twinkle/template/tools/base.py | 29 +++++-------------- src/twinkle/template/tools/cline.py | 13 ++------- src/twinkle/template/tools/qwen.py | 3 -- .../preprocessor/message_normalizer.py | 12 ++++---- tests/template/test_tool_parsers.py | 22 -------------- 6 files changed, 16 insertions(+), 83 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 32dc1ca50..0433ef5a8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -1,24 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""vLLM-based sampler using VLLMEngine (AsyncLLM). - -Device Configuration: - vLLMSampler automatically detects the number of available GPUs from - CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager) - and configures vLLM's tensor_parallel_size accordingly. - - To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1: - - # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2) - - # TP4 (4 GPUs, 1 worker with all 4 GPUs) - DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4) - -Data Flow: - When multiple vLLMSampler workers exist (DP > 1): - - Data is dispatched via dispatch='slice_dp' (each worker gets a slice) - - Results are collected via collect='flatten' (merged into single list) -""" import asyncio import atexit import numpy as np diff --git a/src/twinkle/template/tools/base.py b/src/twinkle/template/tools/base.py index a6d7040e2..35b63dc82 100644 --- a/src/twinkle/template/tools/base.py +++ b/src/twinkle/template/tools/base.py @@ -10,14 +10,6 @@ class ToolCallParser(ABC): open_marker: Optional[str] = None close_marker: Optional[str] = None - def matches_model(self, model_id: str) -> bool: - """Return True if this parser is the canonical choice for ``model_id``. - - Used for streaming where we must commit to a parser before any text - has arrived. Default False — parser is text-detection-only. - """ - return False - @abstractmethod def detect(self, text: str) -> bool: """Cheap pre-check: does ``text`` carry this format's markup?""" @@ -30,13 +22,14 @@ def parse(self, text: str) -> List[Dict[str, Any]]: def clean(self, text: str) -> str: """Strip parser-specific markup; return plain content text.""" - def detect_result(self, text: str) -> bool: - """Does ``text`` look like a tool-result message for this protocol?""" - return False + def extract_tool_result(self, text: str) -> Optional[str]: + """If ``text`` is a tool-result message of this protocol, return the + body with the protocol-specific prefix stripped; otherwise return ``None``. - def parse_result(self, text: str) -> str: - """Strip protocol-specific result prefix; return the raw tool output body.""" - return text + Default returns ``None`` — only protocols carrying their own tool-result + framing (e.g. Cline) need to override this. + """ + return None class ToolCallRegistry: @@ -56,14 +49,6 @@ def register(cls, parser: ToolCallParser) -> ToolCallParser: def parsers(cls) -> List[ToolCallParser]: return list(cls._parsers) - @classmethod - def select_for_model(cls, model_id: Optional[str]) -> Optional[ToolCallParser]: - mid = (model_id or '').lower() - for p in cls._parsers: - if p.matches_model(mid): - return p - return None - @classmethod def detect_first(cls, text: str) -> Optional[ToolCallParser]: if not text: diff --git a/src/twinkle/template/tools/cline.py b/src/twinkle/template/tools/cline.py index 2673e82ef..8f36324ce 100644 --- a/src/twinkle/template/tools/cline.py +++ b/src/twinkle/template/tools/cline.py @@ -21,7 +21,7 @@ from __future__ import annotations import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from .base import ToolCallParser @@ -99,10 +99,6 @@ class ClineParser(ToolCallParser): open_marker = None close_marker = None - def matches_model(self, model_id: str) -> bool: - # Cline is an app-level prompt protocol, not bound to any model family. - return False - def detect(self, text: str) -> bool: if not text or '<' not in text: return False @@ -153,9 +149,6 @@ def clean(self, text: str) -> str: out.append(text[last:]) return ''.join(out).rstrip() - def detect_result(self, text: str) -> bool: - return bool(_RESULT_RE.match(text or '')) - - def parse_result(self, text: str) -> str: + def extract_tool_result(self, text: str) -> Optional[str]: m = _RESULT_RE.match(text or '') - return text[m.end():] if m else text + return text[m.end():] if m else None diff --git a/src/twinkle/template/tools/qwen.py b/src/twinkle/template/tools/qwen.py index 12361b737..6713d570a 100644 --- a/src/twinkle/template/tools/qwen.py +++ b/src/twinkle/template/tools/qwen.py @@ -16,9 +16,6 @@ class HermesQwenParser(ToolCallParser): _PARAMETER_RE = re.compile(r']+)>\s*([\s\S]*?)\s*') _STRIP_RE = re.compile(r'[\s\S]*?(?:|\Z)') - def matches_model(self, model_id: str) -> bool: - return 'qwen' in model_id - def detect(self, text: str) -> bool: return self.open_marker in text diff --git a/src/twinkle_agentic/preprocessor/message_normalizer.py b/src/twinkle_agentic/preprocessor/message_normalizer.py index a8606d8f1..d3074a565 100644 --- a/src/twinkle_agentic/preprocessor/message_normalizer.py +++ b/src/twinkle_agentic/preprocessor/message_normalizer.py @@ -105,12 +105,12 @@ def _normalize_tool_calls(messages: List[Dict[str, Any]]) -> List[Dict[str, Any] nxt_text = msg_content_text(messages[j]) if not nxt_text: break - if parser.detect_result(nxt_text): - body = parser.parse_result(nxt_text) - elif tc_idx == 0 and len(tc_list) == 1: - body = nxt_text - else: - break + body = parser.extract_tool_result(nxt_text) + if body is None: + if tc_idx == 0 and len(tc_list) == 1: + body = nxt_text + else: + break out.append({ 'role': 'tool', 'content': body, diff --git a/tests/template/test_tool_parsers.py b/tests/template/test_tool_parsers.py index 41f6a3a4f..9269ed1f2 100644 --- a/tests/template/test_tool_parsers.py +++ b/tests/template/test_tool_parsers.py @@ -23,11 +23,6 @@ def test_detect(self): assert not self.p.detect('plain text') assert not self.p.detect('') - def test_matches_model(self): - assert self.p.matches_model('qwen2.5-7b') - assert self.p.matches_model('qwen3-32b') - assert not self.p.matches_model('llama-3.1-8b') - def test_parse_json_variant(self): text = '{"name": "get_weather", "arguments": {"city": "Paris"}}' out = self.p.parse(text) @@ -104,10 +99,6 @@ def test_no_block_marker(self): assert self.p.open_marker is None assert self.p.close_marker is None - def test_does_not_match_qwen_model(self): - assert not self.p.matches_model('qwen2.5') - assert not self.p.matches_model('llama-3') - def test_parse_single_action(self): text = 'Thought: search the web.\nAction: search[hello world]' out = self.p.parse(text) @@ -172,11 +163,6 @@ def test_no_marker(self): assert self.p.open_marker is None assert self.p.close_marker is None - def test_does_not_match_any_model_by_default(self): - # Cline is an app-level prompt protocol, not a model-family format. - assert not self.p.matches_model('qwen2.5') - assert not self.p.matches_model('claude-3') - def test_parse_single_arg(self): text = 'src/foo.py' out = self.p.parse(text) @@ -254,14 +240,6 @@ def test_no_parser_for_plain_text(self): assert ToolCallRegistry.detect_first('just some plain text') is None assert ToolCallRegistry.detect_first('') is None - def test_select_for_qwen_picks_hermes(self): - parser = ToolCallRegistry.select_for_model('qwen2.5-7b') - assert parser is not None and parser.name == 'hermes_qwen' - - def test_select_for_unknown_returns_none(self): - assert ToolCallRegistry.select_for_model('llama-3.1-8b') is None - assert ToolCallRegistry.select_for_model(None) is None - if __name__ == '__main__': pytest.main([__file__, '-v']) From a762f0f20e32e3ce709c8fb1c4e79a5135f5b8d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Tue, 16 Jun 2026 16:51:15 +0800 Subject: [PATCH 14/14] wip --- src/twinkle/template/__init__.py | 1 + src/twinkle/template/base.py | 75 ++++++++++++++++++----------- src/twinkle/template/deepseek_v4.py | 26 +++++----- src/twinkle/template/qwen3_5_vl.py | 3 ++ src/twinkle/template/utils.py | 62 ++++++++++++++++-------- 5 files changed, 104 insertions(+), 63 deletions(-) diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 6c4bdddd2..168456eab 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -2,3 +2,4 @@ from .base import Template from .deepseek_v4 import DeepseekV4Template from .qwen3_5_vl import Qwen3_5Template +from .tools import ToolCallParser, ToolCallRegistry, ClineParser, HermesQwenParser, ReActParser, VCPParser diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 26c2e4f26..c1e8f069f 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -32,6 +32,19 @@ class Template: video_placeholder: str = '