Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source_en/Components/Gym/Gym.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
The Gym component provides an interface for reinforcement learning environments in Twinkle.

```python
from twinkle.gym import Gym
from twinkle_agentic.env import Gym


class CustomGym(Gym):

Expand Down
3 changes: 2 additions & 1 deletion docs/source_zh/组件/Gym/Gym.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
Gym 组件为 Twinkle 中的强化学习环境提供接口。

```python
from twinkle.gym import Gym
from twinkle_agentic.env import Gym


class CustomGym(Gym):

Expand Down
1 change: 0 additions & 1 deletion src/twinkle/checkpoint_engine/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
import time
from typing import List, Optional

from twinkle import Platform, get_logger
Expand Down
1 change: 0 additions & 1 deletion src/twinkle/checkpoint_engine/mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os

from twinkle import Platform, remote_function
from twinkle.checkpoint_engine.base import CheckpointEngine
Expand Down
7 changes: 3 additions & 4 deletions src/twinkle/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations

import os

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing from __future__ import annotations will break compatibility with Python < 3.10 because of the use of PEP 604 union types (e.g., str | None) and PEP 585 generic collections (e.g., list[str]) in type annotations. Additionally, keeping it enables postponed evaluation of annotations, allowing forward references (like ConfigRegistry) to be used without enclosing them in quotes. Please keep this import.

Suggested change
import os
from __future__ import annotations
import os

import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from pathlib import Path
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
from typing import Any, Iterator, Literal


# ────────────────────────────────────────────────────────────────────────────────
# Arg group dataclasses
Expand Down Expand Up @@ -243,7 +242,7 @@ def _resolve_path(self) -> Path | None:
class EnvVarSource(ConfigSource):
"""Reads os.environ; recognizes TWINKLE_ prefix and any key known to the registry."""

def __init__(self, registry: ConfigRegistry):
def __init__(self, registry: 'ConfigRegistry'):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With from __future__ import annotations restored, you can use ConfigRegistry directly as a type annotation without enclosing it in quotes.

Suggested change
def __init__(self, registry: 'ConfigRegistry'):
def __init__(self, registry: ConfigRegistry):

self._registry = registry

def load(self) -> dict[str, str]:
Expand Down
2 changes: 2 additions & 0 deletions src/twinkle/data_format/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class ModelOutput(TypedDict, total=False):
loss: The loss calculated by the model.
logps: The log-probabilities of correct tokens by the model.
num_tokens: The token denominator associated with ``loss``.
embeddings: The embeddings output by the model, used be embedding task.
"""
logits: Optional[OutputType]
loss: Optional[OutputType]
logps: Optional[OutputType]
num_tokens: Optional[OutputType]
embeddings: Optional[OutputType]


class LossOutput(TypedDict, total=False):
Expand Down
1 change: 0 additions & 1 deletion src/twinkle/data_format/sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union

Expand Down
3 changes: 2 additions & 1 deletion src/twinkle/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _tracking_iter(self, inner):
def skip_consumed_samples(self, consumed_train_samples: int) -> None:
from torch.utils.data import IterableDataset

if isinstance(self.dataset, IterableDataset):
if isinstance(self.dataset, IterableDataset) or consumed_train_samples is None or consumed_train_samples <= 0:
warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.')
self._skip_samples = 0
return
Expand All @@ -164,6 +164,7 @@ def resume_from_checkpoint(self, consumed_train_samples, **kwargs):

@remote_function()
def get_state(self) -> dict:
"""The dataloader state for saving."""
return {'consumed_train_samples': self._consumed_train_samples}

def _rebuild_sampler_stack(self):
Expand Down
30 changes: 22 additions & 8 deletions src/twinkle/dataset/iterable_packing_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,27 @@ def _fetch_data_out_queue(self, last_res, num_samples):
last_res += res
return last_res

@staticmethod
def _cyclic_iter(iterable):
while True:
yield from iterable
def _write_through_iter(self, iterable):
"""Yields from iterable, meanwhile, save it to disk if needed.
Saving is needed when you are using several datasets at a time.
"""
if not self.cyclic:
for row in iterable:
self._write_through(row)
yield row
return
else:
first_pass = True
while True:
empty = True
for row in iterable:
empty = False
if first_pass:
self._write_through(row)
yield row
if empty:
return
first_pass = False

@remote_function()
def __iter__(self):
Expand All @@ -102,10 +119,7 @@ def __iter__(self):
except StopIteration:
return

if self.cyclic:
iterator = self._cyclic_iter(self.dataset)
else:
iterator = iter(self.dataset)
iterator = self._write_through_iter(self.dataset)
data = []
max_length = self.template.max_length or 2048
while True:
Expand Down
2 changes: 2 additions & 0 deletions src/twinkle/dataset/packing_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __getitem__(self, index):
assert self._packed_called, 'Call `pack_dataset()` first before index the sample.'
sequence = self.packed_idx[index]
rows = [self.dataset[i] for i in sequence]
for row in rows:
self._write_through(row)
output = {}
for key in rows[0]:
output[key] = [r[key] for r in rows]
Expand Down
2 changes: 0 additions & 2 deletions src/twinkle/gym/__init__.py

This file was deleted.

10 changes: 0 additions & 10 deletions src/twinkle/gym/base.py

This file was deleted.

22 changes: 4 additions & 18 deletions src/twinkle/infra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import functools
import inspect
import itertools
import json
import numpy as np
import os
import random
import sys
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union

Expand Down Expand Up @@ -59,7 +57,7 @@ def _tag_exc(exc: BaseException, caller: Optional[str]) -> None:
prefix = f'[twinkle driver caller: {caller}] '
exc.args = (prefix + str(exc.args[0]), *exc.args[1:]) if exc.args else (prefix.rstrip(), )
exc._twinkle_caller_augmented = True
except Exception: # noqa: BLE001
except Exception: # noqa
pass


Expand Down Expand Up @@ -404,6 +402,7 @@ def dispatch_func(arg, n):

return result
elif dispatch == 'slice_dp':
assert device_mesh is not None
# split by dp. each worker in one ep will receive the same argument
result = []
# if device_mesh is not None:
Expand All @@ -420,14 +419,6 @@ def dispatch_func(arg, n):
import torch
if isinstance(arg, list) or isinstance(arg, torch.Tensor):
_args = []
if device_mesh is None:
total = len(arg)
chunk = max(1, (total + n - 1) // n)
for i in range(n):
start = i * chunk
end = min(total, start + chunk)
_args.append(arg[start:end])
return _args
for i in range(n):
_args.append(arg[device_mesh.get_slice(
len(arg), device_mesh.get_data_rank_from_global_rank(i * _rank_stride))])
Expand Down Expand Up @@ -696,7 +687,7 @@ def __next__(_self):
return decorator


def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice',
def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp', 'last_pp_first'], Callable] = 'slice',
execute: Literal['first', 'peer', 'all'] = 'all',
collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none',
sync: bool = False,
Expand Down Expand Up @@ -803,18 +794,13 @@ def wrapper(self, *args, **kwargs) -> T1:
# And this is user independent, only decided by the code.
_local_lazy_collect = self._lazy_collect
if _local_lazy_collect:
# Wrap the deferred collector so that exceptions
# raised when the caller later materializes the
# result also trigger the notifier. Attributes
# (``_futures`` etc.) on the original collector
# are preserved for downstream code paths.
_orig_result_func = result_func

@functools.wraps(_orig_result_func)
def _notifying_result_func(*rargs, **rkwargs):
try:
return _orig_result_func(*rargs, **rkwargs)
except Exception as _e: # noqa: BLE001
except Exception as _e: # noqa
_tag_exc(_e, _caller)
notify_exception(_notifier, _ctx, _e, _name)
raise
Expand Down
Loading
Loading